Skip to content

VAE User Guide¤

Complete guide to building, training, and using Variational Autoencoders with Workshop.

Overview¤

This guide covers practical usage of VAEs in Workshop, from basic setup to advanced techniques. You'll learn how to:

  • Configure VAEs


    Set up encoder/decoder architectures and configure hyperparameters

  • Train Models


    Train VAEs with proper loss functions and monitoring

  • Generate Samples


    Sample from the prior and manipulate latent representations

  • Tune & Debug


    Optimize hyperparameters and troubleshoot common issues


Quick Start¤

Basic VAE Example¤

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.vae import VAE
from workshop.generative_models.models.vae.encoders import MLPEncoder
from workshop.generative_models.models.vae.decoders import MLPDecoder

# Initialize RNGs
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Configuration
input_dim = 784  # 28x28 MNIST
latent_dim = 20
hidden_dims = [512, 256]

# Create encoder and decoder
encoder = MLPEncoder(
    hidden_dims=hidden_dims,
    latent_dim=latent_dim,
    activation="relu",
    input_dim=(input_dim,),
    rngs=rngs,
)

decoder = MLPDecoder(
    hidden_dims=list(reversed(hidden_dims)),
    output_dim=(input_dim,),
    latent_dim=latent_dim,
    activation="relu",
    rngs=rngs,
)

# Create VAE
vae = VAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=latent_dim,
    rngs=rngs,
    kl_weight=1.0,
)

# Forward pass
x = jnp.ones((32, input_dim))
outputs = vae(x, rngs=rngs)

print(f"Reconstruction shape: {outputs['reconstructed'].shape}")
print(f"Latent shape: {outputs['z'].shape}")

Creating VAE Models¤

1. Encoder Architectures¤

MLP Encoder (Fully-Connected)¤

Best for tabular data and flattened images:

from workshop.generative_models.models.vae.encoders import MLPEncoder

encoder = MLPEncoder(
    hidden_dims=[512, 256, 128],  # Network depth
    latent_dim=32,                 # Latent space dimension
    activation="relu",             # Activation function
    input_dim=(784,),              # Flattened input size
    rngs=rngs,
)

# Forward pass returns (mean, log_var)
mean, log_var = encoder(x, rngs=rngs)

CNN Encoder (Convolutional)¤

Best for image data with spatial structure:

from workshop.generative_models.models.vae.encoders import CNNEncoder

encoder = CNNEncoder(
    hidden_dims=[32, 64, 128, 256],  # Channel progression
    latent_dim=64,
    activation="relu",
    input_dim=(28, 28, 1),           # (H, W, C)
    rngs=rngs,
)

# Preserves spatial information through convolutions
mean, log_var = encoder(x, rngs=rngs)

Conditional Encoder¤

Add class conditioning to any encoder:

from workshop.generative_models.models.vae.encoders import ConditionalEncoder

base_encoder = MLPEncoder(
    hidden_dims=[512, 256],
    latent_dim=32,
    input_dim=(784,),
    rngs=rngs,
)

conditional_encoder = ConditionalEncoder(
    encoder=base_encoder,
    num_classes=10,      # Number of classes
    embed_dim=128,       # Embedding dimension for labels
    rngs=rngs,
)

# Pass class labels as condition
mean, log_var = conditional_encoder(x, condition=labels, rngs=rngs)

2. Decoder Architectures¤

MLP Decoder¤

from workshop.generative_models.models.vae.decoders import MLPDecoder

decoder = MLPDecoder(
    hidden_dims=[128, 256, 512],  # Reversed from encoder
    output_dim=(784,),             # Reconstruction size
    latent_dim=32,
    activation="relu",
    rngs=rngs,
)

reconstructed = decoder(z)  # Returns JAX array

CNN Decoder (Transposed Convolutions)¤

from workshop.generative_models.models.vae.decoders import CNNDecoder

decoder = CNNDecoder(
    hidden_dims=[256, 128, 64, 32],  # Reversed channel progression
    output_dim=(28, 28, 1),           # Output image shape
    latent_dim=64,
    activation="relu",
    rngs=rngs,
)

reconstructed = decoder(z)  # Returns (batch, 28, 28, 1)

Conditional Decoder¤

from workshop.generative_models.models.vae.decoders import ConditionalDecoder

base_decoder = MLPDecoder(
    hidden_dims=[128, 256, 512],
    output_dim=(784,),
    latent_dim=32,
    rngs=rngs,
)

