MultiHMCGibbs package
A Numpyro Gibbs sampler that uses conditioned HMC kernels for each Gibbs step.
- class MultiHMCGibbs.multihmcgibbs.MultiHMCGibbs(inner_kernels, gibbs_sites_list)
Bases:
MCMCKernel
Multi-HMC-within-Gibbs. This interface allows the user to combine multiple general purpose gradient-based inference (HMC or NUTS), each conditioned on a different set sub-set of sample sites, as steps in a Gibbs sampler.
Note that it is the user’s responsibility to ensure that every sample site is included in the gibbs_sites_list parameter and that each of the inner_kernels use the same Numpyro model function.
- Parameters:
inner_kernels (List of HMC/NUTS kernels for each of the lists in gibbs_sites. All kernels) – must use the same `model` but any of the other parameters can be different (e.g. target_accept_prob).
gibbs_sites_list (List of lists of sites names that are free parameters for each Gibbs step, all other) – sample sites are fixed to their current values for the step. Each inner list is updated as a group, and the groups are updated in order. All sample sites for the model must be explicitly listed in only one of the groups.
**Example** –
>>> from jax import random >>> import jax.numpy as jnp >>> import numpyro >>> import numpyro.distributions as dist >>> from numpyro.infer import MCMC, NUTS >>> from MultiHMCGibbs import MultiHMCGibbs ... >>> def model(): ... x = numpyro.sample("x", dist.Normal(0.0, 2.0)) ... y = numpyro.sample("y", dist.Normal(0.0, 2.0)) ... numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0])) ... >>> inner_kernels = [NUTS(model), NUTS(model)] >>> outer_kernel = MultiHMCGibbs(inner_kernels, [['y'], ['x']]) >>> mcmc = MCMC(kernel, num_warmup=100, num_samples=100, progress_bar=False) >>> mcmc.run(random.PRNGKey(0)) >>> mcmc.print_summary()
- check_gibbs_sites(model_args, model_kwargs)
- property default_fields
The attributes of the state object to be collected by default during the MCMC run (when
MCMC.run()
is called).
- get_diagnostics_str(state)
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
- init(rng_key, num_warmup, init_params, model_args, model_kwargs)
Initialize the MCMCKernel and return an initial state to begin sampling from.
- Parameters:
rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
num_warmup (int) – Number of warmup steps. This can be useful when doing adaptation during warmup.
init_params (tuple) – Initial parameters to begin sampling. The type must be consistent with the input type to potential_fn.
model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
The initial state representing the state of the kernel. This can be any class that is registered as a pytree.
- property model
- postprocess_fn(args, kwargs)
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s support, in addition to returning deterministic sites in the model.
- Parameters:
model_args – Arguments to the model.
model_kwargs – Keyword arguments to the model.
- sample(state, model_args, model_kwargs)
Given the current state, return the next state using the given transition kernel.
- Parameters:
state –
A pytree class representing the state for the kernel. For HMC, this is given by
HMCState
. In general, this could be any class that supports getattr.model_args – Arguments provided to the model.
model_kwargs – Keyword arguments provided to the model.
- Returns:
Next state.
- sample_field = 'z'
- class MultiHMCGibbs.multihmcgibbs.MultiHMCGibbsState(z, hmc_states, diverging, rng_key)
Bases:
tuple
z - a dict of the current latent values (all sites)
hmc_states - list of current
HMCState
(one per gibbs step)diverging - A list of boolean value to indicate whether the current trajectory is diverging.
rng_key - random number generator seed used for the iteration.
- diverging
Alias for field number 2
- hmc_states
Alias for field number 1
- rng_key
Alias for field number 3
- z
Alias for field number 0