Skip to content

Energy-Based Models User Guide¤

Complete guide to building, training, and using Energy-Based Models with Workshop.

Overview¤

This guide covers practical usage of EBMs in Workshop, from basic setup to advanced techniques. You'll learn how to:

  • Configure EBMs


    Set up energy functions and MCMC sampling parameters

  • Train Models


    Train with persistent contrastive divergence and monitor stability

  • Generate Samples


    Sample using Langevin dynamics and MCMC methods

  • Tune & Debug


    Optimize hyperparameters and troubleshoot common issues


Quick Start¤

Basic EBM Example¤

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.energy import EBM

# Initialize RNGs
rngs = nnx.Rngs(params=0, noise=1, sample=2)

# Configuration for MNIST
config = ModelConfiguration(
    name="mnist_ebm",
    model_class="workshop.generative_models.models.energy.ebm.EBM",
    input_dim=(28, 28, 1),
    hidden_dims=[128, 256, 512],
    output_dim=1,
    activation="silu",
    parameters={
        "energy_type": "cnn",
        "mcmc_steps": 60,
        "mcmc_step_size": 0.01,
        "mcmc_noise_scale": 0.005,
        "sample_buffer_capacity": 8192,
        "sample_buffer_reinit_prob": 0.05,
        "alpha": 0.01,  # Regularization strength
    }
)

# Create EBM
model = EBM(config, rngs=rngs)

# Training step
batch = {"x": jnp.ones((32, 28, 28, 1))}
loss_dict = model.train_step(batch, rngs=rngs)

print(f"Loss: {loss_dict['loss']:.4f}")
print(f"Real energy: {loss_dict['real_energy']:.4f}")
print(f"Fake energy: {loss_dict['fake_energy']:.4f}")

Creating EBM Models¤

1. Standard EBM (MLP Energy Function)¤

For tabular or low-dimensional data:

from workshop.generative_models.models.energy import EBM

# MLP energy function configuration
config = ModelConfiguration(
    name="tabular_ebm",
    model_class="workshop.generative_models.models.energy.ebm.EBM",
    input_dim=100,  # Input features
    hidden_dims=[256, 256, 128],
    output_dim=1,
    activation="gelu",
    dropout_rate=0.1,
    parameters={
        "energy_type": "mlp",
        "mcmc_steps": 60,
        "mcmc_step_size": 0.01,
        "mcmc_noise_scale": 0.005,
        "sample_buffer_capacity": 4096,
        "alpha": 0.01,
    }
)

model = EBM(config, rngs=rngs)

Key Parameters:

Parameter Default Description
energy_type "mlp" Energy function architecture (mlp/cnn)
mcmc_steps 60 Number of Langevin dynamics steps
mcmc_step_size 0.01 Step size for gradient descent
mcmc_noise_scale 0.005 Noise scale for exploration
alpha 0.01 Regularization strength

2. CNN Energy Function (for Images)¤

For image data:

config = ModelConfiguration(
    name="image_ebm",
    model_class="workshop.generative_models.models.energy.ebm.EBM",
    input_dim=(32, 32, 3),  # CIFAR-10 dimensions
    hidden_dims=[64, 128, 256],
    output_dim=1,
    activation="silu",
    parameters={
        "energy_type": "cnn",
        "input_channels": 3,
        "mcmc_steps": 100,
        "mcmc_step_size": 0.005,
        "mcmc_noise_scale": 0.001,
        "sample_buffer_capacity": 8192,
        "sample_buffer_reinit_prob": 0.05,
        "alpha": 0.001,
    }
)

model = EBM(config, rngs=rngs)

3. Deep EBM (Complex Data)¤

For complex datasets requiring deeper architectures:

from workshop.generative_models.models.energy import DeepEBM

config = ModelConfiguration(
    name="deep_ebm",
    model_class="workshop.generative_models.models.energy.ebm.DeepEBM",
    input_dim=(32, 32, 3),
    hidden_dims=[32, 64, 128, 256],
    output_dim=1,
    activation="silu",
    parameters={
        "use_residual": True,
        "use_spectral_norm": True,
        "mcmc_steps": 100,
        "mcmc_step_size": 0.005,
        "mcmc_noise_scale": 0.001,
        "sample_buffer_capacity": 8192,
        "alpha": 0.001,
    }
)

model = DeepEBM(config, rngs=rngs)

Deep EBM Features:

  • Residual connections: Enable deeper networks (10+ layers)
  • Spectral normalization: Stabilizes training
  • GroupNorm: Better than BatchNorm for MCMC sampling

Training EBMs¤

Basic Training Loop¤

from workshop.generative_models.training.trainers import EnergyTrainer
from workshop.generative_models.core.configuration import TrainingConfiguration

