Skip to content

Multi-β VAE Controllable Generation Benchmark¤

Level: Intermediate | Runtime: ~2-3 minutes (CPU) / ~1 minute (GPU) | Format: Python + Jupyter

Prerequisites: Understanding of VAEs and disentangled representations | Target Audience: Researchers in controllable generation and representation learning

Overview¤

This example demonstrates how to benchmark β-VAE models for controllable generation using disentanglement and image quality metrics. Learn how the β parameter affects the trade-off between disentanglement and reconstruction quality, and how to systematically evaluate models using MIG, FID, LPIPS, and SSIM metrics.

What You'll Learn¤

  • β-VAE Framework


    Understand how β controls disentanglement vs reconstruction trade-off

  • MIG Score


    Measure disentanglement using Mutual Information Gap

  • Image Quality Metrics


    Evaluate generation quality with FID, LPIPS, and SSIM

  • Quality Trade-offs


    Balance disentanglement, reconstruction, and training time

  • Model Comparison


    Systematically compare different model configurations

  • :material-benchmark: Benchmark Suite


    Comprehensive evaluation across multiple metrics

Files¤

This example is available in two formats:

Quick Start¤

Run the Python Script¤

# Activate environment
source activate.sh

# Run the benchmark demo
python examples/generative_models/vae/multi_beta_vae_benchmark_demo.py

Run the Jupyter Notebook¤

# Activate environment
source activate.sh

# Launch Jupyter
jupyter lab examples/generative_models/vae/multi_beta_vae_benchmark_demo.ipynb

Key Concepts¤

1. β-VAE Framework¤

β-VAE modifies the standard VAE loss by adding a weight β to the KL divergence term:

\[\mathcal{L}_{\beta\text{-VAE}} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\beta \cdot \text{KL}(q(z|x) \| p(z))}_{\text{Regularization}}\]

β Parameter Effects:

  • β = 1: Standard VAE
  • β > 1: Encourages disentanglement, may reduce reconstruction quality
  • β < 1: Prioritizes reconstruction, may reduce disentanglement
# Different β values for different goals
beta_configs = {
    "reconstruction_focused": 0.5,   # Better reconstruction
    "balanced": 1.0,                 # Standard VAE
    "disentanglement_focused": 4.0,  # Better disentanglement
}

2. Disentanglement¤

Disentangled representations have independent latent dimensions that each capture a single factor of variation:

Disentangled Latent Space:
z[0] → Controls rotation
z[1] → Controls size
z[2] → Controls color
...

Entangled Latent Space:
z[0] → Affects rotation AND size
z[1] → Affects size AND color
z[2] → Affects rotation AND color

Why Disentanglement Matters:

  • Better interpretability
  • More controllable generation
  • Improved generalization
  • Easier downstream task learning

3. MIG Score (Mutual Information Gap)¤

MIG measures how much each latent dimension encodes a single ground-truth factor:

\[\text{MIG} = \frac{1}{K} \sum_{k=1}^{K} \frac{I(z_j^{(k)}; v_k) - I(z_j^{(k-1)}; v_k)}{H(v_k)}\]

where \(I\) is mutual information, \(v_k\) are ground-truth factors, and \(j^{(k)}\) is the latent dimension with highest MI for factor \(k\).

from workshop.benchmarks.metrics.disentanglement import MIGMetric

# Compute MIG score
mig_metric = MIGMetric(rngs=rngs)
mig_score = mig_metric.compute(
    latent_codes=z,          # (batch, latent_dim)
    ground_truth_factors=factors  # (batch, num_factors)
)
# mig_score: 0-1 (higher is better)
# >0.3: Good disentanglement
# >0.5: Excellent disentanglement

4. FID Score (Fréchet Inception Distance)¤

Measures distribution distance between real and generated images in feature space:

\[\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2\sqrt{\Sigma_r \Sigma_g})\]
from workshop.benchmarks.metrics.image import FIDMetric

fid_metric = FIDMetric(rngs=rngs)
fid_score = fid_metric.compute(
    real_images=real_imgs,
    generated_images=gen_imgs
)
# fid_score: 0-∞ (lower is better)
# <30: Excellent quality
# <50: Good quality
# <100: Acceptable quality

5. LPIPS (Learned Perceptual Image Patch Similarity)¤

Uses deep features to measure perceptual similarity:

from workshop.benchmarks.metrics.image import LPIPSMetric

lpips_metric = LPIPSMetric(rngs=rngs)
lpips_score = lpips_metric.compute(
    images1=original_imgs,
    images2=reconstructed_imgs
)
# lpips_score: 0-1 (lower is better)
# <0.1: Excellent perceptual quality
# <0.2: Good perceptual quality
# <0.3: Acceptable perceptual quality