conditional_decoder = ConditionalDecoder(
    decoder=base_decoder,
    num_classes=10,
    embed_dim=128,
    rngs=rngs,
)

# Condition on class labels
reconstructed = conditional_decoder(z, condition=labels, rngs=rngs)

3. Complete VAE Models¤

Standard VAE¤

from workshop.generative_models.models.vae.base import VAE

vae = VAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    rngs=rngs,
    kl_weight=1.0,          # Beta parameter (1.0 = standard VAE)
    precision=None,          # Numerical precision
)

β-VAE (Disentangled Representations)¤

from workshop.generative_models.models.vae.beta_vae import BetaVAE

beta_vae = BetaVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=4.0,                    # Higher beta = more disentanglement
    beta_warmup_steps=10000,             # Gradual beta annealing
    reconstruction_loss_type="mse",      # "mse" or "bce"
    rngs=rngs,
)

β-VAE with Capacity Control¤

from workshop.generative_models.models.vae.beta_vae import BetaVAEWithCapacity

capacity_vae = BetaVAEWithCapacity(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=4.0,
    beta_warmup_steps=10000,
    reconstruction_loss_type="mse",
    use_capacity_control=True,
    capacity_max=25.0,                   # Maximum capacity in nats
    capacity_num_iter=25000,             # Steps to reach max capacity
    gamma=1000.0,                        # Capacity loss weight
    rngs=rngs,
)

Conditional VAE¤

from workshop.generative_models.models.vae.conditional import ConditionalVAE

cvae = ConditionalVAE(
    encoder=conditional_encoder,
    decoder=conditional_decoder,
    latent_dim=32,
    condition_dim=10,                    # Dimension of conditioning info
    condition_type="concat",             # Concatenation strategy
    rngs=rngs,
)

# Forward pass with condition
outputs = cvae(x, y=labels, rngs=rngs)

VQ-VAE (Discrete Latents)¤

from workshop.generative_models.models.vae.vq_vae import VQVAE

vqvae = VQVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=64,
    num_embeddings=512,                  # Codebook size
    embedding_dim=64,                    # Embedding dimension
    commitment_cost=0.25,                # Commitment loss weight
    rngs=rngs,
)

Training VAEs¤

Basic Training Loop¤

import optax

# Initialize model and optimizer
vae = VAE(encoder, decoder, latent_dim=32, rngs=rngs)
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=1e-3))

# Training step
@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        # Forward pass
        outputs = model(batch, rngs=rngs)

        # Compute loss
        losses = model.loss_fn(x=batch, outputs=outputs)
        return losses["total_loss"], losses

    # Compute gradients and update
    loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(nnx.get_gradients(loss))

    return losses

# Training loop
for epoch in range(num_epochs):
    for batch in train_loader:
        losses = train_step(vae, optimizer, batch)

        # Log metrics
        print(f"Recon: {losses['reconstruction_loss']:.4f}, "
              f"KL: {losses['kl_loss']:.4f}, "
              f"Total: {losses['total_loss']:.4f}")

Training β-VAE with Annealing¤

beta_vae = BetaVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=4.0,
    beta_warmup_steps=10000,
    rngs=rngs,
)

optimizer = nnx.Optimizer(beta_vae, optax.adam(learning_rate=1e-3))

step = 0
for epoch in range(num_epochs):
    for batch in train_loader:
        def loss_fn(model):
            outputs = model(batch, rngs=rngs)
            # Pass current step for beta annealing
            losses = model.loss_fn(x=batch, outputs=outputs, step=step)
            return losses["total_loss"], losses

        loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(beta_vae)
        optimizer.update(nnx.get_gradients(loss))

        # Monitor beta value
        print(f"Step {step}, Beta: {losses['beta']:.4f}")
        step += 1

Training Conditional VAE¤

cvae = ConditionalVAE(
    encoder=conditional_encoder,
    decoder=conditional_decoder,
    latent_dim=32,
    condition_dim=10,
    rngs=rngs,
)

optimizer = nnx.Optimizer(cvae, optax.adam(learning_rate=1e-3))

for epoch in range(num_epochs):
    for batch_x, batch_y in train_loader:
        def loss_fn(model):
            # Forward with conditioning
            outputs = model(batch_x, y=batch_y, rngs=rngs)
            losses = model.loss_fn(x=batch_x, outputs=outputs)
            return losses["total_loss"], losses

        loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(cvae)
        optimizer.update(nnx.get_gradients(loss))

Training VQ-VAE¤