# Training configuration
train_config = TrainingConfiguration(
    num_epochs=100,
    batch_size=128,
    learning_rate=1e-4,
    optimizer="adam",
    save_dir="./checkpoints/ebm",
    log_every=10,
    save_every=1000,
)

# Create trainer
trainer = EnergyTrainer(
    model=model,
    config=train_config,
    rngs=rngs,
)

# Train
history = trainer.train(train_loader, val_loader=None)

Training with Monitoring¤

Monitor key metrics during training:

def train_step_with_monitoring(model, batch, rngs):
    """Training step with detailed monitoring."""
    loss_dict = model.train_step(batch, rngs=rngs)

    # Log metrics
    print(f"Step metrics:")
    print(f"  Loss: {loss_dict['loss']:.4f}")
    print(f"  Real energy: {loss_dict['real_energy']:.4f}")
    print(f"  Fake energy: {loss_dict['fake_energy']:.4f}")
    print(f"  Energy gap: {loss_dict['energy_gap']:.4f}")

    # Check for issues
    if loss_dict['energy_gap'] < 0:
        print("WARNING: Negative energy gap - real data has higher energy!")

    if abs(loss_dict['real_energy']) > 100:
        print("WARNING: Energy explosion detected!")

    return loss_dict

# Training loop
for epoch in range(num_epochs):
    for batch in train_loader:
        loss_dict = train_step_with_monitoring(model, batch, rngs)

Hyperparameter Guidelines¤

MCMC Sampling:

# Quick sampling (less accurate)
quick_config = {
    "mcmc_steps": 20,
    "mcmc_step_size": 0.02,
    "mcmc_noise_scale": 0.01,
}

# Standard sampling (balanced)
standard_config = {
    "mcmc_steps": 60,
    "mcmc_step_size": 0.01,
    "mcmc_noise_scale": 0.005,
}

# High-quality sampling (slower)
quality_config = {
    "mcmc_steps": 200,
    "mcmc_step_size": 0.005,
    "mcmc_noise_scale": 0.001,
}

Learning Rates:

# EBMs typically need lower learning rates than supervised models
learning_rates = {
    "small_model": 1e-4,
    "medium_model": 5e-5,
    "large_model": 1e-5,
}

Generating Samples¤

Sampling from the Model¤

# Generate samples using MCMC
n_samples = 16
samples = model.generate(
    n_samples=n_samples,
    n_mcmc_steps=100,  # More steps = better quality
    step_size=0.01,
    noise_scale=0.005,
    rngs=rngs,
)

print(f"Generated samples shape: {samples.shape}")

Sampling with Different Initializations¤

# Random initialization
random_samples = model.generate(
    n_samples=16,
    init_strategy="random",
    rngs=rngs,
)

# Initialize from data
data_init_samples = model.generate(
    n_samples=16,
    init_strategy="data",
    init_data=train_batch,
    rngs=rngs,
)

# Initialize from buffer
buffer_samples = model.sample_from_buffer(
    n_samples=16,
    rngs=rngs,
)

Conditional Generation¤

For conditional EBMs (e.g., class-conditional):

# Generate samples for specific class
class_label = 3
conditional_samples = model.generate_conditional(
    n_samples=16,
    condition=class_label,
    rngs=rngs,
)

Advanced Techniques¤

1. Sample Buffer Management¤

The sample buffer is critical for stable training:

# Access buffer statistics
buffer_size = len(model.sample_buffer.buffer)
print(f"Buffer contains {buffer_size} samples")

# Manually populate buffer
for batch in train_loader:
    # Run MCMC to generate samples
    samples = model.generate(
        n_samples=batch['x'].shape[0],
        init_strategy="data",
        init_data=batch['x'],
        rngs=rngs,
    )
    # Samples automatically added to buffer

# Clear buffer (for reinitialization)
model.sample_buffer.buffer = []

2. Energy Landscape Visualization¤

Visualize the energy landscape:

import matplotlib.pyplot as plt

def visualize_energy_landscape(model, data_range=(-3, 3), resolution=100):
    """Visualize 2D energy landscape."""
    x = jnp.linspace(data_range[0], data_range[1], resolution)
    y = jnp.linspace(data_range[0], data_range[1], resolution)
    X, Y = jnp.meshgrid(x, y)

    # Compute energy for each point
    points = jnp.stack([X.ravel(), Y.ravel()], axis=1)
    energies = model.energy(points)
    energies = energies.reshape(resolution, resolution)

    # Plot
    plt.figure(figsize=(10, 8))
    plt.contourf(X, Y, energies, levels=50, cmap='viridis')
    plt.colorbar(label='Energy')
    plt.title('Energy Landscape')
    plt.xlabel('x1')
    plt.ylabel('x2')
    plt.show()

# For 2D data
visualize_energy_landscape(model)

3. Annealed Importance Sampling¤

For better sampling quality:

