VAE User Guide¤
Complete guide to building, training, and using Variational Autoencoders with Workshop.
Overview¤
This guide covers practical usage of VAEs in Workshop, from basic setup to advanced techniques. You'll learn how to:
-
Configure VAEs
Set up encoder/decoder architectures and configure hyperparameters
-
Train Models
Train VAEs with proper loss functions and monitoring
-
Generate Samples
Sample from the prior and manipulate latent representations
-
Tune & Debug
Optimize hyperparameters and troubleshoot common issues
Quick Start¤
Basic VAE Example¤
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.vae import VAE
from workshop.generative_models.models.vae.encoders import MLPEncoder
from workshop.generative_models.models.vae.decoders import MLPDecoder
# Initialize RNGs
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
# Configuration
input_dim = 784 # 28x28 MNIST
latent_dim = 20
hidden_dims = [512, 256]
# Create encoder and decoder
encoder = MLPEncoder(
hidden_dims=hidden_dims,
latent_dim=latent_dim,
activation="relu",
input_dim=(input_dim,),
rngs=rngs,
)
decoder = MLPDecoder(
hidden_dims=list(reversed(hidden_dims)),
output_dim=(input_dim,),
latent_dim=latent_dim,
activation="relu",
rngs=rngs,
)
# Create VAE
vae = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=latent_dim,
rngs=rngs,
kl_weight=1.0,
)
# Forward pass
x = jnp.ones((32, input_dim))
outputs = vae(x, rngs=rngs)
print(f"Reconstruction shape: {outputs['reconstructed'].shape}")
print(f"Latent shape: {outputs['z'].shape}")
Creating VAE Models¤
1. Encoder Architectures¤
MLP Encoder (Fully-Connected)¤
Best for tabular data and flattened images:
from workshop.generative_models.models.vae.encoders import MLPEncoder
encoder = MLPEncoder(
hidden_dims=[512, 256, 128], # Network depth
latent_dim=32, # Latent space dimension
activation="relu", # Activation function
input_dim=(784,), # Flattened input size
rngs=rngs,
)
# Forward pass returns (mean, log_var)
mean, log_var = encoder(x, rngs=rngs)
CNN Encoder (Convolutional)¤
Best for image data with spatial structure:
from workshop.generative_models.models.vae.encoders import CNNEncoder
encoder = CNNEncoder(
hidden_dims=[32, 64, 128, 256], # Channel progression
latent_dim=64,
activation="relu",
input_dim=(28, 28, 1), # (H, W, C)
rngs=rngs,
)
# Preserves spatial information through convolutions
mean, log_var = encoder(x, rngs=rngs)
Conditional Encoder¤
Add class conditioning to any encoder:
from workshop.generative_models.models.vae.encoders import ConditionalEncoder
base_encoder = MLPEncoder(
hidden_dims=[512, 256],
latent_dim=32,
input_dim=(784,),
rngs=rngs,
)
conditional_encoder = ConditionalEncoder(
encoder=base_encoder,
num_classes=10, # Number of classes
embed_dim=128, # Embedding dimension for labels
rngs=rngs,
)
# Pass class labels as condition
mean, log_var = conditional_encoder(x, condition=labels, rngs=rngs)
2. Decoder Architectures¤
MLP Decoder¤
from workshop.generative_models.models.vae.decoders import MLPDecoder
decoder = MLPDecoder(
hidden_dims=[128, 256, 512], # Reversed from encoder
output_dim=(784,), # Reconstruction size
latent_dim=32,
activation="relu",
rngs=rngs,
)
reconstructed = decoder(z) # Returns JAX array
CNN Decoder (Transposed Convolutions)¤
from workshop.generative_models.models.vae.decoders import CNNDecoder
decoder = CNNDecoder(
hidden_dims=[256, 128, 64, 32], # Reversed channel progression
output_dim=(28, 28, 1), # Output image shape
latent_dim=64,
activation="relu",
rngs=rngs,
)
reconstructed = decoder(z) # Returns (batch, 28, 28, 1)
Conditional Decoder¤
from workshop.generative_models.models.vae.decoders import ConditionalDecoder
base_decoder = MLPDecoder(
hidden_dims=[128, 256, 512],
output_dim=(784,),
latent_dim=32,
rngs=rngs,
)
conditional_decoder = ConditionalDecoder(
decoder=base_decoder,
num_classes=10,
embed_dim=128,
rngs=rngs,
)
# Condition on class labels
reconstructed = conditional_decoder(z, condition=labels, rngs=rngs)
3. Complete VAE Models¤
Standard VAE¤
from workshop.generative_models.models.vae.base import VAE
vae = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
rngs=rngs,
kl_weight=1.0, # Beta parameter (1.0 = standard VAE)
precision=None, # Numerical precision
)
β-VAE (Disentangled Representations)¤
from workshop.generative_models.models.vae.beta_vae import BetaVAE
beta_vae = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=4.0, # Higher beta = more disentanglement
beta_warmup_steps=10000, # Gradual beta annealing
reconstruction_loss_type="mse", # "mse" or "bce"
rngs=rngs,
)
β-VAE with Capacity Control¤
from workshop.generative_models.models.vae.beta_vae import BetaVAEWithCapacity
capacity_vae = BetaVAEWithCapacity(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=4.0,
beta_warmup_steps=10000,
reconstruction_loss_type="mse",
use_capacity_control=True,
capacity_max=25.0, # Maximum capacity in nats
capacity_num_iter=25000, # Steps to reach max capacity
gamma=1000.0, # Capacity loss weight
rngs=rngs,
)
Conditional VAE¤
from workshop.generative_models.models.vae.conditional import ConditionalVAE
cvae = ConditionalVAE(
encoder=conditional_encoder,
decoder=conditional_decoder,
latent_dim=32,
condition_dim=10, # Dimension of conditioning info
condition_type="concat", # Concatenation strategy
rngs=rngs,
)
# Forward pass with condition
outputs = cvae(x, y=labels, rngs=rngs)
VQ-VAE (Discrete Latents)¤
from workshop.generative_models.models.vae.vq_vae import VQVAE
vqvae = VQVAE(
encoder=encoder,
decoder=decoder,
latent_dim=64,
num_embeddings=512, # Codebook size
embedding_dim=64, # Embedding dimension
commitment_cost=0.25, # Commitment loss weight
rngs=rngs,
)
Training VAEs¤
Basic Training Loop¤
import optax
# Initialize model and optimizer
vae = VAE(encoder, decoder, latent_dim=32, rngs=rngs)
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=1e-3))
# Training step
@nnx.jit
def train_step(model, optimizer, batch):
def loss_fn(model):
# Forward pass
outputs = model(batch, rngs=rngs)
# Compute loss
losses = model.loss_fn(x=batch, outputs=outputs)
return losses["total_loss"], losses
# Compute gradients and update
loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(model)
optimizer.update(nnx.get_gradients(loss))
return losses
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
losses = train_step(vae, optimizer, batch)
# Log metrics
print(f"Recon: {losses['reconstruction_loss']:.4f}, "
f"KL: {losses['kl_loss']:.4f}, "
f"Total: {losses['total_loss']:.4f}")
Training β-VAE with Annealing¤
beta_vae = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=4.0,
beta_warmup_steps=10000,
rngs=rngs,
)
optimizer = nnx.Optimizer(beta_vae, optax.adam(learning_rate=1e-3))
step = 0
for epoch in range(num_epochs):
for batch in train_loader:
def loss_fn(model):
outputs = model(batch, rngs=rngs)
# Pass current step for beta annealing
losses = model.loss_fn(x=batch, outputs=outputs, step=step)
return losses["total_loss"], losses
loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(beta_vae)
optimizer.update(nnx.get_gradients(loss))
# Monitor beta value
print(f"Step {step}, Beta: {losses['beta']:.4f}")
step += 1
Training Conditional VAE¤
cvae = ConditionalVAE(
encoder=conditional_encoder,
decoder=conditional_decoder,
latent_dim=32,
condition_dim=10,
rngs=rngs,
)
optimizer = nnx.Optimizer(cvae, optax.adam(learning_rate=1e-3))
for epoch in range(num_epochs):
for batch_x, batch_y in train_loader:
def loss_fn(model):
# Forward with conditioning
outputs = model(batch_x, y=batch_y, rngs=rngs)
losses = model.loss_fn(x=batch_x, outputs=outputs)
return losses["total_loss"], losses
loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(cvae)
optimizer.update(nnx.get_gradients(loss))
Training VQ-VAE¤
vqvae = VQVAE(
encoder=encoder,
decoder=decoder,
latent_dim=64,
num_embeddings=512,
commitment_cost=0.25,
rngs=rngs,
)
optimizer = nnx.Optimizer(vqvae, optax.adam(learning_rate=1e-3))
for epoch in range(num_epochs):
for batch in train_loader:
def loss_fn(model):
outputs = model(batch, rngs=rngs)
losses = model.loss_fn(x=batch, outputs=outputs)
return losses["total_loss"], losses
loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(vqvae)
optimizer.update(nnx.get_gradients(loss))
# VQ-VAE specific metrics
print(f"Recon: {losses['reconstruction_loss']:.4f}, "
f"Codebook: {losses['codebook_loss']:.4f}, "
f"Commitment: {losses['commitment_loss']:.4f}")
Generating and Sampling¤
Generate New Samples¤
# Sample from prior distribution
n_samples = 16
samples = vae.sample(n_samples, temperature=1.0, rngs=rngs)
# Temperature controls diversity
hot_samples = vae.sample(n_samples, temperature=2.0, rngs=rngs) # More diverse
cold_samples = vae.sample(n_samples, temperature=0.5, rngs=rngs) # More focused
# Using generate() method (alias for sample)
samples = vae.generate(n_samples, temperature=1.0, rngs=rngs)
Conditional Generation¤
# Generate samples for specific classes
target_classes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # One of each digit
labels = jax.nn.one_hot(target_classes, num_classes=10)
samples = cvae.sample(n_samples=10, y=labels, temperature=1.0, rngs=rngs)
Reconstruction¤
# Stochastic reconstruction
reconstructed = vae.reconstruct(x, deterministic=False, rngs=rngs)
# Deterministic reconstruction (use mean of latent distribution)
deterministic_recon = vae.reconstruct(x, deterministic=True, rngs=rngs)
Latent Space Manipulation¤
Interpolation Between Images¤
# Linear interpolation in latent space
x1 = test_images[0] # First image
x2 = test_images[1] # Second image
interpolated = vae.interpolate(
x1=x1,
x2=x2,
steps=10, # Number of interpolation steps
rngs=rngs,
)
# interpolated.shape = (10, *input_shape)
Latent Traversal (Disentanglement Analysis)¤
# Traverse a single latent dimension
x = test_images[0]
dim_to_traverse = 3 # Which latent dimension to vary
traversal = vae.latent_traversal(
x=x,
dim=dim_to_traverse,
range_vals=(-3.0, 3.0), # Range of values
steps=10, # Number of steps
rngs=rngs,
)
# traversal.shape = (10, *input_shape)
Manual Latent Manipulation¤
# Encode image to latent space
mean, log_var = vae.encode(x, rngs=rngs)
# Manipulate specific dimensions
modified_mean = mean.copy()
modified_mean = modified_mean.at[:, 5].set(2.0) # Increase dimension 5
modified_mean = modified_mean.at[:, 10].set(-1.5) # Decrease dimension 10
# Decode modified latent
modified_image = vae.decode(modified_mean, rngs=rngs)
Evaluation and Analysis¤
Reconstruction Quality¤
# Calculate reconstruction error
test_batch = test_images[:100]
reconstructed = vae.reconstruct(test_batch, deterministic=True, rngs=rngs)
mse = jnp.mean((test_batch - reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.4f}")
ELBO (Evidence Lower Bound)¤
# Full ELBO calculation
outputs = vae(test_batch, rngs=rngs)
losses = vae.loss_fn(x=test_batch, outputs=outputs)
elbo = -(losses['reconstruction_loss'] + losses['kl_loss'])
print(f"ELBO: {elbo:.4f}")
Latent Space Statistics¤
# Encode test set
all_means = []
all_logvars = []
for batch in test_loader:
mean, log_var = vae.encode(batch, rngs=rngs)
all_means.append(mean)
all_logvars.append(log_var)
all_means = jnp.concatenate(all_means, axis=0)
all_logvars = jnp.concatenate(all_logvars, axis=0)
# Statistics per dimension
mean_per_dim = jnp.mean(all_means, axis=0)
std_per_dim = jnp.std(all_means, axis=0)
variance_per_dim = jnp.exp(jnp.mean(all_logvars, axis=0))
print(f"Latent mean: {mean_per_dim}")
print(f"Latent std: {std_per_dim}")
print(f"Average variance: {variance_per_dim}")
Disentanglement Metrics¤
# Per-dimension KL divergence (detect posterior collapse)
def per_dim_kl(mean, log_var):
"""Calculate KL divergence per dimension."""
kl_per_dim = -0.5 * (1 + log_var - mean**2 - jnp.exp(log_var))
return jnp.mean(kl_per_dim, axis=0)
kl_per_dimension = per_dim_kl(all_means, all_logvars)
# Dimensions with very low KL likely collapsed
inactive_dims = jnp.sum(kl_per_dimension < 0.01)
print(f"Inactive dimensions: {inactive_dims}/{vae.latent_dim}")
Hyperparameter Tuning¤
Key Hyperparameters¤
# Architecture
config = {
# Network architecture
"latent_dim": 64, # 10-100 for images, 2-20 for simple data
"hidden_dims": [512, 256, 128], # Deeper for complex data
"activation": "relu", # or "gelu", "swish"
# Training
"learning_rate": 1e-3, # 1e-4 to 1e-3 typical
"batch_size": 128, # Larger is more stable
"num_epochs": 100,
# VAE-specific
"kl_weight": 1.0, # Beta parameter
"reconstruction_loss": "mse", # "mse" or "bce"
}
Beta Tuning for β-VAE¤
# Grid search over beta values
beta_values = [0.5, 1.0, 2.0, 4.0, 8.0]
results = {}
for beta in beta_values:
model = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=beta,
rngs=rngs,
)
# Train model
trained_model, metrics = train(model, train_loader, num_epochs=50)
# Evaluate
recon_error = evaluate_reconstruction(trained_model, test_loader)
disentanglement = measure_disentanglement(trained_model, test_loader)
results[beta] = {
"recon_error": recon_error,
"disentanglement": disentanglement,
}
# Find best trade-off
print(results)
Learning Rate Scheduling¤
import optax
# Cosine decay schedule
schedule = optax.cosine_decay_schedule(
init_value=1e-3,
decay_steps=num_train_steps,
alpha=0.1, # Final learning rate = 0.1 * init_value
)
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=schedule))
Common Issues and Solutions¤
Problem 1: Posterior Collapse¤
Symptoms: KL divergence near zero, poor generation quality
Solutions:
# Solution 1: Beta annealing
beta_vae = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=1.0,
beta_warmup_steps=10000, # Start with β=0, gradually increase
rngs=rngs,
)
# Solution 2: Increase latent capacity gradually
capacity_vae = BetaVAEWithCapacity(
encoder=encoder,
decoder=decoder,
latent_dim=32,
use_capacity_control=True,
capacity_max=25.0,
capacity_num_iter=25000,
gamma=1000.0,
rngs=rngs,
)
# Solution 3: Weaker decoder (make it harder to ignore latent)
weak_decoder = MLPDecoder(
hidden_dims=[128, 256], # Smaller than encoder
output_dim=(784,),
latent_dim=32,
rngs=rngs,
)
Problem 2: Blurry Reconstructions¤
Symptoms: Overly smooth outputs, lack of detail
Solutions:
# Solution 1: Use perceptual loss (for images)
from workshop.generative_models.core.losses import PerceptualLoss
perceptual_loss = PerceptualLoss(feature_layers=[3, 8, 15])
def custom_reconstruction_loss(x_true, x_pred):
# Combine MSE with perceptual loss
mse = jnp.mean((x_true - x_pred) ** 2)
perceptual = perceptual_loss(x_true, x_pred)
return mse + 0.1 * perceptual
losses = vae.loss_fn(
x=batch,
outputs=outputs,
reconstruction_loss_fn=custom_reconstruction_loss,
)
# Solution 2: Lower beta (emphasize reconstruction)
vae = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
kl_weight=0.5, # Lower than default 1.0
rngs=rngs,
)
# Solution 3: Use VQ-VAE (discrete latents often sharper)
vqvae = VQVAE(
encoder=encoder,
decoder=decoder,
latent_dim=64,
num_embeddings=512,
rngs=rngs,
)
Problem 3: Unstable Training¤
Symptoms: Loss oscillations, NaN values
Solutions:
# Solution 1: Gradient clipping
import optax
optimizer = nnx.Optimizer(
vae,
optax.chain(
optax.clip_by_global_norm(1.0), # Clip gradients
optax.adam(learning_rate=1e-3),
)
)
# Solution 2: Lower learning rate
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=1e-4))
# Solution 3: Batch normalization in encoder/decoder
# (implement custom encoder/decoder with normalization)
Problem 4: Poor Disentanglement¤
Symptoms: Latent dimensions don't correspond to interpretable factors
Solutions:
# Solution 1: Increase beta
beta_vae = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=4.0, # Higher beta encourages disentanglement
rngs=rngs,
)
# Solution 2: More latent dimensions
# Give model more capacity to separate factors
vae = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=128, # Increased from 32
rngs=rngs,
)
# Solution 3: Total Correlation penalty (Factor VAE style)
# Implement custom TC loss
Advanced Techniques¤
Custom Loss Functions¤
def custom_loss_fn(x_true, x_pred):
"""Custom reconstruction loss combining multiple terms."""
# L1 loss for sparsity
l1_loss = jnp.mean(jnp.abs(x_true - x_pred))
# L2 loss for overall quality
l2_loss = jnp.mean((x_true - x_pred) ** 2)
# Combine
return 0.5 * l1_loss + 0.5 * l2_loss
# Use in training
losses = vae.loss_fn(
x=batch,
outputs=outputs,
reconstruction_loss_fn=custom_loss_fn,
)
Multi-GPU Training¤
from jax import devices, pmap
# Replicate model across devices
replicated_vae = nnx.Replicated(vae)
@pmap
def parallel_train_step(model, batch):
def loss_fn(model):
outputs = model(batch, rngs=rngs)
losses = model.loss_fn(x=batch, outputs=outputs)
return losses["total_loss"]
loss = nnx.value_and_grad(loss_fn)(model)
return loss
# Training with multiple GPUs
for batch in train_loader:
# Split batch across devices
batch_split = jnp.array_split(batch, len(devices()))
losses = parallel_train_step(replicated_vae, batch_split)
Checkpointing¤
from flax import nnx
# Save model
state = nnx.state(vae)
with open("vae_checkpoint.pkl", "wb") as f:
import pickle
pickle.dump(state, f)
# Load model
with open("vae_checkpoint.pkl", "rb") as f:
loaded_state = pickle.load(f)
# Restore to new model instance
new_vae = VAE(encoder, decoder, latent_dim=32, rngs=rngs)
nnx.update(new_vae, loaded_state)
Best Practices¤
DO ✅¤
- Start simple: Begin with standard VAE before trying variants
- Monitor both losses: Track reconstruction AND KL divergence
- Use appropriate loss: MSE for continuous, BCE for binary data
- Visualize latent space: Plot 2D projections to check structure
- Test interpolation: Smooth interpolation indicates good latent space
- Check per-dim KL: Detect posterior collapse early
- Use beta annealing: Helps avoid posterior collapse
- Larger batch size: More stable training (128+ recommended)
DON'T ❌¤
- Don't ignore KL: Zero KL means model ignores latent code
- Don't use too small latent: Leads to underfitting
- Don't overtrain: Can lead to posterior collapse
- Don't skip validation: Regular evaluation prevents surprises
- Don't forget temperature: Use temperature for diverse sampling
- Don't compare different betas directly: Higher beta trades reconstruction for disentanglement
Performance Tips¤
Memory Optimization¤
# Use gradient checkpointing for large models
from jax import checkpoint
@checkpoint
def encoder_forward(encoder, x):
return encoder(x)
# Use lower precision for faster training
vae = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
precision=jax.lax.Precision.DEFAULT, # or HIGHEST, FASTEST
rngs=rngs,
)
Speed Optimization¤
# JIT compile training step
@nnx.jit
def fast_train_step(model, optimizer, batch):
def loss_fn(model):
outputs = model(batch, rngs=rngs)
losses = model.loss_fn(x=batch, outputs=outputs)
return losses["total_loss"], losses
loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(model)
optimizer.update(nnx.get_gradients(loss))
return losses
# Vectorize sampling
vmapped_decode = jax.vmap(lambda z: vae.decode(z, rngs=rngs))
samples = vmapped_decode(latent_vectors)
Summary¤
This guide covered:
- ✅ Creating encoders, decoders, and VAE models
- ✅ Training standard VAE, β-VAE, CVAE, and VQ-VAE
- ✅ Generating samples and manipulating latent space
- ✅ Evaluation metrics and diagnostics
- ✅ Hyperparameter tuning strategies
- ✅ Troubleshooting common issues
- ✅ Advanced techniques and optimizations
Next Steps¤
- VAE Concepts — Deep dive into theory
- VAE API Reference — Complete API documentation
- VAE MNIST Example — Hands-on tutorial
- Training Guide — Advanced training techniques
- Benchmarking — Evaluate your models