Utilities

Utilities.

class BindModule(module, params)[source]

Like Flax’s module.bind(params) but composed with JAX transforms.

get_batch_ndims(xs)[source]

Gets the number of same-size leading dimensions of the elements in xs.

get_log_weight(trace, batch_ndims)[source]

Computes log weight of the trace and keeps its batch dimensions.

get_systematic_resampling_indices(log_weights, rng_key, num_samples)[source]

Gets resampling indices based on systematic resampling.

train(loss_fn, init_params, optimizer, num_steps, dataloader=None, seed=0, jit_compile=True, eval_fn=None, log_every=None, init_step=0, opt_state=None, **kwargs)[source]

Optimize the parameters.