def annealed_sampling(model, n_samples, n_steps=1000, rngs=None):
    """Annealed importance sampling for high-quality samples."""
    # Start with high temperature
    temperatures = jnp.linspace(10.0, 1.0, n_steps)

    # Initialize samples
    samples = jax.random.normal(rngs.sample(), (n_samples, *model.input_shape))

    for i, temp in enumerate(temperatures):
        # Compute energy gradient
        energy_grad = jax.grad(lambda x: jnp.sum(model.energy(x)))(samples)

        # Langevin step with temperature
        step_size = 0.01 * temp
        noise_scale = jnp.sqrt(2 * step_size * temp)

        samples = samples - step_size * energy_grad
        samples = samples + noise_scale * jax.random.normal(
            rngs.sample(), samples.shape
        )

    return samples

# Use annealed sampling
high_quality_samples = annealed_sampling(model, n_samples=16, rngs=rngs)

Troubleshooting¤

Common Issues and Solutions¤

  • Energy Explosion


    Symptoms: Energy values grow unbounded, NaN losses

    Solutions: - Reduce learning rate (try 1e-5) - Add/increase regularization (alpha=0.01 to 0.1) - Use spectral normalization - Clip gradients: max_grad_norm=1.0

    config.parameters["alpha"] = 0.1  # Stronger regularization
    
  • Poor Sample Quality


    Symptoms: Samples look like noise or blurry

    Solutions: - Increase MCMC steps (60 → 100+) - Better step size tuning - Larger buffer capacity - Deeper energy function

    config.parameters["mcmc_steps"] = 100
    config.parameters["sample_buffer_capacity"] = 16384
    
  • Mode Collapse


    Symptoms: All samples look similar

    Solutions: - Increase buffer reinit probability - Use data augmentation - Longer MCMC chains - Larger buffer

    config.parameters["sample_buffer_reinit_prob"] = 0.1
    
  • Training Instability


    Symptoms: Oscillating losses, sudden divergence

    Solutions: - Lower learning rate - Use persistent buffer - Add gradient clipping - Monitor energy gap

    # Ensure persistent buffer is enabled
    trainer.use_persistent_buffer = True
    

Debugging Checklist¤

def diagnose_ebm(model, batch, rngs):
    """Diagnostic checks for EBM training."""

    # 1. Check energy values
    real_energy = model.energy(batch['x']).mean()
    print(f"Real data energy: {real_energy:.3f}")

    # Generate samples
    fake_samples = model.generate(n_samples=16, rngs=rngs)
    fake_energy = model.energy(fake_samples).mean()
    print(f"Generated samples energy: {fake_energy:.3f}")

    # Energy gap should be positive
    gap = fake_energy - real_energy
    print(f"Energy gap: {gap:.3f} {'✓' if gap > 0 else '✗'}")

    # 2. Check MCMC convergence
    init_samples = jax.random.normal(rngs.sample(), (16, *model.input_shape))
    init_energy = model.energy(init_samples).mean()

    final_samples = model.generate(
        n_samples=16,
        init_strategy="custom",
        init_samples=init_samples,
        n_mcmc_steps=100,
        rngs=rngs,
    )
    final_energy = model.energy(final_samples).mean()

    energy_decrease = init_energy - final_energy
    print(f"MCMC energy decrease: {energy_decrease:.3f}")

    # 3. Check buffer health
    buffer_size = len(model.sample_buffer.buffer)
    print(f"Buffer size: {buffer_size}/{model.sample_buffer.capacity}")

    # 4. Check invertibility (samples should be valid)
    sample_min, sample_max = fake_samples.min(), fake_samples.max()
    print(f"Sample range: [{sample_min:.3f}, {sample_max:.3f}]")

    return {
        "real_energy": real_energy,
        "fake_energy": fake_energy,
        "energy_gap": gap,
        "mcmc_decrease": energy_decrease,
        "buffer_usage": buffer_size / model.sample_buffer.capacity,
    }

# Run diagnostics
diagnostics = diagnose_ebm(model, batch, rngs)

Best Practices¤

1. Start Simple¤

# Begin with a small model and simple data
simple_config = ModelConfiguration(
    name="simple_ebm",
    model_class="workshop.generative_models.models.energy.ebm.EBM",
    input_dim=2,  # 2D for visualization
    hidden_dims=[64, 64],
    output_dim=1,
    activation="relu",
    parameters={
        "energy_type": "mlp",
        "mcmc_steps": 30,
        "mcmc_step_size": 0.02,
        "sample_buffer_capacity": 1024,
    }
)

2. Gradually Increase Complexity¤

# Once stable, increase capacity
medium_config = ModelConfiguration(
    name="medium_ebm",
    input_dim=(28, 28, 1),
    hidden_dims=[128, 256],
    parameters={
        "energy_type": "cnn",
        "mcmc_steps": 60,
        "sample_buffer_capacity": 4096,
    }
)