vqvae = VQVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=64,
    num_embeddings=512,
    commitment_cost=0.25,
    rngs=rngs,
)

optimizer = nnx.Optimizer(vqvae, optax.adam(learning_rate=1e-3))

for epoch in range(num_epochs):
    for batch in train_loader:
        def loss_fn(model):
            outputs = model(batch, rngs=rngs)
            losses = model.loss_fn(x=batch, outputs=outputs)
            return losses["total_loss"], losses

        loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(vqvae)
        optimizer.update(nnx.get_gradients(loss))

        # VQ-VAE specific metrics
        print(f"Recon: {losses['reconstruction_loss']:.4f}, "
              f"Codebook: {losses['codebook_loss']:.4f}, "
              f"Commitment: {losses['commitment_loss']:.4f}")

Generating and Sampling¤

Generate New Samples¤

# Sample from prior distribution
n_samples = 16
samples = vae.sample(n_samples, temperature=1.0, rngs=rngs)

# Temperature controls diversity
hot_samples = vae.sample(n_samples, temperature=2.0, rngs=rngs)   # More diverse
cold_samples = vae.sample(n_samples, temperature=0.5, rngs=rngs)  # More focused

# Using generate() method (alias for sample)
samples = vae.generate(n_samples, temperature=1.0, rngs=rngs)

Conditional Generation¤

# Generate samples for specific classes
target_classes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])  # One of each digit
labels = jax.nn.one_hot(target_classes, num_classes=10)

samples = cvae.sample(n_samples=10, y=labels, temperature=1.0, rngs=rngs)

Reconstruction¤

# Stochastic reconstruction
reconstructed = vae.reconstruct(x, deterministic=False, rngs=rngs)

# Deterministic reconstruction (use mean of latent distribution)
deterministic_recon = vae.reconstruct(x, deterministic=True, rngs=rngs)

Latent Space Manipulation¤

Interpolation Between Images¤

# Linear interpolation in latent space
x1 = test_images[0]  # First image
x2 = test_images[1]  # Second image

interpolated = vae.interpolate(
    x1=x1,
    x2=x2,
    steps=10,  # Number of interpolation steps
    rngs=rngs,
)

# interpolated.shape = (10, *input_shape)

Latent Traversal (Disentanglement Analysis)¤

# Traverse a single latent dimension
x = test_images[0]
dim_to_traverse = 3  # Which latent dimension to vary

traversal = vae.latent_traversal(
    x=x,
    dim=dim_to_traverse,
    range_vals=(-3.0, 3.0),  # Range of values
    steps=10,                 # Number of steps
    rngs=rngs,
)

# traversal.shape = (10, *input_shape)

Manual Latent Manipulation¤

# Encode image to latent space
mean, log_var = vae.encode(x, rngs=rngs)

# Manipulate specific dimensions
modified_mean = mean.copy()
modified_mean = modified_mean.at[:, 5].set(2.0)   # Increase dimension 5
modified_mean = modified_mean.at[:, 10].set(-1.5) # Decrease dimension 10

# Decode modified latent
modified_image = vae.decode(modified_mean, rngs=rngs)

Evaluation and Analysis¤

Reconstruction Quality¤

# Calculate reconstruction error
test_batch = test_images[:100]
reconstructed = vae.reconstruct(test_batch, deterministic=True, rngs=rngs)

