MultiHMCGibbs package

A Numpyro Gibbs sampler that uses conditioned HMC kernels for each Gibbs step.

class MultiHMCGibbs.multihmcgibbs.MultiHMCGibbs(inner_kernels, gibbs_sites_list)

Bases: MCMCKernel

Multi-HMC-within-Gibbs. This interface allows the user to combine multiple general purpose gradient-based inference (HMC or NUTS), each conditioned on a different set sub-set of sample sites, as steps in a Gibbs sampler.

Note that it is the user’s responsibility to ensure that every sample site is included in the gibbs_sites_list parameter and that each of the inner_kernels use the same Numpyro model function.

Parameters:
  • inner_kernels (List of HMC/NUTS kernels for each of the lists in gibbs_sites. All kernels) – must use the same `model` but any of the other parameters can be different (e.g. target_accept_prob).

  • gibbs_sites_list (List of lists of sites names that are free parameters for each Gibbs step, all other) – sample sites are fixed to their current values for the step. Each inner list is updated as a group, and the groups are updated in order. All sample sites for the model must be explicitly listed in only one of the groups.

  • **Example**

    >>> from jax import random
    >>> import jax.numpy as jnp
    >>> import numpyro
    >>> import numpyro.distributions as dist
    >>> from numpyro.infer import MCMC, NUTS
    >>> from MultiHMCGibbs import MultiHMCGibbs
    ...
    >>> def model():
    ...     x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    ...     y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    ...     numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))
    ...
    >>> inner_kernels = [NUTS(model), NUTS(model)]
    >>> outer_kernel = MultiHMCGibbs(inner_kernels, [['y'], ['x']])
    >>> mcmc = MCMC(kernel, num_warmup=100, num_samples=100, progress_bar=False)
    >>> mcmc.run(random.PRNGKey(0))
    >>> mcmc.print_summary()
    

check_gibbs_sites(model_args, model_kwargs)
property default_fields

The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is called).

get_diagnostics_str(state)

Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.

init(rng_key, num_warmup, init_params, model_args, model_kwargs)

Initialize the MCMCKernel and return an initial state to begin sampling from.

Parameters:
  • rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.

  • num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.

  • init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.

  • model_args – Arguments provided to the model.

  • model_kwargs – Keyword arguments provided to the model.

Returns:

The initial state representing the state of the kernel. This can be any class that is registered as a pytree.

property model
postprocess_fn(args, kwargs)

Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.

Parameters:
  • model_args – Arguments to the model.

  • model_kwargs – Keyword arguments to the model.

sample(state, model_args, model_kwargs)

Given the current state, return the next state using the given transition kernel.

Parameters:
  • state

    A pytree class representing the state for the kernel. For HMC, this is given by HMCState. In general, this could be any class that supports getattr.

  • model_args – Arguments provided to the model.

  • model_kwargs – Keyword arguments provided to the model.

Returns:

Next state.

sample_field = 'z'
class MultiHMCGibbs.multihmcgibbs.MultiHMCGibbsState(z, hmc_states, diverging, rng_key)

Bases: tuple

  • z - a dict of the current latent values (all sites)

  • hmc_states - list of current HMCState (one per gibbs step)

  • diverging - A list of boolean value to indicate whether the current trajectory is diverging.

  • rng_key - random number generator seed used for the iteration.

diverging

Alias for field number 2

hmc_states

Alias for field number 1

rng_key

Alias for field number 3

z

Alias for field number 0