Multi-β VAE Controllable Generation Benchmark¤
Level: Intermediate | Runtime: ~2-3 minutes (CPU) / ~1 minute (GPU) | Format: Python + Jupyter
Prerequisites: Understanding of VAEs and disentangled representations | Target Audience: Researchers in controllable generation and representation learning
Overview¤
This example demonstrates how to benchmark β-VAE models for controllable generation using disentanglement and image quality metrics. Learn how the β parameter affects the trade-off between disentanglement and reconstruction quality, and how to systematically evaluate models using MIG, FID, LPIPS, and SSIM metrics.
What You'll Learn¤
-
β-VAE Framework
Understand how β controls disentanglement vs reconstruction trade-off
-
MIG Score
Measure disentanglement using Mutual Information Gap
-
Image Quality Metrics
Evaluate generation quality with FID, LPIPS, and SSIM
-
Quality Trade-offs
Balance disentanglement, reconstruction, and training time
-
Model Comparison
Systematically compare different model configurations
-
:material-benchmark: Benchmark Suite
Comprehensive evaluation across multiple metrics
Files¤
This example is available in two formats:
- Python Script:
multi_beta_vae_benchmark_demo.py - Jupyter Notebook:
multi_beta_vae_benchmark_demo.ipynb
Quick Start¤
Run the Python Script¤
# Activate environment
source activate.sh
# Run the benchmark demo
python examples/generative_models/vae/multi_beta_vae_benchmark_demo.py
Run the Jupyter Notebook¤
# Activate environment
source activate.sh
# Launch Jupyter
jupyter lab examples/generative_models/vae/multi_beta_vae_benchmark_demo.ipynb
Key Concepts¤
1. β-VAE Framework¤
β-VAE modifies the standard VAE loss by adding a weight β to the KL divergence term:
β Parameter Effects:
- β = 1: Standard VAE
- β > 1: Encourages disentanglement, may reduce reconstruction quality
- β < 1: Prioritizes reconstruction, may reduce disentanglement
# Different β values for different goals
beta_configs = {
"reconstruction_focused": 0.5, # Better reconstruction
"balanced": 1.0, # Standard VAE
"disentanglement_focused": 4.0, # Better disentanglement
}
2. Disentanglement¤
Disentangled representations have independent latent dimensions that each capture a single factor of variation:
Disentangled Latent Space:
z[0] → Controls rotation
z[1] → Controls size
z[2] → Controls color
...
Entangled Latent Space:
z[0] → Affects rotation AND size
z[1] → Affects size AND color
z[2] → Affects rotation AND color
Why Disentanglement Matters:
- Better interpretability
- More controllable generation
- Improved generalization
- Easier downstream task learning
3. MIG Score (Mutual Information Gap)¤
MIG measures how much each latent dimension encodes a single ground-truth factor:
where \(I\) is mutual information, \(v_k\) are ground-truth factors, and \(j^{(k)}\) is the latent dimension with highest MI for factor \(k\).
from workshop.benchmarks.metrics.disentanglement import MIGMetric
# Compute MIG score
mig_metric = MIGMetric(rngs=rngs)
mig_score = mig_metric.compute(
latent_codes=z, # (batch, latent_dim)
ground_truth_factors=factors # (batch, num_factors)
)
# mig_score: 0-1 (higher is better)
# >0.3: Good disentanglement
# >0.5: Excellent disentanglement
4. FID Score (Fréchet Inception Distance)¤
Measures distribution distance between real and generated images in feature space:
from workshop.benchmarks.metrics.image import FIDMetric
fid_metric = FIDMetric(rngs=rngs)
fid_score = fid_metric.compute(
real_images=real_imgs,
generated_images=gen_imgs
)
# fid_score: 0-∞ (lower is better)
# <30: Excellent quality
# <50: Good quality
# <100: Acceptable quality
5. LPIPS (Learned Perceptual Image Patch Similarity)¤
Uses deep features to measure perceptual similarity:
from workshop.benchmarks.metrics.image import LPIPSMetric
lpips_metric = LPIPSMetric(rngs=rngs)
lpips_score = lpips_metric.compute(
images1=original_imgs,
images2=reconstructed_imgs
)
# lpips_score: 0-1 (lower is better)
# <0.1: Excellent perceptual quality
# <0.2: Good perceptual quality
# <0.3: Acceptable perceptual quality
6. SSIM (Structural Similarity Index)¤
Measures structural similarity between images:
from workshop.benchmarks.metrics.image import SSIMMetric
ssim_metric = SSIMMetric(rngs=rngs)
ssim_score = ssim_metric.compute(
images1=original_imgs,
images2=reconstructed_imgs
)
# ssim_score: 0-1 (higher is better)
# >0.9: Excellent structural similarity
# >0.8: Good structural similarity
# >0.7: Acceptable structural similarity
7. Multi-β VAE Benchmark Suite¤
Comprehensive evaluation across all metrics:
from workshop.benchmarks.suites.multi_beta_vae_suite import MultiBetaVAEBenchmarkSuite
suite = MultiBetaVAEBenchmarkSuite(
dataset_config={
"num_samples": 100,
"image_size": 64,
"include_attributes": True, # For disentanglement metrics
},
benchmark_config={
"num_samples": 50,
"batch_size": 10,
},
rngs=rngs
)
# Run evaluation
results = suite.run_all(model)
# results = {
# "multi_beta_vae_benchmark": {
# "mig_score": 0.35,
# "fid_score": 42.3,
# "lpips_score": 0.18,
# "ssim_score": 0.84,
# "training_time_per_epoch": 7.5,
# }
# }
Code Structure¤
The example consists of three main components:
- MockMultiBetaVAE - Simulates β-VAE with controllable quality levels
- Three quality modes: low, medium, high
- Demonstrates encode-decode-generate pipeline
-
Shows proper RNG handling patterns
-
Benchmark Suite - Comprehensive evaluation system
- Disentanglement metrics (MIG)
- Image quality metrics (FID, LPIPS, SSIM)
-
Training efficiency metrics
-
Model Comparison - Systematic evaluation
- Compare across quality levels
- Analyze trade-offs
- Performance targets
Features Demonstrated¤
- ✅ β-VAE framework for controllable generation
- ✅ Disentanglement evaluation (MIG score)
- ✅ Image quality assessment (FID, LPIPS, SSIM)
- ✅ Model comparison across quality levels
- ✅ Trade-off analysis (disentanglement vs quality vs training time)
- ✅ Comprehensive benchmark suite
- ✅ Performance target assessment
Experiments to Try¤
- Adjust Latent Dimensionality
model = MockMultiBetaVAE(
latent_dim=256, # Try different sizes: 16, 32, 64, 128, 256
image_size=64,
quality_level="high",
rngs=rngs
)
- Vary β Values
# In real β-VAE training
beta_values = [0.5, 1.0, 2.0, 4.0, 8.0]
for beta in beta_values:
model = BetaVAE(latent_dim=64, beta=beta, rngs=rngs)
# Train and evaluate
- Change Dataset Size
suite = MultiBetaVAEBenchmarkSuite(
dataset_config={
"num_samples": 500, # More samples for stable metrics
"image_size": 128, # Higher resolution
},
# ...
)
- Custom Quality Configurations
model.quality_params["custom"] = {
"mig_score": 0.40,
"fid_score": 35.0,
"lpips_score": 0.12,
"ssim_score": 0.88,
}
model.quality_level = "custom"
Next Steps¤
-
VAE Training
Train β-VAE on real datasets
-
Advanced VAE
Explore FactorVAE and β-TCVAE
-
Latent Space Analysis
Visualize and interpret disentangled representations
-
Loss Functions
Understand VAE loss components
Troubleshooting¤
Benchmark Runs Slowly¤
Symptom: Evaluation takes too long
Solution: Reduce dataset or batch size
dataset_config = {
"num_samples": 50, # Smaller dataset
}
benchmark_config = {
"batch_size": 5, # Smaller batches
}
Models Don't Meet Targets¤
Symptom: All models fail to meet performance targets
Cause: Insufficient model capacity or training
Solution: Increase latent dimensionality or improve quality
model = MockMultiBetaVAE(
latent_dim=128, # Larger capacity
quality_level="high", # Better quality
rngs=rngs
)
High Memory Usage¤
Symptom: Out of memory errors during evaluation
Solution: Reduce image size or batch size
dataset_config = {
"image_size": 32, # Smaller images
}
benchmark_config = {
"batch_size": 4, # Smaller batches
}
Inconsistent MIG Scores¤
Symptom: MIG scores vary significantly between runs
Cause: Too few samples for stable metric computation
Solution: Increase number of evaluation samples
Additional Resources¤
Documentation¤
- VAE Guide - Complete VAE documentation
- Disentanglement Metrics - Detailed metric explanations
- Benchmark Suite - Benchmarking guide
Related Examples¤
- VAE MNIST Tutorial - Basic VAE training
- Advanced VAE - Advanced VAE variants
- Loss Examples - VAE loss functions
- Framework Features Demo - Configuration system
Papers and Resources¤
- β-VAE: "β-VAE: Learning Basic Visual Concepts with a Constrained VAE" (Higgins et al., 2017)
- FactorVAE: "Disentangling by Factorising" (Kim & Mnih, 2018)
- β-TCVAE: "Isolating Sources of Disentanglement in VAEs" (Chen et al., 2018)
- MIG: "A Framework for the Quantitative Evaluation of Disentangled Representations" (Chen et al., 2018)
- FID: "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium" (Heusel et al., 2017)
- LPIPS: "The Unreasonable Effectiveness of Deep Features as a Perceptual Metric" (Zhang et al., 2018)
External Libraries¤
- disentanglement_lib: Google's disentanglement benchmark
- pytorch-fid: FID score implementation
- LPIPS: Official LPIPS implementation