Source code for coix.util

# Copyright 2024 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities."""

import functools
import time

import jax
from jax import random
import jax.numpy as jnp
import numpy as np


[docs] def get_systematic_resampling_indices(log_weights, rng_key, num_samples): """Gets resampling indices based on systematic resampling.""" n = log_weights.shape[0] # TODO(phandu): It might be more numerical stable if we work in log space. weight = jax.nn.softmax(log_weights, axis=0) cummulative_weight = weight.cumsum(axis=0) cummulative_weight = cummulative_weight / cummulative_weight[-1] cummulative_weight = cummulative_weight.reshape((n, -1)).swapaxes(0, 1) m = cummulative_weight.shape[0] if rng_key is not None: uniform = jax.random.uniform(rng_key, (m,)) else: uniform = np.random.rand(m) positions = (uniform[:, None] + np.arange(num_samples)) / num_samples shift = np.arange(m)[:, None] cummulative_weight = (cummulative_weight + 2 * shift).reshape(-1) positions = (positions + 2 * shift).reshape(-1) index = cummulative_weight.searchsorted(positions) index = (index.reshape(m, num_samples) - n * shift).swapaxes(0, 1) return index.reshape((num_samples,) + log_weights.shape[1:])
def get_site_log_prob(site): if hasattr(site, "log_density"): return site.log_density else: return site["log_prob"] def get_site_value(site, detach=False): if hasattr(site, "value"): value = site.value else: value = site["value"] if detach and isinstance(value, jnp.ndarray): return jax.lax.stop_gradient(value) else: return value def is_observed_site(site): if hasattr(site, "tag"): return site.tag == "observed" else: return "is_observed" in site def can_extract_key(args): return ( args and isinstance(args[0], jnp.ndarray) and ( jax.dtypes.issubdtype(args[0].dtype, jax.dtypes.prng_key) or ( (args[0].dtype == jnp.uint32) and (jnp.ndim(args[0]) >= 1) and (args[0].shape[-1] == 2) ) ) ) class _ChildModule: """A child of a bind module.""" def __init__(self, module, params, name): self.module = module self.params = params self.name = name def __getitem__(self, i): return functools.partial( self.module.apply, self.params, method=lambda n, *a, **kw: getattr(n, self.name)[i](*a, **kw), ) def __call__(self, *args, **kwargs): return self.module.apply( self.params, *args, method=lambda n, *a, **kw: getattr(n, self.name)(*a, **kw), **kwargs, )
[docs] class BindModule: """Like Flax's `module.bind(params)` but composed with JAX transforms.""" def __init__(self, module, params): self.module = module self.params = params for submodule in params["params"]: setattr( self, submodule, _ChildModule(self.module, self.params, submodule) ) for submodule in params["params"]: if "_" in submodule and submodule.split("_")[-1].isnumeric(): maybe_submodule_list = "_".join(submodule.split("_")[:-1]) if not hasattr(self, maybe_submodule_list): setattr( self, maybe_submodule_list, _ChildModule(self.module, self.params, maybe_submodule_list), ) for field in module.__annotations__: if field not in ("parent", "name"): setattr(self, field, getattr(module, field)) def __call__(self, *args, **kwargs): return self.module.apply(self.params, *args, **kwargs)
def _skip_update(grad, opt_state, params): del params return jax.tree_util.tree_map(jnp.zeros_like, grad), opt_state
[docs] def train( loss_fn, init_params, optimizer, num_steps, dataloader=None, seed=0, jit_compile=True, eval_fn=None, log_every=None, init_step=0, opt_state=None, **kwargs, ): """Optimize the parameters.""" def step_fn(params, opt_state, *args, **kwargs): (_, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)( params, *args, **kwargs ) grads = jax.tree_util.tree_map( lambda x, y: x.astype(y.dtype), grads, params ) # Helpful metric to print out during training. squared_grad_norm = sum( jnp.square(p).sum() for p in jax.tree_util.tree_leaves(grads) ) metrics["squared_grad_norm"] = squared_grad_norm updates, opt_state = jax.lax.cond( jnp.isfinite(jax.flatten_util.ravel_pytree(grads)[0]).all(), optimizer.update, _skip_update, grads, opt_state, params, ) params = jax.tree_util.tree_map(lambda p, u: p + u, params, updates) return params, opt_state, metrics if callable(jit_compile): maybe_jitted_step_fn = jit_compile(step_fn) else: maybe_jitted_step_fn = jax.jit(step_fn) if jit_compile else step_fn opt_state = optimizer.init(init_params) if opt_state is None else opt_state params = init_params run_key = random.PRNGKey(seed) if isinstance(seed, int) else seed log_every = max(num_steps // 20, 1) if log_every is None else log_every space = str(len(str(num_steps - 1))) kwargs = kwargs.copy() if eval_fn is not None: print("Evaluating with the initial params...", flush=True) tic = time.time() eval_fn(init_step, params, opt_state, metrics=None) print("Time to compile an eval step:", time.time() - tic, flush=True) print("Compiling the first train step...", flush=True) tic = time.time() metrics = None for step in range(init_step + 1, num_steps + 1): key = random.fold_in(run_key, step) args = (key, next(dataloader)) if dataloader is not None else (key,) params, opt_state, metrics = maybe_jitted_step_fn( params, opt_state, *args, **kwargs ) for name in kwargs: if name in metrics: kwargs[name] = metrics[name] if step == 1: print("Time to compile a train step:", time.time() - tic, flush=True) print("=====", flush=True) if (step == num_steps) or (step % log_every == 0): log = ("Step {:<" + space + "d}").format(step) for name, value in sorted(metrics.items()): if np.isscalar(value) or ( isinstance(value, (np.ndarray, jnp.ndarray)) and (value.ndim == 0) ): log += f" | {name} {float(value):10.4f}" print(log, flush=True) if eval_fn is not None: eval_fn(step, params, opt_state, metrics) return params, metrics
def _remove_suffix(name): i = 0 while name.endswith("_PREV_"): i += len("_PREV_") name = name[: -len("_PREV_")] return name, i
[docs] def desuffix(trace): """Remove unnecessary suffix terms added to the trace.""" names_to_raw_names = {} num_suffix_min = {} for name in trace: raw_name, num_suffix = _remove_suffix(name) names_to_raw_names[name] = raw_name if raw_name in num_suffix_min: num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix) else: num_suffix_min[raw_name] = num_suffix new_trace = {} for name in trace: raw_name = names_to_raw_names[name] new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace
[docs] def get_batch_ndims(xs): """Gets the number of same-size leading dimensions of the elements in xs.""" if not xs: return 0 min_ndim = min(jnp.ndim(lp) for lp in xs) batch_ndims = 0 for i in range(min_ndim): if len(set(jnp.shape(lp)[i] for lp in xs)) > 1: break batch_ndims = batch_ndims + 1 return batch_ndims
[docs] def get_log_weight(trace, batch_ndims): """Computes log weight of the trace and keeps its batch dimensions.""" log_weight = jnp.zeros((1,) * batch_ndims) for site in trace.values(): lp = get_site_log_prob(site) if is_observed_site(site): log_weight = log_weight + jnp.sum( lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0)) ) else: log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims]) return log_weight