Source code for coix.core

# Copyright 2024 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Program transforms."""

import importlib

__all__ = [
    "detach",
    "empirical",
    "prng_key",
    "register_backend",
    "set_backend",
    "stick_the_landing",
    "suffix",
    "traced_evaluate",
]

_BACKENDS = {}
_COIX_BACKEND = None


# pylint:disable=redefined-outer-name
[docs] def register_backend( backend, traced_evaluate=None, empirical=None, suffix=None, prng_key=None, detach=None, stick_the_landing=None, ): """Register backend.""" fn_map = { "traced_evaluate": traced_evaluate, "empirical": empirical, "suffix": suffix, "prng_key": prng_key, "detach": detach, "stick_the_landing": stick_the_landing, } _BACKENDS[backend] = fn_map
# pylint:enable=redefined-outer-name
[docs] def set_backend(backend): """Set backend.""" global _COIX_BACKEND if backend not in _BACKENDS: module = importlib.import_module(backend) fn_map = {} for fn in [ "traced_evaluate", "empirical", "suffix", "prng_key", "detach", "stick_the_landing", ]: fn_map[fn] = getattr(module, fn, None) register_backend(backend, **fn_map) _COIX_BACKEND = backend
def get_backend_name(): return _COIX_BACKEND def get_backend(): backend = _COIX_BACKEND if backend is None: set_backend("coix.numpyro") return _BACKENDS["coix.numpyro"] else: return _BACKENDS[backend] ######################################## # Program transforms ######################################## def _remove_suffix(name): i = 0 while name.endswith("_PREV_"): i += len("_PREV_") name = name[: -len("_PREV_")] return name, i def desuffix(trace): """Remove unnecessary suffix terms added to the trace.""" names_to_raw_names = {} num_suffix_min = {} for name in trace: raw_name, num_suffix = _remove_suffix(name) names_to_raw_names[name] = raw_name if raw_name in num_suffix_min: num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix) else: num_suffix_min[raw_name] = num_suffix new_trace = {} for name in trace: raw_name = names_to_raw_names[name] if raw_name != name and isinstance(trace[name], dict): trace[name]["suffix"] = True new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace
[docs] def traced_evaluate(p, latents=None, seed=None, **kwargs): """Performs traced evaluation for a program `p`.""" # Work around some backends not having `seed` keyword. kwargs = kwargs.copy() if seed is not None: kwargs["seed"] = seed fn = get_backend()["traced_evaluate"](p, latents=latents, **kwargs) def wrapped(*args, **kwargs): out, trace, metrics = fn(*args, **kwargs) return out, desuffix(trace), metrics return wrapped
[docs] def empirical(out, trace, metrics): """Creates an empirical program given a trace.""" return get_backend()["empirical"](out, trace, metrics)
[docs] def suffix(p): """Adds suffix `_PREV_` to variable names of `p`.""" fn = get_backend()["suffix"] if fn is not None: return fn(p) else: return p
[docs] def detach(p): """Makes random variables in `p` become non-reparameterized.""" fn = get_backend()["detach"] if fn is not None: return fn(p) else: return p
[docs] def stick_the_landing(p): """Stops gradient of distributions' parameters before computing log prob.""" fn = get_backend()["stick_the_landing"] if fn is not None: return fn(p) else: return p
[docs] def prng_key(): """Generates a random JAX PRNGKey.""" fn = get_backend()["prng_key"] if fn is not None: return fn() else: return None