Energy-Based Models User Guide¤
Complete guide to building, training, and using Energy-Based Models with Workshop.
Overview¤
This guide covers practical usage of EBMs in Workshop, from basic setup to advanced techniques. You'll learn how to:
-
Configure EBMs
Set up energy functions and MCMC sampling parameters
-
Train Models
Train with persistent contrastive divergence and monitor stability
-
Generate Samples
Sample using Langevin dynamics and MCMC methods
-
Tune & Debug
Optimize hyperparameters and troubleshoot common issues
Quick Start¤
Basic EBM Example¤
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.energy import EBM
# Initialize RNGs
rngs = nnx.Rngs(params=0, noise=1, sample=2)
# Configuration for MNIST
config = ModelConfiguration(
name="mnist_ebm",
model_class="workshop.generative_models.models.energy.ebm.EBM",
input_dim=(28, 28, 1),
hidden_dims=[128, 256, 512],
output_dim=1,
activation="silu",
parameters={
"energy_type": "cnn",
"mcmc_steps": 60,
"mcmc_step_size": 0.01,
"mcmc_noise_scale": 0.005,
"sample_buffer_capacity": 8192,
"sample_buffer_reinit_prob": 0.05,
"alpha": 0.01, # Regularization strength
}
)
# Create EBM
model = EBM(config, rngs=rngs)
# Training step
batch = {"x": jnp.ones((32, 28, 28, 1))}
loss_dict = model.train_step(batch, rngs=rngs)
print(f"Loss: {loss_dict['loss']:.4f}")
print(f"Real energy: {loss_dict['real_energy']:.4f}")
print(f"Fake energy: {loss_dict['fake_energy']:.4f}")
Creating EBM Models¤
1. Standard EBM (MLP Energy Function)¤
For tabular or low-dimensional data:
from workshop.generative_models.models.energy import EBM
# MLP energy function configuration
config = ModelConfiguration(
name="tabular_ebm",
model_class="workshop.generative_models.models.energy.ebm.EBM",
input_dim=100, # Input features
hidden_dims=[256, 256, 128],
output_dim=1,
activation="gelu",
dropout_rate=0.1,
parameters={
"energy_type": "mlp",
"mcmc_steps": 60,
"mcmc_step_size": 0.01,
"mcmc_noise_scale": 0.005,
"sample_buffer_capacity": 4096,
"alpha": 0.01,
}
)
model = EBM(config, rngs=rngs)
Key Parameters:
| Parameter | Default | Description |
|---|---|---|
energy_type |
"mlp" | Energy function architecture (mlp/cnn) |
mcmc_steps |
60 | Number of Langevin dynamics steps |
mcmc_step_size |
0.01 | Step size for gradient descent |
mcmc_noise_scale |
0.005 | Noise scale for exploration |
alpha |
0.01 | Regularization strength |
2. CNN Energy Function (for Images)¤
For image data:
config = ModelConfiguration(
name="image_ebm",
model_class="workshop.generative_models.models.energy.ebm.EBM",
input_dim=(32, 32, 3), # CIFAR-10 dimensions
hidden_dims=[64, 128, 256],
output_dim=1,
activation="silu",
parameters={
"energy_type": "cnn",
"input_channels": 3,
"mcmc_steps": 100,
"mcmc_step_size": 0.005,
"mcmc_noise_scale": 0.001,
"sample_buffer_capacity": 8192,
"sample_buffer_reinit_prob": 0.05,
"alpha": 0.001,
}
)
model = EBM(config, rngs=rngs)
3. Deep EBM (Complex Data)¤
For complex datasets requiring deeper architectures:
from workshop.generative_models.models.energy import DeepEBM
config = ModelConfiguration(
name="deep_ebm",
model_class="workshop.generative_models.models.energy.ebm.DeepEBM",
input_dim=(32, 32, 3),
hidden_dims=[32, 64, 128, 256],
output_dim=1,
activation="silu",
parameters={
"use_residual": True,
"use_spectral_norm": True,
"mcmc_steps": 100,
"mcmc_step_size": 0.005,
"mcmc_noise_scale": 0.001,
"sample_buffer_capacity": 8192,
"alpha": 0.001,
}
)
model = DeepEBM(config, rngs=rngs)
Deep EBM Features:
- Residual connections: Enable deeper networks (10+ layers)
- Spectral normalization: Stabilizes training
- GroupNorm: Better than BatchNorm for MCMC sampling
Training EBMs¤
Basic Training Loop¤
from workshop.generative_models.training.trainers import EnergyTrainer
from workshop.generative_models.core.configuration import TrainingConfiguration
# Training configuration
train_config = TrainingConfiguration(
num_epochs=100,
batch_size=128,
learning_rate=1e-4,
optimizer="adam",
save_dir="./checkpoints/ebm",
log_every=10,
save_every=1000,
)
# Create trainer
trainer = EnergyTrainer(
model=model,
config=train_config,
rngs=rngs,
)
# Train
history = trainer.train(train_loader, val_loader=None)
Training with Monitoring¤
Monitor key metrics during training:
def train_step_with_monitoring(model, batch, rngs):
"""Training step with detailed monitoring."""
loss_dict = model.train_step(batch, rngs=rngs)
# Log metrics
print(f"Step metrics:")
print(f" Loss: {loss_dict['loss']:.4f}")
print(f" Real energy: {loss_dict['real_energy']:.4f}")
print(f" Fake energy: {loss_dict['fake_energy']:.4f}")
print(f" Energy gap: {loss_dict['energy_gap']:.4f}")
# Check for issues
if loss_dict['energy_gap'] < 0:
print("WARNING: Negative energy gap - real data has higher energy!")
if abs(loss_dict['real_energy']) > 100:
print("WARNING: Energy explosion detected!")
return loss_dict
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
loss_dict = train_step_with_monitoring(model, batch, rngs)
Hyperparameter Guidelines¤
MCMC Sampling:
# Quick sampling (less accurate)
quick_config = {
"mcmc_steps": 20,
"mcmc_step_size": 0.02,
"mcmc_noise_scale": 0.01,
}
# Standard sampling (balanced)
standard_config = {
"mcmc_steps": 60,
"mcmc_step_size": 0.01,
"mcmc_noise_scale": 0.005,
}
# High-quality sampling (slower)
quality_config = {
"mcmc_steps": 200,
"mcmc_step_size": 0.005,
"mcmc_noise_scale": 0.001,
}
Learning Rates:
# EBMs typically need lower learning rates than supervised models
learning_rates = {
"small_model": 1e-4,
"medium_model": 5e-5,
"large_model": 1e-5,
}
Generating Samples¤
Sampling from the Model¤
# Generate samples using MCMC
n_samples = 16
samples = model.generate(
n_samples=n_samples,
n_mcmc_steps=100, # More steps = better quality
step_size=0.01,
noise_scale=0.005,
rngs=rngs,
)
print(f"Generated samples shape: {samples.shape}")
Sampling with Different Initializations¤
# Random initialization
random_samples = model.generate(
n_samples=16,
init_strategy="random",
rngs=rngs,
)
# Initialize from data
data_init_samples = model.generate(
n_samples=16,
init_strategy="data",
init_data=train_batch,
rngs=rngs,
)
# Initialize from buffer
buffer_samples = model.sample_from_buffer(
n_samples=16,
rngs=rngs,
)
Conditional Generation¤
For conditional EBMs (e.g., class-conditional):
# Generate samples for specific class
class_label = 3
conditional_samples = model.generate_conditional(
n_samples=16,
condition=class_label,
rngs=rngs,
)
Advanced Techniques¤
1. Sample Buffer Management¤
The sample buffer is critical for stable training:
# Access buffer statistics
buffer_size = len(model.sample_buffer.buffer)
print(f"Buffer contains {buffer_size} samples")
# Manually populate buffer
for batch in train_loader:
# Run MCMC to generate samples
samples = model.generate(
n_samples=batch['x'].shape[0],
init_strategy="data",
init_data=batch['x'],
rngs=rngs,
)
# Samples automatically added to buffer
# Clear buffer (for reinitialization)
model.sample_buffer.buffer = []
2. Energy Landscape Visualization¤
Visualize the energy landscape:
import matplotlib.pyplot as plt
def visualize_energy_landscape(model, data_range=(-3, 3), resolution=100):
"""Visualize 2D energy landscape."""
x = jnp.linspace(data_range[0], data_range[1], resolution)
y = jnp.linspace(data_range[0], data_range[1], resolution)
X, Y = jnp.meshgrid(x, y)
# Compute energy for each point
points = jnp.stack([X.ravel(), Y.ravel()], axis=1)
energies = model.energy(points)
energies = energies.reshape(resolution, resolution)
# Plot
plt.figure(figsize=(10, 8))
plt.contourf(X, Y, energies, levels=50, cmap='viridis')
plt.colorbar(label='Energy')
plt.title('Energy Landscape')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()
# For 2D data
visualize_energy_landscape(model)
3. Annealed Importance Sampling¤
For better sampling quality:
def annealed_sampling(model, n_samples, n_steps=1000, rngs=None):
"""Annealed importance sampling for high-quality samples."""
# Start with high temperature
temperatures = jnp.linspace(10.0, 1.0, n_steps)
# Initialize samples
samples = jax.random.normal(rngs.sample(), (n_samples, *model.input_shape))
for i, temp in enumerate(temperatures):
# Compute energy gradient
energy_grad = jax.grad(lambda x: jnp.sum(model.energy(x)))(samples)
# Langevin step with temperature
step_size = 0.01 * temp
noise_scale = jnp.sqrt(2 * step_size * temp)
samples = samples - step_size * energy_grad
samples = samples + noise_scale * jax.random.normal(
rngs.sample(), samples.shape
)
return samples
# Use annealed sampling
high_quality_samples = annealed_sampling(model, n_samples=16, rngs=rngs)
Troubleshooting¤
Common Issues and Solutions¤
-
Energy Explosion
Symptoms: Energy values grow unbounded, NaN losses
Solutions: - Reduce learning rate (try 1e-5) - Add/increase regularization (alpha=0.01 to 0.1) - Use spectral normalization - Clip gradients:
max_grad_norm=1.0 -
Poor Sample Quality
Symptoms: Samples look like noise or blurry
Solutions: - Increase MCMC steps (60 → 100+) - Better step size tuning - Larger buffer capacity - Deeper energy function
-
Mode Collapse
Symptoms: All samples look similar
Solutions: - Increase buffer reinit probability - Use data augmentation - Longer MCMC chains - Larger buffer
-
Training Instability
Symptoms: Oscillating losses, sudden divergence
Solutions: - Lower learning rate - Use persistent buffer - Add gradient clipping - Monitor energy gap
Debugging Checklist¤
def diagnose_ebm(model, batch, rngs):
"""Diagnostic checks for EBM training."""
# 1. Check energy values
real_energy = model.energy(batch['x']).mean()
print(f"Real data energy: {real_energy:.3f}")
# Generate samples
fake_samples = model.generate(n_samples=16, rngs=rngs)
fake_energy = model.energy(fake_samples).mean()
print(f"Generated samples energy: {fake_energy:.3f}")
# Energy gap should be positive
gap = fake_energy - real_energy
print(f"Energy gap: {gap:.3f} {'✓' if gap > 0 else '✗'}")
# 2. Check MCMC convergence
init_samples = jax.random.normal(rngs.sample(), (16, *model.input_shape))
init_energy = model.energy(init_samples).mean()
final_samples = model.generate(
n_samples=16,
init_strategy="custom",
init_samples=init_samples,
n_mcmc_steps=100,
rngs=rngs,
)
final_energy = model.energy(final_samples).mean()
energy_decrease = init_energy - final_energy
print(f"MCMC energy decrease: {energy_decrease:.3f}")
# 3. Check buffer health
buffer_size = len(model.sample_buffer.buffer)
print(f"Buffer size: {buffer_size}/{model.sample_buffer.capacity}")
# 4. Check invertibility (samples should be valid)
sample_min, sample_max = fake_samples.min(), fake_samples.max()
print(f"Sample range: [{sample_min:.3f}, {sample_max:.3f}]")
return {
"real_energy": real_energy,
"fake_energy": fake_energy,
"energy_gap": gap,
"mcmc_decrease": energy_decrease,
"buffer_usage": buffer_size / model.sample_buffer.capacity,
}
# Run diagnostics
diagnostics = diagnose_ebm(model, batch, rngs)
Best Practices¤
1. Start Simple¤
# Begin with a small model and simple data
simple_config = ModelConfiguration(
name="simple_ebm",
model_class="workshop.generative_models.models.energy.ebm.EBM",
input_dim=2, # 2D for visualization
hidden_dims=[64, 64],
output_dim=1,
activation="relu",
parameters={
"energy_type": "mlp",
"mcmc_steps": 30,
"mcmc_step_size": 0.02,
"sample_buffer_capacity": 1024,
}
)
2. Gradually Increase Complexity¤
# Once stable, increase capacity
medium_config = ModelConfiguration(
name="medium_ebm",
input_dim=(28, 28, 1),
hidden_dims=[128, 256],
parameters={
"energy_type": "cnn",
"mcmc_steps": 60,
"sample_buffer_capacity": 4096,
}
)
# For complex data
complex_config = ModelConfiguration(
name="complex_ebm",
model_class="workshop.generative_models.models.energy.ebm.DeepEBM",
input_dim=(32, 32, 3),
hidden_dims=[64, 128, 256, 512],
parameters={
"use_residual": True,
"use_spectral_norm": True,
"mcmc_steps": 100,
"sample_buffer_capacity": 8192,
}
)
3. Monitor Training Carefully¤
# Log detailed metrics
def detailed_training_step(model, batch, rngs, step):
loss_dict = model.train_step(batch, rngs=rngs)
if step % 100 == 0:
# Detailed logging
print(f"\nStep {step}:")
print(f" Loss: {loss_dict['loss']:.4f}")
print(f" Real energy: {loss_dict['real_energy']:.4f}")
print(f" Fake energy: {loss_dict['fake_energy']:.4f}")
print(f" Gap: {loss_dict['energy_gap']:.4f}")
# Generate samples for visual inspection
if step % 1000 == 0:
samples = model.generate(n_samples=64, rngs=rngs)
visualize_samples(samples, f"step_{step}.png")
return loss_dict
4. Use Proper Preprocessing¤
def preprocess_for_ebm(images):
"""Proper preprocessing for image EBMs."""
# Normalize to [-1, 1]
images = (images - 127.5) / 127.5
# Add small noise during training
if training:
noise = jax.random.normal(rng_key, images.shape) * 0.005
images = images + noise
images = jnp.clip(images, -1.0, 1.0)
return images
Performance Optimization¤
GPU Acceleration¤
# EBMs benefit significantly from GPU
from workshop.generative_models.core.device_manager import DeviceManager
device_manager = DeviceManager()
device = device_manager.get_device()
print(f"Using device: {device}")
# Move data to GPU
batch_gpu = jax.device_put(batch, device)
Batch Size Tuning¤
# Larger batches = more stable gradients
# But: limited by GPU memory
batch_sizes = {
"small_model": 256,
"medium_model": 128,
"large_model": 64,
}
JIT Compilation¤
# Compile training step for speed
@jax.jit
def compiled_train_step(model, batch, rngs):
return model.train_step(batch, rngs=rngs)
# Much faster after first call
loss_dict = compiled_train_step(model, batch, rngs)
Example: Complete MNIST Training¤
from workshop.generative_models.models.energy import EBM
from workshop.generative_models.core.configuration import ModelConfiguration
import tensorflow_datasets as tfds
# Load MNIST
train_ds = tfds.load('mnist', split='train', as_supervised=True)
def preprocess(image, label):
image = jnp.array(image, dtype=jnp.float32) / 255.0
image = (image - 0.5) / 0.5 # Normalize to [-1, 1]
return {"x": image}
# Create model
config = ModelConfiguration(
name="mnist_ebm",
model_class="workshop.generative_models.models.energy.ebm.EBM",
input_dim=(28, 28, 1),
hidden_dims=[128, 256, 512],
output_dim=1,
activation="silu",
parameters={
"energy_type": "cnn",
"mcmc_steps": 60,
"mcmc_step_size": 0.01,
"mcmc_noise_scale": 0.005,
"sample_buffer_capacity": 8192,
"alpha": 0.01,
}
)
model = EBM(config, rngs=rngs)
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
for step, batch in enumerate(train_ds.batch(128)):
batch = preprocess(batch)
loss_dict = model.train_step(batch, rngs=rngs)
if step % 100 == 0:
print(f" Step {step}: Loss={loss_dict['loss']:.4f}, "
f"Gap={loss_dict['energy_gap']:.4f}")
# Generate samples
if (epoch + 1) % 10 == 0:
samples = model.generate(n_samples=64, rngs=rngs)
save_image_grid(samples, f"epoch_{epoch+1}.png")
print("Training complete!")
Further Reading¤
- EBM Explained - Theoretical foundations
- EBM API Reference - Complete API documentation
- Training Guide - General training workflows
- Examples - More EBM examples
Summary¤
Key Takeaways:
- EBMs learn by assigning low energy to data, high energy to non-data
- Persistent Contrastive Divergence (PCD) with MCMC sampling is the standard training method
- Sample buffer management is critical for stable training
- Monitor energy gap: fake_energy should be > real_energy
- Start simple, increase complexity gradually
- Use spectral normalization and regularization for stability
Recommended Workflow:
- Start with simple 2D data to verify training works
- Use MLP energy for tabular, CNN for images
- Monitor energy gap and buffer health
- Tune MCMC steps and step size for your data
- Use DeepEBM for complex distributions
- Visualize samples frequently during training
For theoretical understanding, see the EBM Explained guide.