VAE MNIST Example - Variational Autoencoder Demonstration¤
Level: Beginner | Runtime: ~2-3 minutes (CPU/GPU) | Format: Python + Jupyter
This example demonstrates how to build a Variational Autoencoder (VAE) on MNIST using Workshop's modular encoder/decoder components. It showcases explicit component creation, proper RNG handling, and VAE inference (no training - this is an architecture demonstration).
Files¤
- Python Script:
examples/generative_models/image/vae/vae_mnist.py - Jupyter Notebook:
examples/generative_models/image/vae/vae_mnist.ipynb
Dual-Format Implementation
This example is available in two synchronized formats:
- Python Script (.py) - For version control, IDE development, and CI/CD integration
- Jupyter Notebook (.ipynb) - For interactive learning, experimentation, and exploration
Both formats contain identical content and can be used interchangeably. Choose the format that best suits your workflow.
Quick Start¤
# Activate Workshop environment
source activate.sh
# Run the Python script
python examples/generative_models/image/vae/vae_mnist.py
# Or launch Jupyter notebook
jupyter lab examples/generative_models/image/vae/vae_mnist.ipynb
Overview¤
Learning Objectives:
- Understand VAE architecture: encoder → latent space → decoder
- Use Workshop's MLPEncoder and MLPDecoder components
- Handle RNGs properly in Flax NNX with sample streams
- Understand the reparameterization trick
- Generate samples from learned latent space
- Visualize reconstructions and generations
Prerequisites:
- Basic understanding of autoencoders and latent representations
- Familiarity with JAX and Flax NNX basics
- Understanding of variational inference concepts (ELBO, KL divergence)
- Workshop installed
Estimated Time: 5 minutes
What's Covered¤
-
Modular Components
MLPEncoder and MLPDecoder for building VAEs from reusable parts
-
RNG Handling
Proper random number generation with separate streams for sampling
-
VAE Architecture
Encoder (x → μ, σ), reparameterization (z), decoder (z → x̂)
-
Visualization
Original, reconstructed, and generated samples side-by-side
Expected Results:
- Quick demonstration (~2-3 minutes on CPU, ~30 seconds on GPU)
- Synthetic MNIST-like data (for fast execution without downloads)
- Visualization showing three rows: original, reconstructed, generated images
- Understanding of how to assemble VAE from Workshop components
Theory Background¤
Variational Autoencoder (VAE)¤
A VAE is a generative model that learns a probabilistic latent representation:
Mathematical Framework:
- Encoder: \(q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x))\) - Approximate posterior
- Decoder: \(p_\theta(x|z)\) - Likelihood of data given latent code
- Prior: \(p(z) = \mathcal{N}(0, I)\) - Standard normal prior
VAE Loss (ELBO - Evidence Lower Bound):
Where:
- Reconstruction term: \(\mathbb{E}_{q(z|x)}[\log p(x|z)] \approx -\|x - \hat{x}\|^2\) (MSE)
- KL term: \(\text{KL}(q(z|x) \| p(z))\) has closed form for Gaussians
Reparameterization Trick¤
To enable backpropagation through stochastic sampling:
This separates the stochastic component (ε) from learnable parameters (μ, σ), allowing gradients to flow through the sampling operation.
Why Reparameterization?
Without this trick, we couldn't backpropagate through random sampling because sampling is not differentiable. By expressing z as a deterministic function of ε, μ, and σ, we can compute gradients with respect to μ and σ.
Imports¤
Import Workshop's modular VAE components:
- MLPEncoder: Maps inputs to latent distribution parameters (μ, log σ²)
- MLPDecoder: Maps latent codes to reconstructions
- VAE: Base VAE class that combines encoder + decoder with ELBO loss
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from flax import nnx
from workshop.generative_models.models.vae import VAE
from workshop.generative_models.models.vae.decoders import MLPDecoder
from workshop.generative_models.models.vae.encoders import MLPEncoder
Alternative: Factory Pattern
Workshop also provides a factory pattern for model creation:
The factory pattern is useful when you want consistent model creation across different model types. This example uses explicit component creation to demonstrate the modular design, but both approaches are valid. See the "Factory Pattern Alternative" section below for details.Data Loading¤
For this example, we create synthetic MNIST-like data. In production, you would use real MNIST from tensorflow_datasets or torchvision.
Data Format:
- Images: 28×28×1 (grayscale)
- Values: [0, 1] range (normalized)
- Shape: (batch_size, height, width, channels)
def load_mnist_data():
"""Load MNIST dataset.
In this example, we use synthetic data for quick demonstration.
Replace this with real MNIST loading for production use.
Returns:
Tuple of (train_images, test_images)
Note:
Real MNIST loading would look like:
```python
import tensorflow_datasets as tfds
ds = tfds.load('mnist', split='train', as_supervised=True)
images = ds.map(lambda x, y: x / 255.0) # Normalize to [0, 1]
```
"""
# Create synthetic MNIST-like data with proper dimensions
key = jax.random.key(42)
train_key, test_key = jax.random.split(key)
# Create synthetic data: 28×28×1 images in [0, 1] range
train_images = jax.random.uniform(train_key, (1000, 28, 28, 1))
test_images = jax.random.uniform(test_key, (100, 28, 28, 1))
return train_images, test_images
Visualization Function¤
The visualization function shows three rows to assess VAE quality:
- Original: Input images from the dataset
- Reconstructed: VAE reconstructions (tests encoder + decoder quality)
- Generated: Samples from random latent codes (tests learned prior)
def visualize_vae_results(original, reconstructed, generated, num_samples=5):
"""Visualize VAE results side-by-side.
Args:
original: Original images [batch, height, width, channels]
reconstructed: Reconstructed images (same shape as original)
generated: Generated images from random latent codes
num_samples: Number of samples to display (default: 5)
Returns:
matplotlib.figure.Figure: The created figure
Note:
All images should be in [0, 1] range for proper visualization.
Images are clipped to [0, 1] before display to handle any overshooting.
"""
fig, axes = plt.subplots(3, num_samples, figsize=(12, 7))
for i in range(num_samples):
# Row 1: Original images
axes[0, i].imshow(jnp.clip(original[i, :, :, 0], 0, 1), cmap="gray", vmin=0, vmax=1)
axes[0, i].axis("off")
if i == 0:
axes[0, i].set_ylabel("Original", fontsize=12, fontweight="bold")
# Row 2: Reconstructed images
axes[1, i].imshow(jnp.clip(reconstructed[i, :, :, 0], 0, 1), cmap="gray", vmin=0, vmax=1)
axes[1, i].axis("off")
if i == 0:
axes[1, i].set_ylabel("Reconstructed", fontsize=12, fontweight="bold")
# Row 3: Generated images (from random latent codes)
axes[2, i].imshow(jnp.clip(generated[i, :, :, 0], 0, 1), cmap="gray", vmin=0, vmax=1)
axes[2, i].axis("off")
if i == 0:
axes[2, i].set_ylabel("Generated", fontsize=12, fontweight="bold")
plt.tight_layout()
return fig
Main Pipeline¤
The main function demonstrates the VAE workflow using Workshop's modular components:
- Setup: Initialize RNG for reproducibility
- Data: Load MNIST (synthetic in this demo)
- Encoder: Create MLPEncoder to map x → (μ, log σ²)
- Decoder: Create MLPDecoder to map z → x̂
- VAE: Combine encoder + decoder into full VAE
- Forward Pass: Run reconstruction
- Generation: Sample from prior
- Visualization: Display results
Why Explicit Component Creation?
We explicitly create encoder and decoder to demonstrate:
- How to configure Workshop's components
- The modular design pattern (reusable, swappable parts)
- How components connect in the VAE
- Easy customization (swap MLP → CNN, adjust layers, etc.)
This approach gives you full control and understanding. For production workflows where you want consistency across model types, the factory pattern (see below) is also available.
def main():
"""Run the VAE MNIST example.
This function demonstrates the complete VAE pipeline using Workshop's
modular encoder/decoder components.
"""
print("=" * 80)
print("VAE MNIST Example - Using Workshop's MLPEncoder & MLPDecoder")
print("=" * 80)
Step 1: Setup RNG¤
In Flax NNX, we use nnx.Rngs to manage random number generation. We need separate streams for:
- params: Parameter initialization
- dropout: Dropout layers (if used)
- sample: Stochastic sampling in VAE (reparameterization trick)
# Step 1: Set random seed for reproducibility
seed = 42
key = jax.random.key(seed)
params_key, dropout_key, sample_key = jax.random.split(key, 3)
# Create RNG streams for different purposes
rngs = nnx.Rngs(params=params_key, dropout=dropout_key, sample=sample_key)
RNG Best Practices
Always split your RNG into separate streams for different purposes. This ensures:
- Reproducibility across runs
- Proper handling of stochastic operations
- Thread-safe random state management
- No interference between parameter init and sampling
Step 2: Load Data¤
MNIST consists of 28×28 grayscale images of handwritten digits (0-9).
- Training set: 60,000 images (we use 1,000 synthetic for demo)
- Test set: 10,000 images (we use 100 synthetic for demo)
Images are normalized to [0, 1] range for stable training.
# Step 2: Load data
print("\n📊 Loading MNIST data...")
train_images, test_images = load_mnist_data()
print(f" Train data shape: {train_images.shape}") # (1000, 28, 28, 1)
print(f" Test data shape: {test_images.shape}") # (100, 28, 28, 1)
Output:
Step 3: Create Encoder¤
Workshop's MLPEncoder maps inputs to latent distribution parameters:
- Input: x (28×28×1 = 784 features after flattening)
- Output: (mean, log_var) for latent distribution q(z|x)
Parameters:
hidden_dims=[256, 128]: Two hidden layers with decreasing dimensionslatent_dim=32: Dimension of latent space zactivation="relu": ReLU activation between layersinput_dim=(28, 28, 1): Shape of input images (auto-flattened to 784)
# Step 3: Create encoder using Workshop's MLPEncoder
print("\n🔧 Creating VAE components using Workshop APIs...")
latent_dim = 32
encoder = MLPEncoder(
hidden_dims=[256, 128], # Encoder architecture
latent_dim=latent_dim, # Latent space dimension
activation="relu", # Activation function
input_dim=(28, 28, 1), # Input image shape
rngs=rngs,
)
print(f" ✅ Encoder created: hidden_dims=[256, 128], latent_dim={latent_dim}")
Output:
🔧 Creating VAE components using Workshop APIs...
✅ Encoder created: hidden_dims=[256, 128], latent_dim=32
Step 4: Create Decoder¤
Workshop's MLPDecoder maps latent codes to reconstructions:
- Input: z (32-dimensional latent vector)
- Output: x̂ (28×28×1 reconstructed image)
Parameters:
hidden_dims=[128, 256]: Reversed encoder dims (symmetric architecture)output_dim=(28, 28, 1): Shape of reconstructed imageslatent_dim=32: Dimension of latent space (must match encoder)activation="relu": ReLU activation (except final layer uses sigmoid)
Automatic Output Activation
The decoder automatically applies sigmoid activation to the output to ensure pixel values are in [0, 1] range. This matches the input image range.
# Step 4: Create decoder using Workshop's MLPDecoder
decoder = MLPDecoder(
hidden_dims=[128, 256], # Decoder architecture (reversed)
output_dim=(28, 28, 1), # Output image shape
latent_dim=latent_dim, # Latent space dimension
activation="relu", # Activation function
rngs=rngs,
)
print(f" ✅ Decoder created: hidden_dims=[128, 256], output_size={decoder.output_size}")
Output:
Step 5: Create VAE Model¤
Workshop's VAE class combines encoder + decoder with:
- Forward pass: x → encoder → (μ, log σ²) → sample z → decoder → x̂
- ELBO loss: Reconstruction loss + KL divergence
- Sampling methods: Generate from prior p(z) = N(0, I)
Parameters:
encoder: The MLPEncoder we created abovedecoder: The MLPDecoder we created abovelatent_dim=32: Must match encoder/decoder latent dimensionskl_weight=1.0: Weight for KL term (β-VAE uses β≠1 for disentanglement)
# Step 5: Create VAE model with encoder and decoder
model = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=latent_dim,
kl_weight=1.0, # Standard VAE (β=1), increase for β-VAE
rngs=rngs,
)
print(f" ✅ VAE model created: latent_dim={model.latent_dim}, kl_weight={model.kl_weight}")
Output:
Step 6: Forward Pass (Reconstruction)¤
The forward pass demonstrates the full VAE pipeline:
- Encoding: x → encoder → (μ, log σ²)
- Reparameterization: z = μ + σ ⊙ ε, where ε ~ N(0, I)
- Decoding: z → decoder → x̂ (reconstruction)
Output Dictionary:
reconstructedorreconstruction: Reconstructed images x̂mean: Latent distribution mean μlog_varorlogvar: Latent distribution log variance log σ²z: Sampled latent codes (used for reconstruction)
RNG for Sampling
We pass rngs with a sample stream for the reparameterization trick's random sampling. Without this, the VAE would use deterministic (mean) latent codes.
# Step 6: Test the model with a batch
print("\n🧪 Testing model forward pass...")
test_batch = train_images[:8] # Use 8 images for testing
# Forward pass with proper RNG handling
# The 'sample' RNG stream is used for reparameterization trick
outputs = model(test_batch, rngs=rngs)
# Extract reconstructions (check both possible keys)
reconstructed = outputs.get("reconstructed")
if reconstructed is None:
reconstructed = outputs["reconstruction"]
print(f" ✅ Reconstruction shape: {reconstructed.shape}")
# Extract latent codes
latent = outputs.get("z")
if latent is None:
latent = outputs["latent"]
print(f" ✅ Latent shape: {latent.shape}")
# Show latent statistics to verify reasonable values
print(" 📊 Latent statistics:")
print(f" Mean: {jnp.mean(latent):.4f} (should be near 0)")
print(f" Std: {jnp.std(latent):.4f} (should be near 1 for standard normal)")
Output:
🧪 Testing model forward pass...
✅ Reconstruction shape: (8, 28, 28, 1)
✅ Latent shape: (8, 32)
📊 Latent statistics:
Mean: 0.0315 (should be near 0)
Std: 1.0909 (should be near 1 for standard normal)
Interpreting Latent Statistics
The latent codes should have:
- Mean near 0: The KL term pushes the posterior toward the prior N(0, I)
- Std near 1: Standard deviation close to 1 indicates good regularization
If mean or std deviate significantly, consider adjusting kl_weight or using KL annealing during training.
Step 7: Generation from Prior¤
To generate new samples:
- Sample z ~ N(0, I) from the standard normal prior
- Decode: x_new = decoder(z)
This tests whether the VAE has learned a meaningful latent space.
Quality Indicators:
- Diversity: Generated samples should vary (not all identical)
- Realism: Samples should resemble training data distribution
- Smoothness: Similar z should produce similar x (interpolation works)
Synthetic Data Limitation
With synthetic random data, generations won't be realistic digits, but the shapes should match the training distribution. With real MNIST, you'd see clear digit reconstructions and realistic generated digits.
# Step 7: Generate new samples from the prior
print("\n🎨 Generating new samples from prior...")
n_samples = 5
generated = model.generate(n_samples=n_samples, rngs=rngs)
print(f" ✅ Generated shape: {generated.shape}")
print(f" 📊 Generated pixels range: [{jnp.min(generated):.3f}, {jnp.max(generated):.3f}]")
Output:
🎨 Generating new samples from prior...
✅ Generated shape: (5, 28, 28, 1)
📊 Generated pixels range: [0.059, 0.943]
Step 8: Visualization¤
The visualization shows:
- Top row: Original input images
- Middle row: Reconstructions (tests encoder + decoder quality)
- Bottom row: Generated samples (tests learned prior)
What to look for:
- Reconstructions should closely match originals (good reconstruction loss)
- Generated samples should look plausible (good latent space)
- Diversity in generated samples indicates good latent space coverage
# Step 8: Visualize results
print("\n📊 Visualizing results...")
fig = visualize_vae_results(
original=test_batch[:n_samples],
reconstructed=reconstructed[:n_samples],
generated=generated[:n_samples],
)
# Step 9: Save figure
import os
output_dir = "examples_output"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "vae_mnist_results.png")
fig.savefig(output_path, dpi=150, bbox_inches="tight")
print(f" ✅ Results saved to {output_path}")
Output:
Summary¤
In this example, you learned:
- ✅ VAE architecture: encoder → latent space → decoder
- ✅ Reparameterization trick: enables backpropagation through sampling
- ✅ Workshop's MLPEncoder and MLPDecoder: modular, reusable components
- ✅ Proper RNG handling: use rngs with 'sample' stream for stochastic operations
- ✅ VAE base class: handles ELBO loss computation automatically
Key Insights:
- VAEs trade reconstruction quality for smooth, structured latent spaces
- The latent dimension (32) controls representation capacity
- KL weight controls reconstruction vs. regularization tradeoff
- Modular design allows easy swapping (MLP → CNN, different layers, etc.)
Workshop APIs Used:
MLPEncoder: Maps inputs → (μ, log σ²)MLPDecoder: Maps latent codes → reconstructionsVAE: Combines encoder/decoder with ELBO loss
Factory Pattern Alternative¤
While this example uses explicit component creation to demonstrate the modular design, Workshop also provides a factory pattern for consistent model creation:
from workshop.generative_models.factory import create_model
from workshop.generative_models.core.configuration import ModelConfiguration
# Create configuration
config = ModelConfiguration(
name="vae_mnist",
model_class="workshop.generative_models.models.vae.VAE",
input_dim=(28, 28, 1),
hidden_dims=[256, 128],
output_dim=32, # Latent dimension
activation="relu",
parameters={
"latent_dim": 32,
"kl_weight": 1.0,
"encoder_type": "mlp", # Or "cnn" for CNN encoder
"decoder_type": "mlp", # Or "cnn" for CNN decoder
},
)
# Create model with factory
model = create_model(config, rngs=rngs)
When to use each approach:
| Explicit Components | Factory Pattern |
|---|---|
| ✅ Learning and tutorials | ✅ Production workflows |
| ✅ Custom architectures | ✅ Consistent model creation |
| ✅ Full control | ✅ Configuration-driven |
| ✅ Easy debugging | ✅ Model registry |
Both approaches are valid and can be used interchangeably!
Experiments to Try¤
1. CNN Architecture¤
CNNs often work better for image data than MLPs:
from workshop.generative_models.models.vae.encoders import CNNEncoder
from workshop.generative_models.models.vae.decoders import CNNDecoder
encoder = CNNEncoder(
hidden_dims=[32, 64, 128],
latent_dim=32,
activation="relu",
input_dim=(28, 28, 1),
rngs=rngs
)
decoder = CNNDecoder(
hidden_dims=[128, 64, 32],
output_dim=(28, 28, 1),
latent_dim=32,
activation="relu",
rngs=rngs
)
2. Latent Dimension Experiments¤
- Try
latent_dim=16: Smaller capacity, faster training, may lose details - Try
latent_dim=64: Larger capacity, better reconstructions - Try
latent_dim=128: Very high capacity, risk of overfitting
Trade-off: Larger latent dims → better reconstruction but less structured space
3. β-VAE for Disentanglement¤
model = VAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
kl_weight=4.0, # β=4 encourages disentangled representations
rngs=rngs
)
- Higher β (4.0-10.0): More regularization, worse reconstruction, better disentanglement
- Lower β (0.1-0.5): Better reconstruction, less structured latent space
- β=1: Standard VAE
4. Architecture Variations¤
# Deeper network
encoder = MLPEncoder(
hidden_dims=[512, 256, 128],
latent_dim=32,
activation="gelu", # Try different activations
input_dim=(28, 28, 1),
rngs=rngs
)
- More layers: Higher capacity but slower training
- Different activations: GELU often works better than ReLU
- Batch normalization: Can help with deeper networks
5. Real MNIST Data¤
import tensorflow_datasets as tfds
# Load real MNIST
ds = tfds.load('mnist', split='train', as_supervised=True)
def preprocess(image, label):
image = tf.cast(image, tf.float32) / 255.0 # Normalize to [0, 1]
return image
ds = ds.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
Real MNIST will give much better results and realistic digit generation.
Troubleshooting¤
Encoder/decoder creation fails¤
Solution: Check that input_dim matches your data shape
Common mistake: Forgetting to pass rngs parameter
Fix: All Workshop modules require rngs for initialization
Reconstructions are very blurry¤
Solution: This is expected with MSE loss on images
Explanation: MSE averages over pixel space, causing blur
Alternatives:
- Use CNNEncoder/CNNDecoder instead of MLP
- Try perceptual loss or adversarial training
- Use VQVAE for sharper reconstructions
Generated samples look like noise¤
Solution: VAE needs training; this example only demonstrates architecture
Note: With synthetic random data, generations won't be meaningful
Fix: Train on real MNIST with a proper training loop (see training examples)
KL collapse (all latent codes become identical)¤
Solution: Reduce kl_weight to allow more latent variance
Monitoring: Check jnp.std(latent) - should be > 0.5
Fix: Use KL annealing schedule (start with kl_weight=0.1, increase gradually)
Model output shape mismatch¤
Solution: Ensure encoder latent_dim matches decoder latent_dim
Check: Verify output_dim matches input shape for reconstruction
Next Steps¤
After understanding this basic VAE demonstration, explore:
Related Examples¤
-
Training VAE
See training examples for full training loops with optimizers and loss monitoring
-
Advanced VAEs
β-VAE, Conditional VAE, VQ-VAE, Hierarchical VAE with disentanglement
-
Disentanglement
Multi-β-VAE benchmark with MIG score evaluation and metrics
-
VAE vs GAN
Compare
simple_gan.pyto understand trade-offs between approaches
Documentation Resources¤
- VAE Concepts: Deep dive into VAE theory
- VAE User Guide: Advanced usage patterns
- VAE API Reference: Complete API documentation
- Training Guide: How to train VAEs from scratch
Research Papers¤
-
Auto-Encoding Variational Bayes (Kingma & Welling, 2014) Original VAE paper: https://arxiv.org/abs/1312.6114
-
β-VAE: Learning Basic Visual Concepts (Higgins et al., 2017) Disentanglement via β parameter: https://openreview.net/forum?id=Sy2fzU9gl
-
Understanding disentangling in β-VAE (Burgess et al., 2018) Analysis of β-VAE disentanglement: https://arxiv.org/abs/1804.03599
Congratulations! You've learned how to build VAEs with Workshop's modular components! 🎉