Energy-Based Models (EBMs) Explained¤
-
Energy Functions
Learn distributions by assigning low energy to data and high energy to non-data, capturing complex dependencies
-
Flexible Modeling
No architectural constraints—any neural network can define the energy landscape
-
Unified Framework
Simultaneously perform classification, generation, and out-of-distribution detection
-
MCMC Sampling
Generate samples by traversing the energy landscape using Markov Chain Monte Carlo methods
Overview¤
Energy-Based Models (EBMs) are a class of deep generative models that learn data distributions by assigning scalar energy values to every possible configuration. Unlike models that explicitly parameterize probability distributions, EBMs define probabilities implicitly through an energy function: lower energy corresponds to higher probability.
What makes EBMs special?
EBMs solve generative modeling from a fundamentally different perspective. Rather than learning to directly generate samples (GANs) or compress data through bottlenecks (VAEs), EBMs learn an energy landscape where:
- Data regions have low energy - the model assigns favorable (low) energy to realistic configurations
- Non-data regions have high energy - implausible configurations receive unfavorable (high) energy
- Sampling traverses the landscape - MCMC methods move from high to low energy regions
- No architectural constraints - any neural network architecture can define the energy function
The Intuition: Probability Through Energy¤
Think of EBMs like a physical landscape with hills and valleys:
-
The Energy Function defines the terrain—each point (data configuration) has an elevation (energy value). Data points sit in deep valleys (low energy), while meaningless noise occupies high peaks (high energy).
-
Training Shapes the Landscape - The model learns to carve valleys where training data exists and raise peaks everywhere else, creating a terrain that naturally guides samples toward realistic outputs.
-
Sampling Rolls Downhill - Like a ball released on a hillside, MCMC sampling iteratively moves toward lower energy, eventually settling in valleys that correspond to realistic data.
-
Probability is Energy-Based - The probability of any configuration depends on its energy through the Boltzmann distribution: \(p(x) \propto e^{-E(x)}\). Lower energy means exponentially higher probability.
The critical insight: by learning to distinguish data from non-data through energy assignment, EBMs capture complex distributions without explicit density modeling or generation mechanisms.
Mathematical Foundation¤
The Boltzmann Distribution¤
EBMs define probability distributions through the Boltzmann distribution:
where:
- \(E_\theta(x)\) is the energy function parameterized by neural network with weights \(\theta\)
- \(Z_\theta = \int e^{-E_\theta(x')} dx'\) is the partition function (normalizing constant)
- Lower energy \(E_\theta(x)\) corresponds to higher probability \(p_\theta(x)\)
graph TD
X["Data x"] --> E["Energy Function<br/>E_θ(x)"]
E --> NegE["-E_θ(x)"]
NegE --> Exp["exp(-E_θ(x))"]
Exp --> Norm["÷ Z_θ"]
Z["Partition Function<br/>Z_θ = ∫ exp(-E_θ(x')) dx'"] --> Norm
Norm --> P["Probability p_θ(x)"]
style X fill:#e1f5ff
style E fill:#fff3cd
style P fill:#c8e6c9
style Z fill:#ffccbc
Why the exponential? The exponential function ensures probabilities are positive and creates sharp distinctions between energy levels—a small energy difference produces a large probability ratio.
The Intractable Partition Function¤
The partition function \(Z_\theta\) poses the fundamental challenge in EBM training:
This integral sums over all possible configurations of \(x\), which becomes intractable for high-dimensional data:
- Images (256×256×3): \(Z_\theta\) requires integration over \(196{,}608\)-dimensional space
- Text (512 tokens, vocab 50k): \(Z_\theta\) requires summing over \((50{,}000)^{512} \approx 10^{2414}\) configurations
The intractability of \(Z_\theta\) makes maximum likelihood training directly infeasible and necessitates alternative training methods.
Energy Functions and Neural Networks¤
The energy function \(E_\theta(x)\) can be implemented using any neural network architecture:
Multi-Layer Perceptron (MLP) Energy:
class MLPEnergyFunction(nnx.Module):
def __init__(self, hidden_dims, input_dim, *, rngs):
self.layers = []
for dim in hidden_dims:
self.layers.append(nnx.Linear(input_dim, dim, rngs=rngs))
input_dim = dim
self.output = nnx.Linear(input_dim, 1, rngs=rngs)
def __call__(self, x):
for layer in self.layers:
x = nnx.relu(layer(x))
return self.output(x).squeeze(-1) # Scalar energy
Convolutional Energy (for images):
class CNNEnergyFunction(nnx.Module):
def __init__(self, hidden_dims, input_channels, *, rngs):
self.conv_layers = []
for out_channels in hidden_dims:
self.conv_layers.append(
nnx.Conv(input_channels, out_channels,
kernel_size=(3, 3), rngs=rngs)
)
input_channels = out_channels
self.global_pool = lambda x: jnp.mean(x, axis=(1, 2))
self.output = nnx.Linear(hidden_dims[-1], 1, rngs=rngs)
def __call__(self, x):
for conv in self.conv_layers:
x = nnx.relu(conv(x))
x = self.global_pool(x)
return self.output(x).squeeze(-1)
Key properties:
- Scalar output - energy function outputs single real number per input
- Unrestricted architecture - no invertibility, dimension matching, or structural requirements
- Learned landscape - network parameters define the energy surface topology
Training Energy-Based Models¤
Maximum Likelihood: The Ideal but Intractable Approach¤
The maximum likelihood objective seeks to maximize the log-probability of observed data:
Substituting the Boltzmann distribution:
Taking gradients with respect to \(\theta\):
The first term \(-\nabla_\theta E_\theta(x)\) is tractable—just backpropagation through the energy function evaluated at data \(x\).
The second term is problematic:
This requires sampling from the model distribution \(p_\theta(x)\), which is exactly what we're trying to learn! This chicken-and-egg problem necessitates approximate training methods.
Contrastive Divergence: Practical Approximation¤
Contrastive Divergence (CD) (Hinton, 2002) approximates the intractable expectation using short MCMC chains:
Algorithm:
- Positive phase: Sample data \(x^+ \sim p_{\text{data}}\), compute \(\nabla_\theta E_\theta(x^+)\)
- Negative phase: Initialize \(x^- = x^+\), run \(k\) steps of MCMC to get \(x^-\), compute \(\nabla_\theta E_\theta(x^-)\)
-
Gradient step:
\[ \nabla_\theta \mathcal{L}_{\text{CD}} \approx -\nabla_\theta E_\theta(x^+) + \nabla_\theta E_\theta(x^-) \]
graph LR
Data["Data x⁺<br/>(real sample)"] --> PosGrad["Positive Gradient<br/>-∇E(x⁺)"]
Data --> Init["Initialize<br/>x⁻ = x⁺"]
Init --> MCMC["k MCMC Steps<br/>(typically k=1-60)"]
MCMC --> Neg["Model Sample<br/>x⁻"]
Neg --> NegGrad["Negative Gradient<br/>+∇E(x⁻)"]
PosGrad --> Update["Parameter Update"]
NegGrad --> Update
style Data fill:#c8e6c9
style Neg fill:#ffccbc
style Update fill:#e1f5ff
Why it works:
- Positive gradient pushes down energy at data points
- Negative gradient pushes up energy at model samples
- Short chains approximate the full expectation with \(k \ll \infty\) steps
- Persistent chains (PCD) reuse samples across iterations for better mixing
CD-k Limitations
CD with very short chains (\(k=1\)) can be unstable. Persistent Contrastive Divergence (PCD) maintains a persistent pool of samples across training iterations, providing better gradient estimates.
Persistent Contrastive Divergence (PCD)¤
PCD (Tieleman, 2008) maintains a persistent sample buffer across training iterations:
Algorithm:
- Initialize buffer: \(\mathcal{B} = \{x_1, x_2, \ldots, x_M\}\) with random samples
- Each training iteration:
- Sample data batch \(\{x^+_i\}\)
- Sample buffer batch \(\{x^-_j\} \sim \mathcal{B}\)
- Run \(k\) MCMC steps from each \(x^-_j\) to get updated \(\tilde{x}^-_j\)
- Update buffer: \(\mathcal{B} \leftarrow \mathcal{B} \cup \{\tilde{x}^-_j\}\)
- Compute gradient: \(\nabla_\theta \mathcal{L} \approx -\frac{1}{N}\sum_i \nabla E(x^+_i) + \frac{1}{N}\sum_j \nabla E(\tilde{x}^-_j)\)
- Update parameters
def persistent_contrastive_divergence(
energy_fn, real_samples, sample_buffer, rng_key,
n_mcmc_steps=60, step_size=0.01, noise_scale=0.005
):
"""Persistent Contrastive Divergence training step."""
batch_size = real_samples.shape[0]
sample_shape = real_samples.shape[1:]
# Sample initial points from buffer
init_samples = sample_buffer.sample_initial(batch_size, rng_key, sample_shape)
# Run MCMC chain
final_samples = langevin_dynamics(
energy_fn, init_samples, n_mcmc_steps, step_size, noise_scale, rng_key
)
# Update buffer with new samples
sample_buffer.push(final_samples)
return init_samples, final_samples
Advantages over CD:
- Better mixing: Chains continue across iterations
- More accurate gradients: Samples closer to true model distribution
- Stable training: Reduces oscillations and mode collapse
Langevin Dynamics: MCMC Sampling¤
Langevin dynamics provides the MCMC sampler for EBMs:
where \(\xi_t \sim \mathcal{N}(0, I)\) and \(\epsilon\) is the step size.
Two components:
- Gradient descent \(-\frac{\epsilon}{2} \nabla_x E_\theta(x_t)\) moves toward lower energy
- Brownian motion \(\sqrt{\epsilon} \, \xi_t\) adds noise to explore the space
def langevin_dynamics(
energy_fn, init_samples, n_steps, step_size, noise_scale, rng_key
):
"""Sample from EBM using Langevin dynamics."""
x = init_samples
for step in range(n_steps):
# Compute energy gradient
energy_grad = jax.grad(lambda x: jnp.sum(energy_fn(x)))(x)
# Gradient descent step
x = x - (step_size / 2) * energy_grad
# Add noise
rng_key, noise_key = jax.random.split(rng_key)
noise = jax.random.normal(noise_key, x.shape) * noise_scale
x = x + noise
# Clip to valid range
x = jnp.clip(x, -1.0, 1.0)
return x
Convergence: As \(n_{\text{steps}} \to \infty\) and \(\epsilon \to 0\), Langevin dynamics samples from \(p_\theta(x)\) exactly. In practice, finite steps and step sizes introduce bias.
Score Matching: Avoiding Sampling¤
Score Matching (Hyvärinen, 2005) bypasses MCMC sampling by matching the score function (gradient of log-density):
Objective - Minimize the Fisher divergence:
Since \(\nabla_x \log p_{\text{data}}(x)\) is unknown, Hyvärinen showed this is equivalent to:
Advantages:
- No sampling required during training
- No partition function computation needed
- Fast training compared to CD/PCD
Disadvantages:
- Second derivatives \(\nabla_x^2 E_\theta(x)\) expensive to compute
- Generation still requires MCMC at test time
- Less effective for high-dimensional data
Denoising Score Matching: A practical variant adds noise to data:
where \(\tilde{x} = x + \sigma \epsilon\) with \(\epsilon \sim \mathcal{N}(0, I)\).
Noise Contrastive Estimation (NCE)¤
NCE (Gutmann & Hyvärinen, 2010) frames training as binary classification:
Setup: Given data samples \(\{x_i\}\) and noise samples \(\{\tilde{x}_j\}\) from distribution \(p_n\), train a classifier to distinguish them.
Decision rule: Sample \(x\) comes from data if \(\frac{p_\theta(x)}{p_n(x)} > 1\), equivalently:
NCE objective:
Advantages:
- No partition function in gradient
- Simple implementation like standard classification
- Flexible noise distribution \(p_n\) (typically Gaussian)
When to use:
- High-dimensional problems where MCMC is slow
- When generation quality is secondary to density estimation
- Contrastive learning scenarios
Joint Energy-Based Models (JEM)¤
Joint Energy-Based Models (Grathwohl et al., 2020) unify generative and discriminative modeling in a single framework.
The Key Insight¤
A standard classifier \(p(y|x)\) can be reinterpreted as an energy model:
where \(f_\theta(x) \in \mathbb{R}^C\) are logits.
Define energy as:
Then the joint distribution becomes:
And the marginal data distribution:
graph TD
X["Input x"] --> F["Classifier f_θ(x)<br/>(C logits)"]
F --> Cond["Conditional p(y|x)<br/>(softmax)"]
F --> LogSumExp["log-sum-exp"]
LogSumExp --> Energy["E_θ(x) = -logsumexp(f_θ(x))"]
Energy --> Joint["Joint p(x,y)<br/>∝ exp(f_θ(x)[y])"]
Joint --> Marg["Marginal p(x)<br/>∝ exp(-E_θ(x))"]
style F fill:#fff3cd
style Cond fill:#e1f5ff
style Marg fill:#c8e6c9
Training JEM¤
Hybrid objective combining classification and generation:
where:
- \(\mathcal{L}_{\text{class}} = -\mathbb{E}_{(x, y) \sim p_{\text{data}}}[\log p_\theta(y|x)]\) (cross-entropy)
- \(\mathcal{L}_{\text{gen}} = -\mathbb{E}_{x \sim p_{\text{data}}}[\log p_\theta(x)]\) (MLE via SGLD)
Training algorithm:
- Sample batch \(\{(x_i, y_i)\}\) from labeled data
- Classification update: Compute cross-entropy loss, backprop
- Generative update:
- Generate samples via SGLD
- Compute CD gradient
- Update parameters
- Alternate or combine both losses
Capabilities of JEM:
-
Classification
Achieves competitive accuracy on CIFAR-10, SVHN, ImageNet
-
Generation
Generates realistic samples comparable to GANs
-
Out-of-Distribution Detection
Uses \(p(x)\) to detect anomalies with AUROC > 95%
-
Adversarial Robustness
More robust to adversarial examples than standard classifiers
-
Calibration
Better uncertainty estimates than standard softmax
-
Hybrid Discriminative-Generative
Leverages both labeled and unlabeled data
Challenges:
- Training instability: SGLD can fail to produce quality samples
- Computational cost: Requires MCMC sampling during training
- Hyperparameter sensitivity: Step size, noise scale, buffer size critical
Recent work (2023-2024) has proposed improvements including better initialization, curriculum learning, and stabilized SGLD variants.
Architecture Design¤
Energy Function Design Principles¤
Key considerations when designing energy functions:
- Scalar output: \(E_\theta: \mathcal{X} \to \mathbb{R}\) must map inputs to single energy value
- Sufficient capacity: Deep networks capture complex dependencies
- Spectral normalization: Stabilizes training by constraining Lipschitz constant
- Residual connections: Enable deep architectures (10+ layers)
- Normalization layers: GroupNorm or LayerNorm (BatchNorm can interfere with MCMC)
MLP Energy Function (Tabular Data)¤
class MLPEnergyFunction(nnx.Module):
def __init__(
self, hidden_dims, input_dim, activation=nnx.gelu,
dropout_rate=0.0, *, rngs
):
super().__init__(rngs=rngs)
self.layers = []
for dim in hidden_dims:
self.layers.append(nnx.Linear(input_dim, dim, rngs=rngs))
if dropout_rate > 0:
self.layers.append(nnx.Dropout(dropout_rate, rngs=rngs))
input_dim = dim
self.output_layer = nnx.Linear(input_dim, 1, rngs=rngs)
self.activation = activation
def __call__(self, x, *, deterministic=True):
for layer in self.layers:
if isinstance(layer, nnx.Linear):
x = self.activation(layer(x))
elif isinstance(layer, nnx.Dropout):
x = layer(x, deterministic=deterministic)
return self.output_layer(x).squeeze(-1)
Use cases: Tabular data, low-to-moderate dimensional problems (< 1000 dimensions)
CNN Energy Function (Images)¤
class CNNEnergyFunction(nnx.Module):
def __init__(
self, hidden_dims, input_channels=3,
activation=nnx.silu, *, rngs
):
super().__init__(rngs=rngs)
self.conv_blocks = []
in_channels = input_channels
for out_channels in hidden_dims:
self.conv_blocks.append(EnergyBlock(
in_channels, out_channels, activation=activation, rngs=rngs
))
in_channels = out_channels
self.global_pool = lambda x: jnp.mean(x, axis=(1, 2))
self.fc_layers = [
nnx.Linear(hidden_dims[-1], hidden_dims[-1]//2, rngs=rngs),
nnx.Linear(hidden_dims[-1]//2, 1, rngs=rngs),
]
self.activation = activation
def __call__(self, x, *, deterministic=True):
for block in self.conv_blocks:
x = block(x)
x = self.global_pool(x)
for i, fc in enumerate(self.fc_layers):
x = fc(x)
if i < len(self.fc_layers) - 1:
x = self.activation(x)
return x.squeeze(-1)
class EnergyBlock(nnx.Module):
"""Residual block with GroupNorm for energy function."""
def __init__(
self, in_channels, out_channels,
kernel_size=3, stride=2,
use_residual=False, activation=nnx.silu, *, rngs
):
super().__init__()
self.conv = nnx.Conv(
in_channels, out_channels,
kernel_size=(kernel_size, kernel_size),
strides=(stride, stride), padding="SAME", rngs=rngs
)
self.norm = nnx.GroupNorm(
min(32, out_channels), out_channels, rngs=rngs
)
self.activation = activation
self.use_residual = use_residual and (in_channels == out_channels) and (stride == 1)
def __call__(self, x):
residual = x
x = self.conv(x)
x = self.norm(x)
x = self.activation(x)
if self.use_residual:
x = x + residual
return x
Use cases: Image data (MNIST, CIFAR-10, CelebA), spatial data
Deep Energy Functions with Spectral Normalization¤
For complex distributions, deeper architectures with spectral normalization improve stability:
class DeepCNNEnergyFunction(nnx.Module):
def __init__(
self, hidden_dims, input_channels=3,
use_spectral_norm=True, use_residual=True, *, rngs
):
super().__init__(rngs=rngs)
# Build deeper architecture (8-12 layers)
# Apply spectral normalization to conv layers
# Use residual connections where dimensions match
...
Benefits:
- Spectral norm: Constrains Lipschitz constant, prevents gradient explosion
- Residual connections: Enable training of 10+ layer networks
- Better sample quality: Captures finer details in distribution
Training Dynamics and Best Practices¤
Sample Buffer Management¤
Persistent sample buffers are critical for stable EBM training:
Buffer initialization:
- Random noise (Gaussian or uniform)
- Real data with added noise
- Pre-trained model samples
Buffer operations:
class SampleBuffer:
def __init__(self, capacity=8192, reinit_prob=0.05):
self.capacity = capacity
self.reinit_prob = reinit_prob
self.buffer = []
def sample_initial(self, batch_size, rng_key, sample_shape):
"""Sample starting points for MCMC."""
if len(self.buffer) < batch_size:
# Not enough samples - return random noise
return jax.random.normal(rng_key, (batch_size, *sample_shape))
# Sample from buffer
indices = jax.random.choice(rng_key, len(self.buffer), (batch_size,))
samples = jnp.array([self.buffer[i] for i in indices])
# Reinitialize some samples with noise
reinit_mask = jax.random.uniform(rng_key, (batch_size,)) < self.reinit_prob
noise = jax.random.normal(rng_key, samples.shape)
samples = jnp.where(reinit_mask[:, None, None, None], noise, samples)
return samples
def push(self, samples):
"""Add new samples to buffer."""
for sample in samples:
if len(self.buffer) >= self.capacity:
self.buffer.pop(0) # Remove oldest
self.buffer.append(sample)
Reinitialiation probability: Setting \(p_{\text{reinit}} = 0.05\) helps escape local minima.
Hyperparameter Guidelines¤
MCMC sampling:
| Parameter | Value Range | Typical | Notes |
|---|---|---|---|
| MCMC steps | 20-200 | 60 | More steps = better samples, slower |
| Step size | 0.001-0.05 | 0.01 | Too large: instability, too small: slow |
| Noise scale | 0.001-0.01 | 0.005 | Exploration vs exploitation |
Training:
| Parameter | Value Range | Typical | Notes |
|---|---|---|---|
| Learning rate | 1e-5 to 1e-3 | 1e-4 | Lower than standard supervised |
| Batch size | 32-256 | 128 | Larger batches stabilize |
| Buffer capacity | 2048-16384 | 8192 | More = better mixing |
| Alpha (regularization) | 0.001-0.1 | 0.01 | Prevents energy collapse |
Data preprocessing:
- Normalize images to \([-1, 1]\) or \([0, 1]\)
- Add noise to real data during training (noise scale ~0.005)
- Clip samples after MCMC to valid range
Common Training Issues¤
-
Mode Collapse
Symptom: Generated samples lack diversity, all similar
Solutions: Increase buffer size, use reinitialize probability, longer MCMC chains, reduce learning rate
-
Energy Explosion
Symptom: Energy values grow unbounded, NaN losses
Solutions: Add regularization term \(\alpha \mathbb{E}[E(x)^2]\), spectral normalization, gradient clipping, smaller learning rate
-
Poor Sample Quality
Symptom: Samples look like noise or blurry averages
Solutions: More MCMC steps (100+), better step size tuning, deeper energy function, larger model capacity
-
Training Instability
Symptom: Oscillating losses, sudden divergence
Solutions: Persistent buffer, KL annealing on regularization, spectral norm, lower learning rate, gradient clipping
Monitoring and Diagnostics¤
Essential metrics:
- Energy values: Monitor \(\mathbb{E}[E(x_{\text{data}})]\) and \(\mathbb{E}[E(x_{\text{gen}})]\)
- Should satisfy \(E(x_{\text{data}}) < E(x_{\text{gen}})\)
-
Gap should be positive and stable
-
MCMC diagnostics:
- Acceptance rate: Track how many proposed moves are accepted
-
Energy trajectory: Plot energy over MCMC steps (should decrease)
-
Sample quality: Visual inspection, FID/IS scores
-
Gradient norms: Should be stable, not exploding
def training_step_with_diagnostics(model, batch, rngs):
# Forward pass
real_data = batch['x']
real_energy = model.energy(real_data).mean()
# Generate samples
fake_data = generate_samples_via_mcmc(model, batch_size, rngs)
fake_energy = model.energy(fake_data).mean()
# Compute loss
loss = real_energy - fake_energy + alpha * (real_energy**2).mean()
# Diagnostics
energy_gap = fake_energy - real_energy
print(f"Real Energy: {real_energy:.3f}, "
f"Fake Energy: {fake_energy:.3f}, "
f"Gap: {energy_gap:.3f}, "
f"Loss: {loss:.3f}")
return loss
Comparing EBMs with Other Generative Models¤
EBMs vs GANs: Stable Training vs Sharp Samples¤
| Aspect | Energy-Based Models | GANs |
|---|---|---|
| Training Stability | Stable with PCD/CD | Notorious instability, mode collapse |
| Sample Quality | Good (improving with recent methods) | Excellent (sharp, realistic) |
| Sampling Speed | Slow (50-200 MCMC steps) | Fast (1 forward pass) |
| Mode Coverage | Excellent (explicit density) | Poor (mode collapse common) |
| Likelihood | Exact (modulo partition function) | None (no explicit density) |
| Architecture | Any neural network | Generator-discriminator pair |
| Use Cases | Density estimation, OOD detection | High-quality image synthesis |
When to use EBMs over GANs:
- Exact likelihood computation needed
- Out-of-distribution detection
- Training stability priority
- Avoiding mode collapse essential
EBMs vs VAEs: Flexibility vs Efficiency¤
| Aspect | Energy-Based Models | VAEs |
|---|---|---|
| Architecture | Unrestricted | Encoder-decoder with bottleneck |
| Likelihood | Exact density | Lower bound (ELBO) |
| Sample Quality | Good | Often blurry (reconstruction bias) |
| Sampling Speed | Slow (MCMC) | Fast (single pass) |
| Training | CD/PCD (requires MCMC) | Stable (reparameterization trick) |
| Latent Space | No explicit latent | Structured latent representations |
| Use Cases | Density estimation | Representation learning, compression |
When to use EBMs over VAEs:
- No need for learned latent representations
- Exact density more important than speed
- Want architectural flexibility
- Avoid reconstruction-based blur
EBMs vs Diffusion Models: Principled Energy vs Iterative Denoising¤
| Aspect | Energy-Based Models | Diffusion Models |
|---|---|---|
| Conceptual Framework | Energy landscape | Iterative denoising |
| Training | CD/PCD (challenging) | Stable (MSE denoising) |
| Sampling | MCMC (50-200 steps) | Denoising (50-1000 steps) |
| Sample Quality | Good | State-of-the-art |
| Likelihood | Exact | Tractable via ODE |
| Flexibility | Any architecture | Flexible but typically U-Net |
| Maturity | Older field, renewed interest | Rapidly advancing (2020-2025) |
Relationship: Diffusion models can be viewed as energy-based models with time-dependent energy functions and specific noise schedules. Score-based models unify the two perspectives.
Recent Advances (2023-2025)¤
Improved Contrastive Divergence¤
Improved CD (Du et al., ICML 2021) addresses gradient bias in standard CD:
Key innovations:
- KL divergence term: Add \(\mathbb{E}_{x \sim p_{\text{data}}}[D_{\text{KL}}(q(x') | p_\theta(x'))]\) to loss
- Data augmentation: Augment both positive and negative samples
- Multi-scale processing: Process images at multiple resolutions
- Reservoir sampling: Maintain diverse buffer through stratified sampling
Results: CIFAR-10 Inception Score of 8.30, substantial improvement over baseline CD.
Diffusion Contrastive Divergence (DCD)¤
DCD (2023) replaces Langevin dynamics with diffusion processes:
Instead of Langevin MCMC: $$ x_{t+1} = x_t - \epsilon \nabla E(x_t) + \sqrt{2\epsilon} \xi_t $$
Use general diffusion: $$ dx_t = f(x_t, t)dt + g(t)dw_t $$
Benefits:
- More flexible: Not limited to Langevin
- More efficient: Better exploration of space
- Simulation-free: Train without explicit integration
Applications: Higher quality image generation, faster convergence.
Energy Matching: Unifying Flows and EBMs¤
Energy Matching (2024) connects flow matching and energy-based modeling:
Framework: Learn energy functions via flow matching objectives, avoiding MCMC during training entirely.
Training: Match energy gradients to optimal transport flows between data and noise.
Results: State-of-the-art likelihood on tabular datasets, competitive image generation.
Stable JEM Training¤
Recent work addresses JEM training instability:
- Better initialization: Initialize from pre-trained classifier
- Curriculum learning: Gradually increase MCMC steps during training
- Improved SGLD: Adaptive step sizes, better noise schedules
- Separate buffer per class: Class-conditional buffers improve mixing
Stabilized JEM (2023): Achieves stable training on ImageNet-scale datasets with competitive classification and generation.
Summary and Key Takeaways¤
Energy-Based Models provide a principled framework for learning probability distributions through energy functions. While training challenges have historically limited their adoption, recent advances in contrastive divergence and hybrid methods are revitalizing the field.
Core Principles¤
-
Energy Landscape
Model assigns scalar energy to configurations—low energy = high probability
-
Boltzmann Distribution
\(p(x) \propto e^{-E(x)}\) connects energy to probability via exponential
-
MCMC Sampling
Generate samples by traversing energy landscape with Langevin dynamics
-
Flexible Architecture
Any neural network can define energy function—no structural constraints
Training Methods¤
| Method | Sampling During Training | Pros | Cons |
|---|---|---|---|
| Contrastive Divergence | Short MCMC chains (k=1-60) | Practical, widely used | Biased gradient estimates |
| Persistent CD | Persistent buffer + MCMC | Better gradients, stable | Requires careful buffer management |
| Score Matching | None | No sampling needed | Expensive second derivatives |
| NCE | Noise samples only | Simple, no MCMC | Requires good noise distribution |
When to Use EBMs¤
Best suited for:
- Density estimation where exact likelihood matters
- Out-of-distribution detection leveraging \(p(x)\)
- Anomaly detection in high-stakes applications
- Hybrid tasks combining generation and classification (JEM)
- Structured prediction with complex dependencies
Avoid when:
- Real-time generation required (use GANs or fast flows)
- Training data limited (VAEs/diffusion more sample-efficient)
- Computational resources constrained at training time
Future Directions¤
- Faster sampling: Learned samplers, amortized MCMC
- Better training: Improved CD variants, score matching hybrids
- Scalability: Efficient large-scale training, distributed MCMC
- Applications: Scientific computing, protein design, molecular generation
- Unification: Bridging EBMs, diffusion, and score-based models
Next Steps¤
-
Practical usage guide with implementation examples and training workflows
-
Complete API documentation for EBM, DeepEBM, and JEM classes
-
Step-by-step hands-on tutorial: train an EBM on MNIST from scratch
-
Explore Joint Energy-Based Models, hybrid training, and applications
Further Reading¤
Seminal Papers (Must Read)¤
LeCun, Y., Chopra, S., Hadsell, R., Ranzato, M., & Huang, F. (2006). "A Tutorial on Energy-Based Learning"
Technical Report
Foundational tutorial establishing the energy-based learning framework
Hinton, G. E. (2002). "Training Products of Experts by Minimizing Contrastive Divergence"
Neural Computation 14(8)
Introduced Contrastive Divergence, making EBM training practical
Tieleman, T. (2008). "Training Restricted Boltzmann Machines using Approximations to the Likelihood Gradient"
ICML 2008
Persistent Contrastive Divergence for improved training
Hyvärinen, A. (2005). "Estimation of Non-Normalized Statistical Models by Score Matching"
JMLR 6
Score matching method avoiding partition function
Gutmann, M., & Hyvärinen, A. (2010). "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models"
AISTATS 2010
NCE training via binary classification
Modern EBMs and Applications¤
Du, Y., & Mordatch, I. (2019). "Implicit Generation and Modeling with Energy Based Models"
arXiv:1903.08689 | NeurIPS 2019
Modern EBMs for image generation with improved training
Grathwohl, W., Wang, K.-C., Jacobsen, J.-H., Duvenaud, D., Norouzi, M., & Swersky, K. (2020). "Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One"
arXiv:1912.03263 | ICLR 2020
Joint Energy-Based Models unifying classification and generation
Du, Y., Li, S., Tenenbaum, J., & Mordatch, I. (2020). "Improved Contrastive Divergence Training of Energy Based Models"
arXiv:2012.01316 | ICML 2021
Addresses gradient bias in CD with KL regularization
Gao, R., Song, Y., Poole, B., Wu, Y. N., & Kingma, D. P. (2021). "Learning Energy-Based Models by Diffusion Recovery Likelihood"
arXiv:2012.08125 | ICLR 2021
Connects EBMs with diffusion models via recovery likelihood
Recent Advances (2023-2025)¤
Wu, D., Wang, L., & Hong, P. (2023). "Training Energy-Based Models with Diffusion Contrastive Divergences"
arXiv:2307.01668
Replaces Langevin with general diffusion processes
Gao, Y., Liu, X., & Zha, H. (2024). "Energy Matching: Unifying Flow Matching and Energy-Based Models for Generative Modeling"
arXiv:2504.10612
Unifies flow matching and EBMs (2024 cutting-edge)
Zhang, Y., et al. (2023). "Stabilized Training of Joint Energy-based Models and their Practical Applications"
arXiv:2303.04187
Addresses JEM training instability
Tutorial Resources¤
NYU Deep Learning Course: Energy-Based Models
atcold.github.io/NYU-DLSP20/en/week07/07-1
Comprehensive lectures by Yann LeCun and Alfredo Canziani
UvA Deep Learning Tutorial 8: Deep Energy-Based Generative Models
uvadlc-notebooks.readthedocs.io
Hands-on Colab notebooks with implementations
Stanford CS236: Deep Generative Models (EBM Lecture)
Lecture Slides
Academic course materials covering EBM theory
Energy-Based Models GitHub Repositories
github.com/yilundu/improved_contrastive_divergence
Official implementation of Improved CD
Ready to explore Energy-Based Models? Start with the EBM User Guide for practical implementations, check the API Reference for complete documentation, or dive into the MNIST Tutorial to train your first EBM!