Utilities
Utilities.
- class BindModule(module, params)[source]
Like Flax’s module.bind(params) but composed with JAX transforms.
- get_log_weight(trace, batch_ndims)[source]
Computes log weight of the trace and keeps its batch dimensions.
Utilities.
Like Flax’s module.bind(params) but composed with JAX transforms.
Computes log weight of the trace and keeps its batch dimensions.