BlackJAX Integration Example¤
Overview¤
This example demonstrates how to use BlackJAX samplers with Workshop's distribution framework, comparing different MCMC algorithms on both multimodal distributions and Bayesian regression tasks.
Files¤
- Python script:
examples/generative_models/sampling/blackjax_example.py - Jupyter notebook:
examples/generative_models/sampling/blackjax_example.ipynb
Quick Start¤
Learning Objectives¤
After completing this example, you will:
- Understand how to use BlackJAX samplers (HMC, NUTS, MALA) with Workshop
- Learn to sample from multimodal distributions using different MCMC methods
- Implement Bayesian regression using NUTS sampling
- Compare different sampling algorithms for the same problem
- Visualize and interpret MCMC sampling results
Prerequisites¤
- Understanding of MCMC sampling concepts
- Basic knowledge of Bayesian inference
- Familiarity with probability distributions
- Workshop core sampling module
What is BlackJAX?¤
BlackJAX is a library of samplers for JAX that provides state-of-the-art MCMC algorithms. Workshop integrates BlackJAX to offer advanced sampling capabilities.
Supported Algorithms¤
| Algorithm | Description | Best For |
|---|---|---|
| HMC | Hamiltonian Monte Carlo | Smooth, continuous distributions |
| NUTS | No-U-Turn Sampler | Complex posteriors, automatic tuning |
| MALA | Metropolis-Adjusted Langevin | Gradient-based sampling |
Theory¤
Hamiltonian Monte Carlo (HMC)¤
HMC uses Hamiltonian dynamics to propose efficient moves in the parameter space:
where \(U(q) = -\log p(q)\) is the potential energy and \(K(p) = \frac{1}{2}p^T M^{-1} p\) is the kinetic energy.
The algorithm simulates Hamiltonian dynamics using the leapfrog integrator:
No-U-Turn Sampler (NUTS)¤
NUTS automatically tunes the HMC trajectory length by building a tree of states until the trajectory makes a "U-turn". This eliminates the need to manually set the number of integration steps.
MALA (Metropolis-Adjusted Langevin Algorithm)¤
MALA uses Langevin dynamics for proposals:
where \(\eta \sim \mathcal{N}(0, I)\) is Gaussian noise.
Code Walkthrough¤
Example 1: Multimodal Distribution Sampling¤
The first example compares four sampling methods on a bimodal Gaussian mixture:
# Define a bimodal log probability function
def log_prob_fn(x):
log_prob1 = -0.5 * ((x - 2.0) ** 2) / 0.5
log_prob2 = -0.5 * ((x + 2.0) ** 2) / 0.5
return jnp.logaddexp(log_prob1, log_prob2)
# Sample using Metropolis-Hastings
mh_samples = mcmc_sampling(
log_prob_fn=log_prob_fn,
init_state=init_state,
key=key,
n_samples=2000,
n_burnin=500,
step_size=0.5,
)
# Sample using HMC (BlackJAX)
hmc_samples = hmc_sampling(
log_prob_fn=log_prob_fn,
init_state=init_state,
key=key,
n_samples=2000,
n_burnin=500,
step_size=0.1,
num_integration_steps=10,
)
Key Points:
- The bimodal distribution has two modes at \(x = -2\) and \(x = 2\)
- HMC uses gradient information for more efficient exploration
- NUTS automatically tunes trajectory length
- MALA balances speed and efficiency
Example 2: Bayesian Linear Regression¤
The second example demonstrates Bayesian parameter estimation:
# Define Bayesian regression model
def log_prob_fn(params):
beta = params["beta"]
log_sigma = params["log_sigma"]
sigma = jnp.exp(log_sigma)
# Prior
prior_beta = jnp.sum(jax.scipy.stats.norm.logpdf(beta, loc=0.0, scale=1.0))
prior_sigma = jax.scipy.stats.norm.logpdf(log_sigma, loc=-2.0, scale=1.0)
# Likelihood
y_pred = X @ beta
likelihood = jnp.sum(jax.scipy.stats.norm.logpdf(y, loc=y_pred, scale=sigma))
return prior_beta + prior_sigma + likelihood
# Sample using NUTS
nuts_samples = nuts_sampling(
log_prob_fn=log_prob_fn,
init_state=init_state,
key=key,
n_samples=2000,
n_burnin=1000,
)
Key Points:
- NUTS is ideal for Bayesian inference with multiple parameters
- The model includes priors on coefficients and noise scale
- Posterior distributions recover true parameter values
- No manual tuning of trajectory length needed
Expected Output¤
Multimodal Distribution¤
The example generates comparison plots showing samples from each method:
- Metropolis-Hastings: Baseline performance
- HMC: Efficient exploration with gradient information
- NUTS: Similar to HMC but with automatic tuning
- MALA: Fast per-iteration sampling
You should see all methods successfully sample from both modes of the distribution.
Bayesian Regression¤
The regression example produces:
- Coefficient posteriors: Distributions for each \(\beta_i\) parameter
- Noise posterior: Distribution for noise scale \(\sigma\)
- Comparison with truth: True values marked on plots
Expected results:
- Posterior means close to true parameter values
- Reasonable posterior uncertainty
- Successful convergence after burn-in
Performance Considerations¤
Computational Cost¤
| Method | Time per Sample | Effective Sample Size (ESS) | Overall Efficiency |
|---|---|---|---|
| MH | Low | Low | Baseline |
| HMC | Medium | High | Good |
| NUTS | High | Very High | Excellent |
| MALA | Low-Medium | Medium | Good |
Memory Usage¤
- HMC/MALA: Moderate memory usage
- NUTS: Higher memory usage due to trajectory storage
- For memory-constrained systems, reduce
max_num_doublingsin NUTS
Tuning Guidelines¤
HMC:
step_size: Start with 0.1, adjust based on acceptance rate (target: 0.6-0.8)num_integration_steps: Start with 10, increase for complex distributions
MALA:
step_size: Start with 0.05, adjust based on acceptance rate (target: 0.5-0.7)
NUTS:
step_size: Usually auto-tuned, but can be set manuallymax_num_doublings: Controls trajectory length and memory usage (default: 10)
Experiments to Try¤
-
Change the distribution: Modify the bimodal distribution to have three modes or varying widths
-
Tune hyperparameters: Experiment with different step sizes and integration steps, observing effects on acceptance rate and mixing
-
Compare convergence: Plot traces and autocorrelation to assess convergence and mixing for each method
-
Higher dimensions: Extend Bayesian regression to 20-50 features to see how samplers scale
-
Different priors: Try informative vs uninformative priors on regression coefficients
-
Visualize traces: Add MCMC trace plots to check convergence
Troubleshooting¤
Low Acceptance Rate¤
Symptom: Acceptance rate below 0.5 for HMC/MALA
Solution:
- Decrease
step_sizefor HMC/MALA - Check gradient computation (should not have NaNs)
- Verify log probability function is correct
Poor Mixing¤
Symptom: Samples stay in one mode of multimodal distribution
Solution:
- Increase burn-in period
- Try different initialization points
- Consider tempering or parallel chains
NUTS Memory Errors¤
Symptom: Out of memory errors with NUTS
Solution:
# Reduce max_num_doublings
nuts_samples = nuts_sampling(
log_prob_fn=log_prob_fn,
init_state=init_state,
key=key,
n_samples=1000, # Reduce number of samples
n_burnin=500,
max_num_doublings=5, # Lower from default 10
)
Divergent Transitions (NUTS)¤
Symptom: Warning about divergent transitions
Solution:
- Decrease
step_size - Reparameterize the model (e.g., use non-centered parameterization)
- Check for prior-likelihood conflicts
Next Steps¤
Related Examples¤
-
BlackJAX Sampling Examples
Explore more sampling algorithms and advanced usage patterns
-
BlackJAX Integration Examples
Learn advanced integration with Workshop distributions
Further Learning¤
- BlackJAX Documentation
- MCMC Diagnostics
- HMC Tutorial
- Workshop Sampling Module Documentation
Additional Resources¤
Papers¤
- Neal, R. M. (2011). "MCMC using Hamiltonian dynamics"
- Hoffman, M. D., & Gelman, A. (2014). "The No-U-Turn Sampler"
- Roberts, G. O., & Tweedie, R. L. (1996). "Exponential convergence of Langevin distributions and their discrete approximations"
Code References¤
- Distribution creation:
workshop.generative_models.core.distributions - Sampling functions:
workshop.generative_models.core.sampling - BlackJAX integration:
workshop.generative_models.core.sampling.blackjax_samplers
Support¤
If you encounter issues:
- Check that BlackJAX is installed:
pip install blackjax - Verify JAX GPU/CPU setup is correct
- Review error messages for parameter constraints
- Consult Workshop documentation or open an issue
Tags: #mcmc #blackjax #hmc #nuts #mala #bayesian #sampling
Difficulty: Intermediate
Estimated Time: 15-20 minutes