# For complex data
complex_config = ModelConfiguration(
    name="complex_ebm",
    model_class="workshop.generative_models.models.energy.ebm.DeepEBM",
    input_dim=(32, 32, 3),
    hidden_dims=[64, 128, 256, 512],
    parameters={
        "use_residual": True,
        "use_spectral_norm": True,
        "mcmc_steps": 100,
        "sample_buffer_capacity": 8192,
    }
)

3. Monitor Training Carefully¤

# Log detailed metrics
def detailed_training_step(model, batch, rngs, step):
    loss_dict = model.train_step(batch, rngs=rngs)

    if step % 100 == 0:
        # Detailed logging
        print(f"\nStep {step}:")
        print(f"  Loss: {loss_dict['loss']:.4f}")
        print(f"  Real energy: {loss_dict['real_energy']:.4f}")
        print(f"  Fake energy: {loss_dict['fake_energy']:.4f}")
        print(f"  Gap: {loss_dict['energy_gap']:.4f}")

        # Generate samples for visual inspection
        if step % 1000 == 0:
            samples = model.generate(n_samples=64, rngs=rngs)
            visualize_samples(samples, f"step_{step}.png")

    return loss_dict

4. Use Proper Preprocessing¤

def preprocess_for_ebm(images):
    """Proper preprocessing for image EBMs."""
    # Normalize to [-1, 1]
    images = (images - 127.5) / 127.5

    # Add small noise during training
    if training:
        noise = jax.random.normal(rng_key, images.shape) * 0.005
        images = images + noise
        images = jnp.clip(images, -1.0, 1.0)

    return images

Performance Optimization¤

GPU Acceleration¤

# EBMs benefit significantly from GPU
from workshop.generative_models.core.device_manager import DeviceManager

device_manager = DeviceManager()
device = device_manager.get_device()
print(f"Using device: {device}")

# Move data to GPU
batch_gpu = jax.device_put(batch, device)

Batch Size Tuning¤

# Larger batches = more stable gradients
# But: limited by GPU memory

batch_sizes = {
    "small_model": 256,
    "medium_model": 128,
    "large_model": 64,
}

JIT Compilation¤

# Compile training step for speed
@jax.jit
def compiled_train_step(model, batch, rngs):
    return model.train_step(batch, rngs=rngs)

# Much faster after first call
loss_dict = compiled_train_step(model, batch, rngs)

Example: Complete MNIST Training¤

from workshop.generative_models.models.energy import EBM
from workshop.generative_models.core.configuration import ModelConfiguration
import tensorflow_datasets as tfds

# Load MNIST
train_ds = tfds.load('mnist', split='train', as_supervised=True)

def preprocess(image, label):
    image = jnp.array(image, dtype=jnp.float32) / 255.0
    image = (image - 0.5) / 0.5  # Normalize to [-1, 1]
    return {"x": image}

# Create model
config = ModelConfiguration(
    name="mnist_ebm",
    model_class="workshop.generative_models.models.energy.ebm.EBM",
    input_dim=(28, 28, 1),
    hidden_dims=[128, 256, 512],
    output_dim=1,
    activation="silu",
    parameters={
        "energy_type": "cnn",
        "mcmc_steps": 60,
        "mcmc_step_size": 0.01,
        "mcmc_noise_scale": 0.005,
        "sample_buffer_capacity": 8192,
        "alpha": 0.01,
    }
)

model = EBM(config, rngs=rngs)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    for step, batch in enumerate(train_ds.batch(128)):
        batch = preprocess(batch)
        loss_dict = model.train_step(batch, rngs=rngs)

        if step % 100 == 0:
            print(f"  Step {step}: Loss={loss_dict['loss']:.4f}, "
                  f"Gap={loss_dict['energy_gap']:.4f}")

    # Generate samples
    if (epoch + 1) % 10 == 0:
        samples = model.generate(n_samples=64, rngs=rngs)
        save_image_grid(samples, f"epoch_{epoch+1}.png")

print("Training complete!")

Further Reading¤


Summary¤

Key Takeaways:

  • EBMs learn by assigning low energy to data, high energy to non-data
  • Persistent Contrastive Divergence (PCD) with MCMC sampling is the standard training method
  • Sample buffer management is critical for stable training
  • Monitor energy gap: fake_energy should be > real_energy
  • Start simple, increase complexity gradually
  • Use spectral normalization and regularization for stability

Recommended Workflow:

  1. Start with simple 2D data to verify training works
  2. Use MLP energy for tabular, CNN for images
  3. Monitor energy gap and buffer health
  4. Tune MCMC steps and step size for your data
  5. Use DeepEBM for complex distributions
  6. Visualize samples frequently during training

For theoretical understanding, see the EBM Explained guide.