Source code for coix.api

# Copyright 2024 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Program combinators.

The implement is pretty much backend-agnostic. We just assume that the core
backend supports the following functionality:

  + `suffix(p)`: rename latent variables of the program `p`,
  + `traced_evaluate(p, latents=None)`: execute `p` and collect trace, metrics,
    optionally we can substitute values in `latents` to `p`,
  + `empirical(out, trace, metrics)`: create a delta program given output,
    trace, and metrics. Inputs of `empirical` are outputs of `traced_evaluate`.
"""

import functools

from coix import core
from coix import util
import jax
import jax.numpy as jnp
import numpy as np

# pytype: disable=module-attr
try:
  wrap_key_data = jax.random.wrap_key_data
except AttributeError:
  try:
    wrap_key_data = jax.extend.random.wrap_key_data
  except AttributeError:

    def _identity(k):
      return k

    wrap_key_data = _identity


# pytype: enable=module-attr


__all__ = [
    "compose",
    "extend",
    "fori_loop",
    "propose",
    "resample",
]


[docs] def compose(q2, q1, suffix=True): r"""Executes q2(\*q1(...)). Note: We only allow at most one of `q1` or `q2` is weighted. Args: q2: a program q1: a program suffix: whether to add suffix `\_PREV\_` to variables in `q1` Returns: q: the composed program """ def wrapped(*args, **kwargs): q = core.suffix(q1) if suffix else q1 return q2(*q(*args, **kwargs)) return wrapped
[docs] def extend(p, f): r"""Executes f(\*p(...)) with random variables in f marked as auxiliary. Note: We don't allow recursively marginalize out `p` yet. Args: p: a target program f: an auxiliary program Returns: p_new: the extended program """ def wrapped(*args, **kwargs): args = p(*args, **kwargs) core.suffix(f)(*args) return args return wrapped
def _reshape_key(key, shape): if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): return jnp.reshape(key, shape) else: return jnp.reshape(key, shape + (2,)) def _split_key(key): keys = jax.vmap(jax.random.split, out_axes=1)(_reshape_key(key, (-1,))) return keys[0].reshape(key.shape), keys[1].reshape(key.shape) def _fold_in_key(key, i): key_new = jax.vmap(jax.random.fold_in, (0, None))(_reshape_key(key, (-1,)), i) return key_new.reshape(key.shape)
[docs] def propose(p, q, *, loss_fn=None, detach=False, chain=False): """Returns a new program with important weight. We assume the leftmost batch dimension is the particle dimension. You can add additional batch dimensions to the whole program by using `vmap`, e.g. `vmap(propose(p, q))`. Args: p: a target program q: a proposal program loss_fn: a function that computes loss of this propose combinator detach: whether to detach `value` of the returned program chain: if True, we will use output of `q` as input of `p` Returns: q_new: the proposed program """ def wrapped(*args, **kwargs): if util.can_extract_key(args) and not chain: key_p, key_q = _split_key(args[0]) p_args = (key_p,) + args[1:] q_args = (key_q,) + args[1:] else: p_args = q_args = args q_out, q_trace, q_metrics = core.traced_evaluate(q)(*q_args, **kwargs) if chain: p_args = q_out metrics = q_metrics.copy() q_latents = { name: util.get_site_value(site) for name, site in q_trace.items() if not util.is_observed_site(site) } traced_p = core.traced_evaluate(p, latents=q_latents) out, p_trace, _ = traced_p(*p_args, **kwargs) p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() } q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } log_probs = list(p_log_probs.values()) + list(q_log_probs.values()) batch_ndims = util.get_batch_ndims(log_probs) if "log_weight" in q_metrics: in_log_weight = q_metrics["log_weight"] in_log_weight = jnp.sum( in_log_weight, axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), ) else: in_log_weight = util.get_log_weight(q_trace, batch_ndims) p_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() if util.is_observed_site(p_trace[name]) or (name in q_trace) ) q_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in q_log_probs.items() if util.is_observed_site(q_trace[name]) or (name in p_trace) ) incremental_log_weight = p_log_weight - q_log_weight log_weight = in_log_weight + incremental_log_weight metrics["log_weight"] = log_weight if batch_ndims: # leftmost dimension is particle dimension ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0) metrics["ess"] = ess.mean() log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log( log_weight.shape[0] ) metrics["log_Z"] = log_z.sum() if loss_fn is not None: if detach: p_latents = { name: util.get_site_value(site, detach=True) for name, site in p_trace.items() if not util.is_observed_site(site) } out, p_trace, _ = core.traced_evaluate(p, latents=p_latents)( *p_args, **kwargs ) loss = loss_fn(q_trace, p_trace, in_log_weight, incremental_log_weight) metrics["loss"] = q_metrics.get("loss", 0.0) + loss marginal_trace = { name: site for name, site in p_trace.items() if not name.endswith("_PREV_") } log_density = jnp.zeros((1,) * batch_ndims) + sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() if name in marginal_trace ) if batch_ndims: log_density = jnp.mean(log_density, axis=0).sum() metrics["log_density"] = log_density return core.empirical(out, marginal_trace, metrics)(*args, **kwargs) return wrapped
def _maybe_get_along_first_axis(x, idx, n, squeeze=False): """Get along the first axis of `x` if `x.shape[0] == n`.""" is_list = False if isinstance(x, list): is_list = True x = np.array(x) # Special treatment for cascades. if hasattr(x, "value"): setattr( x, "value", _maybe_get_along_first_axis( util.get_site_value(x), idx, n, squeeze=squeeze ), ) if hasattr(x, "log_density"): setattr( x, "log_density", _maybe_get_along_first_axis( util.get_site_log_prob(x), idx, n, squeeze=squeeze ), ) if ( isinstance(x, (np.ndarray, jnp.ndarray)) and (x.ndim >= 1) and (x.shape[0] == n) ): idx = idx.reshape(idx.shape + (1,) * (x.ndim - idx.ndim)) if isinstance(x, np.ndarray): y = np.take_along_axis(x, idx, axis=0) elif jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key): x_data = jax.random.key_data(x) idx = idx.reshape(idx.shape + (1,) * (x_data.ndim - idx.ndim)) y_data = jnp.take_along_axis(x_data, idx, axis=0) y_data = y_data[0] if (idx.shape[0] == 1 and squeeze) else y_data y = wrap_key_data(y_data) else: y = jnp.take_along_axis(x, idx, axis=0) y = y.tolist() if is_list else y return y[0] if (idx.shape[0] == 1 and squeeze) else y else: return x
[docs] def resample(q, num_samples=None): """Returns a new program with equally-weighted particles. Args: q: a program num_samples: the number of samples after resampling. Set this to an empty tuple to draw 1 sample without the leftmost singleton dimension. Defaults to the number of particles in `q`. Returns: q_new: the resampled program """ def fn(*args, **kwargs): if util.can_extract_key(args): key_r, key_q = _split_key(args[0]) # We just need a single key for resampling. key_r = _reshape_key(key_r, (-1,))[0] args = (key_q,) + args[1:] else: key_r = core.prng_key() out, trace, q_metrics = core.traced_evaluate(q)(*args, **kwargs) log_probs = { name: util.get_site_log_prob(site) for name, site in trace.items() } batch_ndims = util.get_batch_ndims(log_probs.values()) weighted = ("log_weight" in q_metrics) or any( util.is_observed_site(site) for site in trace.values() ) if (batch_ndims == 0) or not weighted: # resample is no-op return core.empirical(out, trace, q_metrics)(*args, **kwargs) metrics = q_metrics.copy() if "log_weight" in q_metrics: in_log_weight = q_metrics.pop("log_weight") in_log_weight = jnp.sum( in_log_weight, axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), ) else: in_log_weight = util.get_log_weight(trace, batch_ndims) n = in_log_weight.shape[0] k = n if num_samples is None else num_samples log_weight = jax.nn.logsumexp(in_log_weight, 0) - jnp.log(k if k else 1) if k: metrics["log_weight"] = jnp.broadcast_to( log_weight, (k,) + in_log_weight.shape[1:] ) metrics["ess"] = jnp.asarray(float(k)) if "log_Z" not in q_metrics: metrics["log_Z"] = log_weight.sum() log_probs = jax.nn.log_softmax(in_log_weight, axis=0) idx = util.get_systematic_resampling_indices( log_probs, rng_key=key_r, num_samples=k if k else 1 ) maybe_get_along_first_axis = functools.partial( _maybe_get_along_first_axis, idx=idx, n=n, squeeze=not k ) out = jax.tree.map( maybe_get_along_first_axis, out, is_leaf=lambda x: isinstance(x, list) ) resample_trace = jax.tree.map( maybe_get_along_first_axis, trace, is_leaf=lambda x: isinstance(x, list) ) return core.empirical(out, resample_trace, metrics)(*args, **kwargs) return fn
def _add_missing_metrics(metrics, trace): """Adds missing metrics to get consistent pytree in fori_loop.""" full_metrics = metrics.copy() log_probs = { name: util.get_site_log_prob(site) for name, site in trace.items() } if "log_weight" not in metrics: batch_ndims = min(util.get_batch_ndims(list(log_probs.values())), 1) log_weight = util.get_log_weight(trace, batch_ndims) full_metrics["log_weight"] = log_weight else: batch_ndims = metrics["log_weight"].ndim log_weight = metrics["log_weight"] # leftmost dimension is particle dimension if batch_ndims and "ess" not in metrics: assert "log_Z" not in metrics ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0) full_metrics["ess"] = ess.mean() n = log_weight.shape[0] log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n) full_metrics["log_Z"] = log_z.mean() if "loss" not in metrics: full_metrics["loss"] = jnp.array(0.0) if "log_density" not in metrics: log_density = sum(jnp.sum(lp) for lp in log_probs.values()) full_metrics["log_density"] = jnp.array(0.0) + log_density return full_metrics
[docs] def fori_loop(lower, upper, body_fun, init_program): """Returns a program which loops over programs created by body_fun. Args: lower: loop index lower bound upper: loop index upper bound (exclusive) body_fun: a function that takes a pair of inputs (index, program) and return a new program init_program: initial program for `body_fun` Returns: q: the final program """ def fn(*args, **kwargs): def trace_arg_key(fn, key): return core.traced_evaluate(fn)(key, *args[1:], **kwargs) def trace_with_seed(fn, key): return core.traced_evaluate(fn, seed=key)(*args, **kwargs) if util.can_extract_key(args): key = args[0] trace_fn = trace_arg_key else: key = core.prng_key() trace_fn = trace_with_seed key_body, key_init = _split_key(key) def jax_body_fun(i, val): q = core.empirical(*val) return trace_fn(body_fun(i, q), _fold_in_key(key_body, i)) v, trace, metrics = trace_fn(init_program, key_init) metrics = _add_missing_metrics(metrics, trace) output = jax.lax.fori_loop(lower, upper, jax_body_fun, (v, trace, metrics)) return core.empirical(*output)(key, *args, **kwargs) return fn
def _join_samples(first_set, second_set): if first_set is None: return second_set if isinstance(first_set, list): assert isinstance(second_set, list) return first_set + second_set else: return jnp.concatenate([first_set, second_set], axis=0) def memoize(p, q, memory=None, memory_size=None): """Returns a new program using additional samples from memory for proposal. Args: p: a target program q: a proposal program memory: additional samples for `q` memory_size: size of the memory to be stored in the program's metrics Returns: q_new: the proposed program """ if (memory is None) and (memory_size is None): raise ValueError("One of memory or memory_size needs to be specified.") memory = {} if memory is None else memory memory_sizes = [len(x) for x in memory.values()] if len(set(memory_sizes)) > 1: raise ValueError("We need all variables have the same memory size.") memory_size = memory_size if memory_size is not None else memory_sizes[0] def wrapped(*args, **kwargs): if util.can_extract_key(args): key = args[0] p_key, q_key = _split_key(key) p_args = (p_key,) + args[1:] q_args = (q_key,) + args[1:] else: p_args = q_args = args _, q_trace, q_metrics = core.traced_evaluate(q)(*q_args, **kwargs) metrics = q_metrics.copy() q_latents = { name: _join_samples(memory.get(name, None), util.get_site_value(site)) for name, site in q_trace.items() if not util.is_observed_site(site) } traced_p = core.traced_evaluate(p, latents=q_latents) out, p_trace, _ = traced_p(*p_args, **kwargs) p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() } batch_ndims = util.get_batch_ndims(p_log_probs.values()) p_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for lp in p_log_probs.values() ) marginal_trace = { name: site for name, site in p_trace.items() if not util.is_observed_site(site) } new_memory = { name: util.get_site_value(site) for name, site in marginal_trace.items() } assert not isinstance(p_log_weight, int) num_particles = p_log_weight.shape[0] batch_dim = p_log_weight.ndim flat_memory = { k: np.array(v) if isinstance(v, list) else v for k, v in new_memory.items() } flat_memory = { k: v.reshape((num_particles, -1) + v.shape[batch_dim:]) for k, v in flat_memory.items() } flat_log_weight = p_log_weight.reshape(p_log_weight.shape[:1] + (-1,)) idxs = [] for i in range(flat_log_weight.shape[1]): w = flat_log_weight[:, i] mem = [ [flat_memory[k][j, i] for k in sorted(flat_memory)] for j in range(num_particles) ] unique_idx = np.unique(mem, axis=0, return_index=True)[1] sorted_idx = jnp.argsort(w[unique_idx])[::-1][:memory_size] idxs.append(unique_idx[sorted_idx]) idxs = jnp.stack(idxs, -1).reshape((memory_size,) + p_log_weight.shape[1:]) maybe_get_along_first_axis = functools.partial( _maybe_get_along_first_axis, idx=idxs, n=num_particles ) metrics["log_weight"] = maybe_get_along_first_axis(p_log_weight) out = jax.tree.map( maybe_get_along_first_axis, out, is_leaf=lambda x: isinstance(x, list) ) marginal_trace = jax.tree.map( maybe_get_along_first_axis, marginal_trace, is_leaf=lambda x: isinstance(x, list), ) metrics["memory"] = jax.tree.map( maybe_get_along_first_axis, new_memory, is_leaf=lambda x: isinstance(x, list), ) return core.empirical(out, marginal_trace, metrics)(*args, **kwargs) return wrapped