6. SSIM (Structural Similarity Index)¤

Measures structural similarity between images:

\[\text{SSIM}(x, y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}\]
from workshop.benchmarks.metrics.image import SSIMMetric

ssim_metric = SSIMMetric(rngs=rngs)
ssim_score = ssim_metric.compute(
    images1=original_imgs,
    images2=reconstructed_imgs
)
# ssim_score: 0-1 (higher is better)
# >0.9: Excellent structural similarity
# >0.8: Good structural similarity
# >0.7: Acceptable structural similarity

7. Multi-β VAE Benchmark Suite¤

Comprehensive evaluation across all metrics:

from workshop.benchmarks.suites.multi_beta_vae_suite import MultiBetaVAEBenchmarkSuite

suite = MultiBetaVAEBenchmarkSuite(
    dataset_config={
        "num_samples": 100,
        "image_size": 64,
        "include_attributes": True,  # For disentanglement metrics
    },
    benchmark_config={
        "num_samples": 50,
        "batch_size": 10,
    },
    rngs=rngs
)

# Run evaluation
results = suite.run_all(model)
# results = {
#     "multi_beta_vae_benchmark": {
#         "mig_score": 0.35,
#         "fid_score": 42.3,
#         "lpips_score": 0.18,
#         "ssim_score": 0.84,
#         "training_time_per_epoch": 7.5,
#     }
# }

Code Structure¤

The example consists of three main components:

  1. MockMultiBetaVAE - Simulates β-VAE with controllable quality levels
  2. Three quality modes: low, medium, high
  3. Demonstrates encode-decode-generate pipeline
  4. Shows proper RNG handling patterns

  5. Benchmark Suite - Comprehensive evaluation system

  6. Disentanglement metrics (MIG)
  7. Image quality metrics (FID, LPIPS, SSIM)
  8. Training efficiency metrics

  9. Model Comparison - Systematic evaluation

  10. Compare across quality levels
  11. Analyze trade-offs
  12. Performance targets

Features Demonstrated¤

  • ✅ β-VAE framework for controllable generation
  • ✅ Disentanglement evaluation (MIG score)
  • ✅ Image quality assessment (FID, LPIPS, SSIM)
  • ✅ Model comparison across quality levels
  • ✅ Trade-off analysis (disentanglement vs quality vs training time)
  • ✅ Comprehensive benchmark suite
  • ✅ Performance target assessment

Experiments to Try¤

  1. Adjust Latent Dimensionality
model = MockMultiBetaVAE(
    latent_dim=256,  # Try different sizes: 16, 32, 64, 128, 256
    image_size=64,
    quality_level="high",
    rngs=rngs
)
  1. Vary β Values
# In real β-VAE training
beta_values = [0.5, 1.0, 2.0, 4.0, 8.0]
for beta in beta_values:
    model = BetaVAE(latent_dim=64, beta=beta, rngs=rngs)
    # Train and evaluate
  1. Change Dataset Size
suite = MultiBetaVAEBenchmarkSuite(
    dataset_config={
        "num_samples": 500,  # More samples for stable metrics
        "image_size": 128,   # Higher resolution
    },
    # ...
)
  1. Custom Quality Configurations
model.quality_params["custom"] = {
    "mig_score": 0.40,
    "fid_score": 35.0,
    "lpips_score": 0.12,
    "ssim_score": 0.88,
}
model.quality_level = "custom"

Next Steps¤

Troubleshooting¤

Benchmark Runs Slowly¤

Symptom: Evaluation takes too long

Solution: Reduce dataset or batch size

dataset_config = {
    "num_samples": 50,  # Smaller dataset
}
benchmark_config = {
    "batch_size": 5,    # Smaller batches
}

Models Don't Meet Targets¤

Symptom: All models fail to meet performance targets

Cause: Insufficient model capacity or training

Solution: Increase latent dimensionality or improve quality

model = MockMultiBetaVAE(
    latent_dim=128,      # Larger capacity
    quality_level="high", # Better quality
    rngs=rngs
)

High Memory Usage¤

Symptom: Out of memory errors during evaluation

Solution: Reduce image size or batch size

dataset_config = {
    "image_size": 32,   # Smaller images
}
benchmark_config = {
    "batch_size": 4,    # Smaller batches
}

Inconsistent MIG Scores¤

Symptom: MIG scores vary significantly between runs

Cause: Too few samples for stable metric computation

Solution: Increase number of evaluation samples

benchmark_config = {
    "num_samples": 100,  # More samples for stability
}

Additional Resources¤

Documentation¤

Papers and Resources¤

External Libraries¤