mse = jnp.mean((test_batch - reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.4f}")

ELBO (Evidence Lower Bound)¤

# Full ELBO calculation
outputs = vae(test_batch, rngs=rngs)
losses = vae.loss_fn(x=test_batch, outputs=outputs)

elbo = -(losses['reconstruction_loss'] + losses['kl_loss'])
print(f"ELBO: {elbo:.4f}")

Latent Space Statistics¤

# Encode test set
all_means = []
all_logvars = []

for batch in test_loader:
    mean, log_var = vae.encode(batch, rngs=rngs)
    all_means.append(mean)
    all_logvars.append(log_var)

all_means = jnp.concatenate(all_means, axis=0)
all_logvars = jnp.concatenate(all_logvars, axis=0)

# Statistics per dimension
mean_per_dim = jnp.mean(all_means, axis=0)
std_per_dim = jnp.std(all_means, axis=0)
variance_per_dim = jnp.exp(jnp.mean(all_logvars, axis=0))

print(f"Latent mean: {mean_per_dim}")
print(f"Latent std: {std_per_dim}")
print(f"Average variance: {variance_per_dim}")

Disentanglement Metrics¤

# Per-dimension KL divergence (detect posterior collapse)
def per_dim_kl(mean, log_var):
    """Calculate KL divergence per dimension."""
    kl_per_dim = -0.5 * (1 + log_var - mean**2 - jnp.exp(log_var))
    return jnp.mean(kl_per_dim, axis=0)

kl_per_dimension = per_dim_kl(all_means, all_logvars)

# Dimensions with very low KL likely collapsed
inactive_dims = jnp.sum(kl_per_dimension < 0.01)
print(f"Inactive dimensions: {inactive_dims}/{vae.latent_dim}")

Hyperparameter Tuning¤

Key Hyperparameters¤

# Architecture
config = {
    # Network architecture
    "latent_dim": 64,              # 10-100 for images, 2-20 for simple data
    "hidden_dims": [512, 256, 128], # Deeper for complex data
    "activation": "relu",          # or "gelu", "swish"

    # Training
    "learning_rate": 1e-3,         # 1e-4 to 1e-3 typical
    "batch_size": 128,             # Larger is more stable
    "num_epochs": 100,

    # VAE-specific
    "kl_weight": 1.0,              # Beta parameter
    "reconstruction_loss": "mse",  # "mse" or "bce"
}

Beta Tuning for β-VAE¤

# Grid search over beta values
beta_values = [0.5, 1.0, 2.0, 4.0, 8.0]
results = {}

for beta in beta_values:
    model = BetaVAE(
        encoder=encoder,
        decoder=decoder,
        latent_dim=32,
        beta_default=beta,
        rngs=rngs,
    )

    # Train model
    trained_model, metrics = train(model, train_loader, num_epochs=50)

    # Evaluate
    recon_error = evaluate_reconstruction(trained_model, test_loader)
    disentanglement = measure_disentanglement(trained_model, test_loader)

    results[beta] = {
        "recon_error": recon_error,
        "disentanglement": disentanglement,
    }

# Find best trade-off
print(results)

Learning Rate Scheduling¤

import optax

# Cosine decay schedule
schedule = optax.cosine_decay_schedule(
    init_value=1e-3,
    decay_steps=num_train_steps,
    alpha=0.1,  # Final learning rate = 0.1 * init_value
)

optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=schedule))

Common Issues and Solutions¤

Problem 1: Posterior Collapse¤

Symptoms: KL divergence near zero, poor generation quality

Solutions:

# Solution 1: Beta annealing
beta_vae = BetaVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=1.0,
    beta_warmup_steps=10000,  # Start with β=0, gradually increase
    rngs=rngs,
)

# Solution 2: Increase latent capacity gradually
capacity_vae = BetaVAEWithCapacity(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    use_capacity_control=True,
    capacity_max=25.0,
    capacity_num_iter=25000,
    gamma=1000.0,
    rngs=rngs,
)

# Solution 3: Weaker decoder (make it harder to ignore latent)
weak_decoder = MLPDecoder(
    hidden_dims=[128, 256],  # Smaller than encoder
    output_dim=(784,),
    latent_dim=32,
    rngs=rngs,
)

Problem 2: Blurry Reconstructions¤

Symptoms: Overly smooth outputs, lack of detail

Solutions:

# Solution 1: Use perceptual loss (for images)
from workshop.generative_models.core.losses import PerceptualLoss

perceptual_loss = PerceptualLoss(feature_layers=[3, 8, 15])

def custom_reconstruction_loss(x_true, x_pred):
    # Combine MSE with perceptual loss
    mse = jnp.mean((x_true - x_pred) ** 2)
    perceptual = perceptual_loss(x_true, x_pred)
    return mse + 0.1 * perceptual

losses = vae.loss_fn(
    x=batch,
    outputs=outputs,
    reconstruction_loss_fn=custom_reconstruction_loss,
)

# Solution 2: Lower beta (emphasize reconstruction)
vae = VAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    kl_weight=0.5,  # Lower than default 1.0
    rngs=rngs,
)

# Solution 3: Use VQ-VAE (discrete latents often sharper)
vqvae = VQVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=64,
    num_embeddings=512,
    rngs=rngs,
)

Problem 3: Unstable Training¤

Symptoms: Loss oscillations, NaN values

Solutions:

# Solution 1: Gradient clipping
import optax

optimizer = nnx.Optimizer(
    vae,
    optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip gradients
        optax.adam(learning_rate=1e-3),
    )
)

# Solution 2: Lower learning rate
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=1e-4))

# Solution 3: Batch normalization in encoder/decoder
# (implement custom encoder/decoder with normalization)

