Neal’s funnel

Now lets take a distribution where a Gibbs sampler is needed to get a decent result, Neal’s funnel. In general Gibbs sampler is useful for hierarchical models, with each part of the hierarchy being treated as different steps.

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import arviz
import matplotlib.pyplot as plt

from numpyro.infer import MCMC, NUTS
from jax import random

from MultiHMCGibbs import MultiHMCGibbs
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

The model

We will set up 5 dimensional funnel, something that regular HMC struggles with. For this model we know what the true marginal distribution for the y variable will be Normal(0, 3), we will use this to see how well each sampler does.

def model(dim=5):
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))


x_marginal_true = jnp.linspace(-10, 10, 1000)
y_marginal_true = jnp.exp(dist.Normal(0, 3).log_prob(x_marginal_true))
def run_inference(kernel, chain_method, rng_key):
    mcmc = MCMC(
        kernel,
        num_warmup=3000,
        num_samples=5000,
        num_chains=4,
        chain_method=chain_method,
        progress_bar=False
    )
    mcmc.run(rng_key)
    return mcmc

NUTS

We will use a large target_accept_prob to get rid of most divergent samples and use a large number of warmup and samples to get the r_hats close to 1.

rng_key = random.PRNGKey(0)
hmc_key, gibbs_key = random.split(rng_key)
funnel_mcmc_hmc = run_inference(NUTS(model, target_accept_prob=0.995), 'vectorized', hmc_key)
inf_funnel_hmc = arviz.from_numpyro(funnel_mcmc_hmc)
print(f'divergences per chain: {inf_funnel_hmc.sample_stats.diverging.values.sum(axis=1)}')
display(arviz.summary(inf_funnel_hmc))
divergences per chain: [ 72  24   7 185]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x[0] 0.004 7.151 -8.680 8.001 0.064 0.102 11381.0 1811.0 1.01
x[1] -0.061 8.437 -8.011 9.004 0.076 0.107 14583.0 1785.0 1.01
x[2] 0.051 7.553 -7.819 9.547 0.075 0.104 12495.0 1753.0 1.02
x[3] -0.047 7.840 -8.481 8.486 0.083 0.093 13393.0 2163.0 1.01
y 0.243 2.791 -4.883 4.681 0.256 0.182 95.0 25.0 1.03

Give the large number of divergent samples in chain 4 let’s remove that from the plots.

x_model_hmc = inf_funnel_hmc.isel(chain=[0, 1, 2]).posterior.x[..., 0].data.flatten()
y_model_hmc = inf_funnel_hmc.isel(chain=[0, 1, 2]).posterior.y.data.flatten()

plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(x_model_hmc, y_model_hmc, '.')
plt.xlabel('x[0]')
plt.ylabel('y')
plt.xlim(-100, 100)
plt.subplot(122)
plt.hist(y_model_hmc, bins=30, histtype='step', density=True, label='HMC')
plt.plot(x_marginal_true, y_marginal_true, color='k', label='True marginal')
plt.xlabel('y')
plt.legend();
_images/8b41df9924da27e61e421f241081f0bd8af022e54cc8b6c681e6a12513ed88f8.png

We can see that NUTS is struggling with this model, the y marginal is still missing a bit of negative values at the bottom of the funnel.

MultiHMCGibbs

For the MultiHMCGibbs sampler we will only put a large target_accept_prob on the x values (as these are the difficult ones to draw), but keep the default value for the y values. To keep it on the same footing as the previous run we will use the same number of warm up and sample draws.

funnel_mcmc_gibbs = run_inference(
    MultiHMCGibbs(
        [NUTS(model, target_accept_prob=0.995), NUTS(model)],
        [['x'], ['y']]
    ),
    'vectorized',
    gibbs_key
)
inf_funnel_gibbs = arviz.from_numpyro(funnel_mcmc_gibbs)
print(f'divergences per chain per step:\n {inf_funnel_gibbs.sample_stats.diverging.values.sum(axis=1).T}')
display(arviz.summary(inf_funnel_gibbs))
divergences per chain per step:
 [[  44    0   14 1192]
 [   0    0    0    0]]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x[0] -0.007 5.228 -6.449 6.068 0.085 0.091 6372.0 1006.0 1.01
x[1] 0.050 5.379 -6.195 6.524 0.077 0.074 8707.0 1184.0 1.01
x[2] -0.135 6.341 -6.192 6.487 0.132 0.128 9453.0 1112.0 1.01
x[3] -0.042 6.220 -5.783 7.172 0.094 0.101 9905.0 1209.0 1.01
y -0.455 2.934 -5.943 4.847 0.267 0.189 123.0 495.0 1.02

Again chain 4 has a large number of divergences, so let’s remove it from the plots.

x_model_gibbs = inf_funnel_gibbs.isel(chain=[0, 1, 2]).posterior.x[..., 0].data.flatten()
y_model_gibbs = inf_funnel_gibbs.isel(chain=[0, 1, 2]).posterior.y.data.flatten()

plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(x_model_hmc, y_model_hmc, '.', label='HMC', zorder=2)
plt.plot(x_model_gibbs, y_model_gibbs, '.', label='Gibbs', zorder=1)
plt.xlabel('x[0]')
plt.ylabel('y')
plt.legend()
plt.xlim(-100, 100)
plt.subplot(122)
plt.hist(y_model_hmc, bins=30, histtype='step', label='HMC', density=True)
plt.hist(y_model_gibbs, bins=30, histtype='step', label='Gibbs', density=True)
plt.plot(x_marginal_true, y_marginal_true, color='k', label='True marginal')
plt.xlabel('y')
plt.legend();
_images/0a09244c9fbef0b9d238fbf278db990f09afdb15cf8299fc4a484a3e7c504484.png

We can see that with the same set up MultiHMCGibbs was able to reach deeper into the funnel and pull out the negative y values missed by NUTS.

Note: Even if I did not remove chain 4 from both samples the results would be much the same. Under this parameterization this model is hard for both samplers to draw from, MultiHMCGibbs just does a better job with the same set up.

Other notes

  • You can use as many inner_kernels as you want

  • The order the kernels are stepped in is set by the order of the parameter list (in the example above x septs first, followed by y)

  • The order matters! Typically you want to step the parameters closest to the likelihood first and the hyper-parameters second. But for some models this might not be so clear, so some experimentation could be needed.