
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/anneal_oryx.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_anneal_oryx.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_anneal_oryx.py:


Example: Annealed Variational Inference in Oryx
===============================================

This example illustrates how to construct an inference program based on the NVI
algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of
the reference. We will use the Oryx backend for this example.

**References**

    1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.

.. image:: ../_static/anneal_oryx.png
    :align: center

.. GENERATED FROM PYTHON SOURCE LINES 31-48

.. code-block:: Python


    import argparse
    from functools import partial

    import coix
    import coix.oryx as coryx
    import flax
    import flax.linen as nn
    import jax
    from jax import random
    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    import numpy as np
    import numpyro
    import numpyro.distributions as dist
    import optax


.. GENERATED FROM PYTHON SOURCE LINES 49-50

First, we define the neural networks for the targets and kernels.

.. GENERATED FROM PYTHON SOURCE LINES 50-118

.. code-block:: Python



    class AnnealKernel(nn.Module):

      @nn.compact
      def __call__(self, x):
        h = nn.Dense(50)(x)
        h = nn.relu(h)
        loc = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) + x
        scale_raw = nn.Dense(2, kernel_init=nn.initializers.zeros)(h)
        return loc, nn.softplus(scale_raw)


    class AnnealDensity(nn.Module):
      M = 8

      @nn.compact
      def __call__(self, x, index=0):
        beta_raw = self.param("beta_raw", lambda _: -jnp.ones(self.M - 2))
        beta = nn.sigmoid(
            beta_raw[0] + jnp.pad(jnp.cumsum(nn.softplus(beta_raw[1:])), (1, 0))
        )
        beta = jnp.pad(beta, (1, 1), constant_values=(0, 1))
        beta_k = beta[index]

        angles = 2 * jnp.arange(1, self.M + 1) * jnp.pi / self.M
        mu = 10 * jnp.stack([jnp.sin(angles), jnp.cos(angles)], -1)
        sigma = jnp.sqrt(0.5)
        target_density = nn.logsumexp(
            dist.Normal(mu, sigma).log_prob(x[..., None, :]).sum(-1), -1
        )
        init_proposal = dist.Normal(0, 5).log_prob(x).sum(-1)
        return beta_k * target_density + (1 - beta_k) * init_proposal


    class AnnealKernelList(nn.Module):
      M = 8

      @nn.compact
      def __call__(self, x, index=0):
        if self.is_mutable_collection("params"):
          vmap_net = nn.vmap(
              AnnealKernel, variable_axes={"params": 0}, split_rngs={"params": True}
          )
          out = vmap_net(name="kernel")(
              jnp.broadcast_to(x, (self.M - 1,) + x.shape)
          )
          return jax.tree.map(lambda x: x[index], out)
        params = self.scope.get_variable("params", "kernel")
        params_i = jax.tree.map(lambda x: x[index], params)
        return AnnealKernel(name="kernel").apply(
            flax.core.freeze({"params": params_i}), x
        )


    class AnnealNetwork(nn.Module):

      def setup(self):
        self.forward_kernels = AnnealKernelList()
        self.reverse_kernels = AnnealKernelList()
        self.anneal_density = AnnealDensity()

      def __call__(self, x):
        self.reverse_kernels(x)
        self.anneal_density(x)
        return self.forward_kernels(x)



.. GENERATED FROM PYTHON SOURCE LINES 119-120

Then, we define the targets and kernels as in Section E.1.

.. GENERATED FROM PYTHON SOURCE LINES 120-139

.. code-block:: Python



    def anneal_target(network, key, k=0):
      key_out, key = random.split(key)
      x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key)
      coryx.factor(network.anneal_density(x, index=k), name="anneal_density")
      return key_out, {"x": x}


    def anneal_forward(network, key, inputs, k=0):
      mu, sigma = network.forward_kernels(inputs["x"], index=k)
      return coryx.rv(dist.Normal(mu, sigma), name="x")(key)


    def anneal_reverse(network, key, inputs, k=0):
      mu, sigma = network.reverse_kernels(inputs["x"], index=k)
      return coryx.rv(dist.Normal(mu, sigma), name="x")(key)



.. GENERATED FROM PYTHON SOURCE LINES 140-142

Finally, we create the anneal inference program, define the loss function,
run the training loop, and plot the results.

.. GENERATED FROM PYTHON SOURCE LINES 142-216

.. code-block:: Python



    def make_anneal(params, unroll=False):
      network = coix.util.BindModule(AnnealNetwork(), params)
      # Add particle dimension and construct a program.
      targets = lambda k: jax.vmap(partial(anneal_target, network, k=k))
      forwards = lambda k: jax.vmap(partial(anneal_forward, network, k=k))
      reverses = lambda k: jax.vmap(partial(anneal_reverse, network, k=k))
      if unroll:  # to unroll the algorithm, we provide a list of programs
        targets = [targets(k) for k in range(8)]
        forwards = [forwards(k) for k in range(7)]
        reverses = [reverses(k) for k in range(7)]
      program = coix.algo.nvi_rkl(targets, forwards, reverses, num_targets=8)
      return program


    def loss_fn(params, key, num_particles, unroll=False):
      # Prepare data for the program.
      rng_keys = random.split(key, num_particles)

      # Run the program and get metrics.
      program = make_anneal(params, unroll=unroll)
      _, _, metrics = coix.traced_evaluate(program)(rng_keys)
      return metrics["loss"], metrics


    def main(args):
      lr = args.learning_rate
      num_steps = args.num_steps
      num_particles = args.num_particles
      unroll = args.unroll_loop

      anneal_net = AnnealNetwork()
      init_params = anneal_net.init(random.PRNGKey(0), jnp.zeros(2))

      anneal_params, _ = coix.util.train(
          partial(loss_fn, num_particles=num_particles, unroll=unroll),
          init_params,
          optax.adam(lr),
          num_steps,
          jit_compile=True,
      )

      rng_keys = random.split(random.PRNGKey(1), 100000).reshape((100, 1000, 2))
      _, trace, metrics = coix.traced_evaluate(
          jax.vmap(make_anneal(anneal_params, unroll=unroll))
      )(rng_keys)

      metrics.pop("log_weight")
      anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics)
      print(anneal_metrics)

      plt.figure(figsize=(8, 8))
      x = trace["x"]["value"].reshape((-1, 2))
      H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
      plt.imshow(H.T)
      plt.show()


    if __name__ == "__main__":
      parser = argparse.ArgumentParser(description="Annealing example")
      parser.add_argument("--num_particles", nargs="?", default=36, type=int)
      parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float)
      parser.add_argument("--num-steps", nargs="?", default=20000, type=int)
      parser.add_argument("--unroll-loop", action="store_true")
      parser.add_argument(
          "--device", default="cpu", type=str, help='use "cpu" or "gpu".'
      )
      args = parser.parse_args()

      numpyro.set_platform(args.device)
      coix.set_backend("coix.oryx")

      main(args)


.. _sphx_glr_download_examples_anneal_oryx.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: anneal_oryx.ipynb <anneal_oryx.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: anneal_oryx.py <anneal_oryx.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: anneal_oryx.zip <anneal_oryx.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
