BlackJAX Sampling Examples¤
Overview¤
This example provides a comprehensive exploration of BlackJAX samplers integrated with Workshop's distribution framework. It compares four different approaches to MCMC sampling: Workshop's HMC wrapper, Workshop's MALA wrapper, Workshop's NUTS wrapper, and direct BlackJAX API usage.
Files¤
- Python script:
examples/generative_models/sampling/blackjax_sampling_examples.py - Jupyter notebook:
examples/generative_models/sampling/blackjax_sampling_examples.ipynb
Quick Start¤
Learning Objectives¤
After completing this example, you will:
- Understand different MCMC sampling algorithms (HMC, MALA, NUTS)
- Learn to use Workshop's BlackJAX integration API
- Compare Workshop's sampler wrappers with direct BlackJAX usage
- Apply MCMC sampling to mixture distributions
- Visualize and interpret sampling results
- Handle memory constraints in NUTS sampling
Prerequisites¤
- Understanding of MCMC sampling fundamentals
- Familiarity with probability distributions
- Basic knowledge of Hamiltonian Monte Carlo
- Completion of BlackJAX Integration Example
- Workshop core sampling module
MCMC Algorithms Overview¤
This example demonstrates three state-of-the-art MCMC algorithms from the BlackJAX library.
Hamiltonian Monte Carlo (HMC)¤
HMC uses gradient information to propose efficient moves in parameter space by simulating Hamiltonian dynamics.
Key Characteristics:
- Uses gradient information for exploration
- Requires tuning of step size and integration steps
- Excellent for smooth, continuous distributions
- Higher computational cost per iteration than MH
Mathematical Formulation:
The Hamiltonian system is defined as: $$ H(q, p) = U(q) + K(p) $$
where:
- \(U(q) = -\log p(q)\) is the potential energy
- \(K(p) = \frac{1}{2}p^T M^{-1} p\) is the kinetic energy
- \(M\) is the mass matrix
The leapfrog integrator updates positions and momenta:
Metropolis-Adjusted Langevin Algorithm (MALA)¤
MALA is a gradient-based Metropolis method that uses Langevin dynamics for proposals.
Key Characteristics:
- Single step per iteration (faster than HMC)
- Gradient-based proposals for efficiency
- Good for smooth posteriors with strong gradients
- Lower acceptance rate than HMC typically
Mathematical Formulation:
The proposal distribution is:
where \(\eta \sim \mathcal{N}(0, I)\) is Gaussian noise.
The acceptance probability follows the Metropolis-Hastings rule:
No-U-Turn Sampler (NUTS)¤
NUTS automatically tunes the HMC trajectory length by building a tree of states until the trajectory makes a "U-turn".
Key Characteristics:
- No manual tuning of integration steps needed
- Adaptive step size selection
- State-of-the-art for Bayesian inference
- Higher memory usage due to trajectory storage
- Excellent for complex, high-dimensional posteriors
Algorithm Overview:
NUTS builds a balanced binary tree of trajectory states by recursively doubling until:
- The trajectory makes a U-turn (forward/backward directions oppose)
- Maximum tree depth is reached (
max_num_doublings)
The U-turn criterion is:
where \(\theta^+, p^+\) are the forward endpoint and \(\theta^-, p^-\) are the backward endpoint.
Code Walkthrough¤
Example 1: Workshop HMC Sampling¤
This example uses Workshop's HMC wrapper to sample from a bimodal mixture of Gaussians:
# Create a 2D mixture of Gaussians
def create_mixture_logprob():
mean1 = jnp.array([3.0, 3.0])
mean2 = jnp.array([-3.0, -3.0])
def log_prob_fn(x):
dist1 = Normal(loc=mean1, scale=jnp.array([1.0, 1.0]))
dist2 = Normal(loc=mean2, scale=jnp.array([1.0, 1.0]))
log_prob1 = jnp.sum(dist1.log_prob(x))
log_prob2 = jnp.sum(dist2.log_prob(x))
# Equal mixture weights
return jnp.logaddexp(log_prob1, log_prob2) - jnp.log(2.0)
return log_prob_fn
# Sample using Workshop's HMC wrapper
mixture_logprob = create_mixture_logprob()
init_state = jnp.zeros(2)
hmc_samples = hmc_sampling(
mixture_logprob,
init_state,
key,
n_samples=1000,
n_burnin=500,
step_size=0.1,
num_integration_steps=10,
)
Key Points:
- The mixture has two well-separated modes at [3, 3] and [-3, -3]
- HMC explores both modes efficiently using gradient information
- Workshop wrapper handles initialization and sampling loop
- Returns array of samples with shape
[n_samples, 2]
Example 2: Workshop MALA Sampling¤
This example demonstrates MALA on the same bimodal distribution:
mala_samples = mala_sampling(
mixture_logprob,
init_state,
key,
n_samples=1000,
n_burnin=500,
step_size=0.05, # Smaller step size than HMC
)
Key Points:
- MALA uses smaller step sizes than HMC (typically 0.05 vs 0.1)
- Single Langevin step per iteration makes it faster per sample
- May need more samples to achieve same effective sample size as HMC
- Good for problems where gradient evaluation is cheap
Example 3: Workshop NUTS Sampling¤
NUTS automatically tunes trajectory length, eliminating manual tuning:
# Use simpler distribution to avoid memory issues
simple_logprob = create_normal_logprob()
nuts_samples = nuts_sampling(
simple_logprob,
init_state,
key,
n_samples=500, # Fewer samples due to memory
n_burnin=200,
step_size=0.8,
max_num_doublings=5, # Control memory usage
)
Key Points:
- NUTS is memory-intensive due to trajectory tree storage
- Use
max_num_doublingsto control memory usage (default: 10) - Excellent for complex posteriors where tuning is difficult
- This example uses a simpler distribution to demonstrate the API
Example 4: Direct BlackJAX HMC¤
This example shows how to use BlackJAX's API directly without Workshop wrappers:
import blackjax
# Initialize the HMC algorithm
inverse_mass_matrix = jnp.eye(2)
hmc = blackjax.hmc(
mixture_logprob,
step_size=0.1,
inverse_mass_matrix=inverse_mass_matrix,
num_integration_steps=10,
)
# Initialize sampling state
initial_state = hmc.init(init_state)
# Define one step function
@nnx.jit
def one_step(state, key):
state, _ = hmc.step(key, state)
return state, state
# Burn-in phase
state = initial_state
for _ in range(n_burnin):
key, subkey = jax.random.split(key)
state, _ = one_step(state, subkey)
# Collect samples
key, subkey = jax.random.split(key)
state, samples = jax.lax.scan(
one_step,
state,
jax.random.split(subkey, n_samples)
)
samples = samples.position
Key Points:
- Direct API provides fine-grained control over sampling
- Must manually manage state and random keys
- Use
jax.lax.scanfor efficient sample collection - JIT compilation improves performance significantly
- Useful when implementing custom sampling logic
Expected Output¤
Sample Plots¤
Each example generates a scatter plot showing the samples in 2D space:
- HMC samples: Should show clear exploration of both modes
- MALA samples: Similar coverage but potentially more concentrated
- NUTS samples: For the normal distribution, centered at origin
- Direct API samples: Should match Workshop HMC results
Statistics¤
Each example prints sample statistics:
For the bimodal mixture, expect:
- Mean near [0, 0] (average of two modes)
- Large standard deviation (reflecting mode separation)
Performance Comparison¤
Computational Cost¤
| Method | Time per Sample | ESS per Sample | Memory Usage | Tuning Required |
|---|---|---|---|---|
| HMC | Medium | High | Low | Yes (step size, steps) |
| MALA | Low | Medium | Low | Yes (step size) |
| NUTS | High | Very High | High | Minimal (auto-tuning) |
| Direct API | Medium | High | Low | Yes (same as HMC) |
When to Use Each Method¤
Use HMC when:
- You have smooth, continuous target distributions
- You can afford moderate computational cost
- You want efficient exploration with gradients
Use MALA when:
- Gradient evaluation is cheap
- You need many samples quickly
- Target has strong gradients
Use NUTS when:
- You have complex, high-dimensional posteriors
- You can afford higher memory usage
- You want to avoid manual tuning
- You need robust inference
Use Direct API when:
- You need custom sampling logic
- You want fine-grained control
- You're implementing advanced algorithms
- Workshop wrappers don't fit your use case
Tuning Guidelines¤
HMC Tuning¤
Step Size (step_size):
- Start with 0.1
- Target acceptance rate: 0.6-0.8
- Too large: low acceptance rate
- Too small: slow mixing
Integration Steps (num_integration_steps):
- Start with 10
- Increase for better exploration
- Higher values increase cost per sample
MALA Tuning¤
Step Size (step_size):
- Start with 0.05 (smaller than HMC)
- Target acceptance rate: 0.5-0.7
- Adjust based on acceptance diagnostics
NUTS Tuning¤
Step Size (step_size):
- Often auto-tuned during warmup
- Can set manually if needed
- Usually between 0.1-1.0
Max Doublings (max_num_doublings):
- Controls trajectory length and memory
- Default: 10 (max trajectory length = 2^10 = 1024)
- Reduce if encountering memory errors
- Values 5-7 often sufficient
Experiments to Try¤
-
Compare mixing: Plot trace plots and autocorrelation for each sampler to assess mixing quality
-
Tune hyperparameters: Systematically vary step sizes and integration steps, tracking acceptance rates and ESS
-
Higher dimensions: Extend the mixture to 10D or 20D to see how samplers scale
-
Different targets: Try non-Gaussian distributions like:
- Rosenbrock's banana-shaped distribution
- Neal's funnel distribution
-
Mixture of many components
-
Effective sample size: Compute ESS using
arvizor similar tools to measure sampling efficiency -
Warmup strategies: Experiment with different warmup lengths and adaptive schemes
-
Parallel chains: Run multiple chains and assess convergence using R-hat
Troubleshooting¤
Low Acceptance Rate¤
Symptom: Acceptance rate below 0.5
Solution:
- Reduce
step_sizeby factor of 2 - Check gradient computation (no NaNs)
- Verify log probability is correct
- Try simpler test distribution first
Poor Mixing¤
Symptom: Samples stuck in one mode of multimodal distribution
Solution:
- Increase burn-in period (try 2x-5x current)
- Try different initialization points
- Consider parallel tempering for multimodal targets
- Increase
num_integration_stepsfor HMC
NUTS Memory Errors¤
Symptom: Out of memory errors with NUTS
Solution:
# Reduce memory usage
nuts_samples = nuts_sampling(
log_prob_fn,
init_state,
key,
n_samples=500, # Reduce sample count
n_burnin=200,
max_num_doublings=5, # Lower from default 10
)
Divergent Transitions (NUTS)¤
Symptom: Warning about divergent transitions
Solution:
- Decrease
step_size(try 0.5x current) - Reparameterize the model (e.g., non-centered parameterization)
- Check for prior-likelihood conflicts
- Increase warmup period
Slow Performance¤
Symptom: Sampling taking too long
Solution:
- Ensure JIT compilation is used (
@nnx.jitor@jax.jit) - Check if GPU is available and being used
- Use Direct API with
jax.lax.scanfor efficient loops - Reduce sample count for testing
Next Steps¤
Related Examples¤
-
BlackJAX Integration Example
Learn the basics of BlackJAX with Workshop
-
BlackJAX Integration Examples
Advanced integration patterns and production use cases
Further Learning¤
- BlackJAX Documentation
- BlackJAX Sampling Book
- MCMC Diagnostics Guide
- HMC Tutorial by Betancourt
- NUTS Paper
- Workshop Sampling Module Documentation
Additional Resources¤
Papers¤
-
HMC: Neal, R. M. (2011). "MCMC using Hamiltonian dynamics". Handbook of Markov Chain Monte Carlo.
-
NUTS: Hoffman, M. D., & Gelman, A. (2014). "The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo". Journal of Machine Learning Research.
-
MALA: Roberts, G. O., & Tweedie, R. L. (1996). "Exponential convergence of Langevin distributions and their discrete approximations". Bernoulli.
-
Convergence Diagnostics: Vehtari, A., et al. (2021). "Rank-Normalization, Folding, and Localization: An Improved R-hat for Assessing Convergence of MCMC". Bayesian Analysis.
Code References¤
- Distribution creation:
workshop.generative_models.core.distributions - Sampling functions:
workshop.generative_models.core.sampling - BlackJAX wrappers:
workshop.generative_models.core.sampling.blackjax_samplers - Direct BlackJAX API:
blackjax.hmc,blackjax.nuts,blackjax.mala
Diagnostic Tools¤
- ArviZ: Python package for MCMC diagnostics and visualization
- PyStan: Stan interface with excellent diagnostics
- PyMC: Bayesian modeling with built-in diagnostics
Support¤
If you encounter issues:
- Check that BlackJAX is installed:
pip install blackjax - Verify JAX GPU/CPU setup is correct
- Review error messages for parameter constraints
- Check BlackJAX documentation for API changes
- Consult Workshop documentation or open an issue
Tags: #mcmc #blackjax #hmc #nuts #mala #sampling #comparison #advanced
Difficulty: Advanced
Estimated Time: 20-30 minutes