Skip to content

Simple Energy-Based Model (EBM) Example¤

Level: Intermediate | Runtime: ~2 minutes (CPU) / ~30 seconds (GPU) | Format: Python + Jupyter

Overview¤

This comprehensive example demonstrates Energy-Based Models (EBMs) with MCMC sampling. It covers the fundamentals of energy functions, Langevin dynamics sampling, and contrastive divergence training, including advanced techniques like persistent contrastive divergence and deep EBM architectures.

What You'll Learn¤

  • Energy function computation and interpretation
  • Langevin dynamics for MCMC sampling
  • Contrastive divergence (CD) training
  • Persistent contrastive divergence with sample buffers
  • Deep EBM architectures with residual connections
  • Score function estimation

Files¤

Quick Start¤

Run the Python Script¤

# Activate environment
source activate.sh

# Run the example
python examples/generative_models/energy/simple_ebm_example.py

Run the Jupyter Notebook¤

# Activate environment
source activate.sh

# Launch Jupyter
jupyter lab examples/generative_models/energy/simple_ebm_example.ipynb

Key Concepts¤

Energy-Based Models¤

EBMs define a probability distribution through an energy function:

\[p(x) = \frac{\exp(-E(x))}{Z}\]

Where:

  • \(E(x)\) is the energy function (lower energy = higher probability)
  • \(Z\) is the partition function (normalization constant)

Langevin Dynamics¤

Sampling from the model using Langevin dynamics:

\[x_{t+1} = x_t - \frac{\epsilon}{2} \nabla_x E(x_t) + \sqrt{\epsilon} \cdot \text{noise}\]

Where \(\epsilon\) is the step size and noise is Gaussian.

Contrastive Divergence Loss¤

Training objective that contrasts real data with model samples:

\[\mathcal{L}_{CD} = E_{x \sim p_{data}}[E(x)] - E_{x \sim p_{model}}[E(x)]\]

The goal is to:

  • Lower energy for real data samples
  • Raise energy for generated samples

Persistent Contrastive Divergence¤

An improved version of CD that maintains a buffer of persistent samples across training iterations, leading to better gradient estimates.

Code Structure¤

The example demonstrates 9 major sections:

  1. Simple EBM Creation: Basic energy-based model for MNIST
  2. Energy Computation: Computing and interpreting energy values
  3. MCMC Sampling: Generating samples using Langevin dynamics
  4. Configuration System: Declarative model creation
  5. Contrastive Divergence: Training with CD loss
  6. Persistent CD: Advanced training with sample buffers
  7. Deep EBM: Complex architectures with residual connections
  8. Visualization: Analyzing samples and energy landscapes
  9. Summary: Key takeaways and experiments

Example Code¤

Creating a Simple EBM¤

from flax import nnx
import jax.numpy as jnp
from workshop.generative_models.models.energy import EnergyBasedModel

# Initialize RNG
rngs = nnx.Rngs(42)

# Create simple EBM
ebm = EnergyBasedModel(
    input_dim=(28, 28, 1),
    hidden_dims=[256, 128],
    rngs=rngs
)

# Compute energy for data
data = jnp.ones((32, 28, 28, 1))
energy = ebm.energy(data)
print(f"Energy shape: {energy.shape}")  # (32,)

Langevin Dynamics Sampling¤

from workshop.generative_models.core.sampling import langevin_dynamics

# Generate samples using MCMC
samples = langevin_dynamics(
    energy_fn=ebm.energy,
    init_samples=jnp.random.normal(key, (16, 28, 28, 1)),
    step_size=0.01,
    n_steps=100,
    temperature=1.0
)

Contrastive Divergence Training¤

from workshop.generative_models.core.losses import contrastive_divergence_loss

# Compute CD loss
real_data = data_batch  # Your training data
fake_data = samples      # MCMC samples

loss_dict = contrastive_divergence_loss(
    ebm=ebm,
    real_data=real_data,
    fake_data=fake_data
)

print(f"CD Loss: {loss_dict['loss']:.4f}")
print(f"Real Energy: {loss_dict['real_energy_mean']:.4f}")
print(f"Fake Energy: {loss_dict['fake_energy_mean']:.4f}")

Persistent Contrastive Divergence¤

from workshop.generative_models.core.sampling import SampleBuffer

# Create persistent sample buffer
buffer = SampleBuffer(capacity=10000, sample_shape=(28, 28, 1))

# Training loop with persistent CD
for epoch in range(num_epochs):
    # Get samples from buffer (or initialize if empty)
    init_samples = buffer.sample(batch_size=32) if not buffer.is_empty() else None

    # Run MCMC from buffer samples
    samples = langevin_dynamics(
        energy_fn=ebm.energy,
        init_samples=init_samples,
        step_size=0.01,
        n_steps=20  # Fewer steps with persistent buffer
    )

    # Update buffer
    buffer.add(samples)

    # Compute loss and update model
    loss = contrastive_divergence_loss(ebm, real_data, samples)

Deep EBM Architecture¤

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.energy import DeepEBM

# Create deep EBM configuration
deep_config = ModelConfiguration(
    name="deep_ebm",
    model_class="workshop.generative_models.models.energy.DeepEBM",
    input_dim=(32, 32, 3),
    hidden_dims=[32, 64, 128],
    output_dim=1,
    activation="silu",
    parameters={
        "base_channels": 32,
        "num_blocks": 3,
        "use_spectral_norm": True,
    },
)

# Create model
deep_ebm = DeepEBM(deep_config, rngs=rngs)

Features Demonstrated¤

Energy Function Design¤

  • MLP-based energy functions
  • CNN-based energy functions
  • Deep architectures with residual connections
  • Spectral normalization for stability

MCMC Sampling¤

  • Langevin dynamics implementation
  • Step size tuning
  • Temperature control
  • Burn-in periods

Training Techniques¤

  • Standard contrastive divergence
  • Persistent contrastive divergence
  • Sample buffer management
  • Gradient estimation

Advanced Architectures¤

  • Deep convolutional EBMs
  • Residual connections
  • Batch normalization
  • Spectral normalization

Experiments to Try¤

  1. Modify energy architecture: Try different hidden dimensions or activation functions
  2. Tune MCMC parameters: Experiment with step sizes, number of steps, temperature
  3. Compare CD vs Persistent CD: Observe training stability and sample quality
  4. Add noise annealing: Gradually reduce noise during sampling
  5. Conditional EBMs: Extend to conditional generation with labels
  6. Hybrid models: Combine EBMs with other generative models

Next Steps¤

After understanding this example:

  1. Training Loop: Implement full training on real datasets (MNIST, CIFAR-10)
  2. Score Matching: Explore score matching as an alternative to CD
  3. Conditional Generation: Add class or attribute conditioning
  4. Energy Landscape Analysis: Visualize and analyze learned energy functions
  5. Compositional Generation: Combine multiple EBMs for complex generation

Troubleshooting¤

MCMC Not Converging¤

  • Increase number of sampling steps
  • Reduce step size
  • Add noise annealing schedule
  • Check energy function gradients

Training Instability¤

  • Use spectral normalization
  • Reduce learning rate
  • Increase batch size
  • Use persistent CD with larger buffer

Slow Sampling¤

  • Use GPU acceleration
  • Reduce number of MCMC steps
  • Use persistent buffers to start from better initializations
  • Consider faster sampling methods (e.g., HMC)

Memory Issues¤

  • Reduce buffer size for persistent CD
  • Use smaller batch sizes
  • Reduce model size (fewer hidden dims)

Additional Resources¤