Problem 4: Poor Disentanglement¤

Symptoms: Latent dimensions don't correspond to interpretable factors

Solutions:

# Solution 1: Increase beta
beta_vae = BetaVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=4.0,  # Higher beta encourages disentanglement
    rngs=rngs,
)

# Solution 2: More latent dimensions
# Give model more capacity to separate factors
vae = VAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=128,  # Increased from 32
    rngs=rngs,
)

# Solution 3: Total Correlation penalty (Factor VAE style)
# Implement custom TC loss

Advanced Techniques¤

Custom Loss Functions¤

def custom_loss_fn(x_true, x_pred):
    """Custom reconstruction loss combining multiple terms."""
    # L1 loss for sparsity
    l1_loss = jnp.mean(jnp.abs(x_true - x_pred))

    # L2 loss for overall quality
    l2_loss = jnp.mean((x_true - x_pred) ** 2)

    # Combine
    return 0.5 * l1_loss + 0.5 * l2_loss

# Use in training
losses = vae.loss_fn(
    x=batch,
    outputs=outputs,
    reconstruction_loss_fn=custom_loss_fn,
)

Multi-GPU Training¤

from jax import devices, pmap

# Replicate model across devices
replicated_vae = nnx.Replicated(vae)

@pmap
def parallel_train_step(model, batch):
    def loss_fn(model):
        outputs = model(batch, rngs=rngs)
        losses = model.loss_fn(x=batch, outputs=outputs)
        return losses["total_loss"]

    loss = nnx.value_and_grad(loss_fn)(model)
    return loss

# Training with multiple GPUs
for batch in train_loader:
    # Split batch across devices
    batch_split = jnp.array_split(batch, len(devices()))
    losses = parallel_train_step(replicated_vae, batch_split)

Checkpointing¤

from flax import nnx

# Save model
state = nnx.state(vae)
with open("vae_checkpoint.pkl", "wb") as f:
    import pickle
    pickle.dump(state, f)

# Load model
with open("vae_checkpoint.pkl", "rb") as f:
    loaded_state = pickle.load(f)

# Restore to new model instance
new_vae = VAE(encoder, decoder, latent_dim=32, rngs=rngs)
nnx.update(new_vae, loaded_state)

Best Practices¤

DO ✅¤

  • Start simple: Begin with standard VAE before trying variants
  • Monitor both losses: Track reconstruction AND KL divergence
  • Use appropriate loss: MSE for continuous, BCE for binary data
  • Visualize latent space: Plot 2D projections to check structure
  • Test interpolation: Smooth interpolation indicates good latent space
  • Check per-dim KL: Detect posterior collapse early
  • Use beta annealing: Helps avoid posterior collapse
  • Larger batch size: More stable training (128+ recommended)

DON'T ❌¤

  • Don't ignore KL: Zero KL means model ignores latent code
  • Don't use too small latent: Leads to underfitting
  • Don't overtrain: Can lead to posterior collapse
  • Don't skip validation: Regular evaluation prevents surprises
  • Don't forget temperature: Use temperature for diverse sampling
  • Don't compare different betas directly: Higher beta trades reconstruction for disentanglement

Performance Tips¤

Memory Optimization¤

# Use gradient checkpointing for large models
from jax import checkpoint

@checkpoint
def encoder_forward(encoder, x):
    return encoder(x)

# Use lower precision for faster training
vae = VAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    precision=jax.lax.Precision.DEFAULT,  # or HIGHEST, FASTEST
    rngs=rngs,
)

Speed Optimization¤

# JIT compile training step
@nnx.jit
def fast_train_step(model, optimizer, batch):
    def loss_fn(model):
        outputs = model(batch, rngs=rngs)
        losses = model.loss_fn(x=batch, outputs=outputs)
        return losses["total_loss"], losses

    loss, losses = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(nnx.get_gradients(loss))
    return losses

# Vectorize sampling
vmapped_decode = jax.vmap(lambda z: vae.decode(z, rngs=rngs))
samples = vmapped_decode(latent_vectors)

Summary¤

This guide covered:

  • ✅ Creating encoders, decoders, and VAE models
  • ✅ Training standard VAE, β-VAE, CVAE, and VQ-VAE
  • ✅ Generating samples and manipulating latent space
  • ✅ Evaluation metrics and diagnostics
  • ✅ Hyperparameter tuning strategies
  • ✅ Troubleshooting common issues
  • ✅ Advanced techniques and optimizations

Next Steps¤