Variational Autoencoders (VAEs) Explained¤
-
Probabilistic Framework
Learn distributions over latent codes rather than deterministic encodings
-
Structured Latent Space
Continuous, smooth latent space enabling interpolation and controlled generation
-
Principled Generation
Sample from learned prior distribution to generate new, realistic data
-
Differentiable Training
End-to-end optimization using the reparameterization trick
Overview¤
Variational Autoencoders (VAEs) are a class of deep generative models that combine neural networks with variational inference to learn probabilistic representations of data. Unlike standard autoencoders that learn deterministic mappings, VAEs learn probability distributions over latent representations, enabling principled data generation and interpretable latent spaces.
What makes VAEs special?
VAEs solve a fundamental challenge in generative modeling: how to learn a structured, continuous latent space that can be sampled to generate new, realistic data. By imposing a probabilistic structure through variational inference, VAEs create smooth latent spaces where:
- Interpolation works naturally - moving between two points in latent space produces meaningful intermediate outputs
- Random sampling generates valid data - sampling from the prior produces realistic new samples
- The representation is interpretable - the latent space has structure that can be understood and controlled
The Intuition: Compression and Blueprints¤
Think of VAEs like an architect creating blueprints:
-
The Encoder compresses a complex house (your data) into essential instructions (latent vector), capturing key features—number of rooms, architectural style, materials—while discarding minor details like exact nail positions.
-
The Latent Space is a structured blueprint repository where similar designs cluster together: ranch houses near each other, Victorian mansions in another region, modern apartments elsewhere.
-
The Decoder rebuilds houses from blueprints, reconstructing recognizable structures though minor details differ from the original.
The critical distinction: VAEs encode to probability distributions, not single points. Each house maps to a probability cloud of similar blueprints, ensuring the latent space remains smooth and continuous. This enables generation—sample a random blueprint from the structured space, and the decoder builds a valid house, even one never seen before.
Mathematical Foundation¤
The Generative Story¤
VAEs model the data generation process as a two-step procedure:
- Sample latent code: \(z \sim p(z)\) from a simple prior distribution (typically standard normal)
- Generate data: \(x \sim p_\theta(x|z)\) using a decoder network parameterized by \(\theta\)
The goal is to learn parameters \(\theta\) (decoder) and \(\phi\) (encoder) that maximize the likelihood of observed data \(p_\theta(x)\).
graph LR
subgraph "True Generative Model"
A["Prior p(z)<br/>𝒩(0,I)"] --> B["Decoder p(x|z)"]
B --> C["Data x"]
end
subgraph "Inference Model"
C2["Data x"] --> D["Encoder q(z|x)"]
D --> E["Approximate<br/>Posterior"]
end
style A fill:#e1f5ff
style B fill:#fff3e0
style D fill:#f3e5f5
Variational Inference: Why We Need Approximation¤
The true posterior \(p_\theta(z|x)\) tells us what latent code likely generated our data. However, computing it requires:
The integral in the denominator (the evidence \(p_\theta(x)\)) is intractable for high-dimensional \(z\)—we'd need to integrate over all possible latent codes. VAEs sidestep this by learning an approximate posterior \(q_\phi(z|x)\) (the encoder) that's easy to compute.
The ELBO: Evidence Lower BOund¤
The key insight of VAEs is to maximize a tractable lower bound on the log-likelihood called the Evidence Lower BOund (ELBO):
This inequality states that the log-likelihood is always at least as large as the ELBO. The gap between them equals exactly \(D_{\text{KL}}(q_\phi(z|x) \| p_\theta(z|x))\)—when our approximate posterior perfectly matches the true posterior, there's no gap and we achieve the true likelihood.
Derivation from First Principles¤
Starting with the log-likelihood and introducing our approximate posterior:
Applying Jensen's inequality (since log is concave):
We get:
Expanding \(p_\theta(x, z) = p_\theta(x|z)p(z)\):
Two Interpretable Terms¤
The ELBO naturally decomposes into two competing objectives:
- Reconstruction Term: \(\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]\)
- Measures how well we can reconstruct the input from sampled latent codes
- Encourages the model to preserve information
-
Higher is better (less negative)
-
KL Divergence: \(D_{\text{KL}}(q_\phi(z|x) \| p(z))\)
- Measures how close our learned encoding is to the prior
- Regularizes the latent space to be smooth and structured
- Prevents "cheating" by spreading encodings arbitrarily far apart
- Lower is better (closer to prior)
The fundamental trade-off: The reconstruction term wants to encode all information to perfectly reconstruct. The KL term wants to compress encodings to match the simple prior. Training finds the optimal balance, creating a structured latent space that retains essential information while remaining smooth for generation.
Architecture Components¤
Encoder: Variational Posterior \(q_\phi(z|x)\)¤
The encoder is a neural network that maps inputs to parameters of a probability distribution over latent codes:
For a diagonal Gaussian (most common choice), the encoder outputs:
- Mean \(\mu_\phi(x) \in \mathbb{R}^d\) - the center of the latent distribution
- Log-variance \(\log \sigma^2_\phi(x) \in \mathbb{R}^d\) - the spread/uncertainty
graph LR
A["Input x<br/>(e.g., 28×28 image)"] --> B["Encoder Network<br/>(Conv layers or FC)"]
B --> C["Mean μ<br/>(d dimensions)"]
B --> D["Log-variance log σ²<br/>(d dimensions)"]
C --> E["Latent Distribution<br/>𝒩(μ, σ²I)"]
D --> E
style B fill:#f3e5f5
style E fill:#e8eaf6
Why output log-variance? Numerical stability. Variance must be positive, and learning \(\log \sigma^2\) allows the network to output any real number while ensuring \(\sigma^2 = \exp(\log \sigma^2) > 0\).
Why diagonal covariance? Full covariance matrices require \(O(d^2)\) parameters and are harder to optimize. Diagonal covariance assumes independence between dimensions, requiring only \(O(d)\) parameters while working well in practice.
Decoder: Likelihood \(p_\theta(x|z)\)¤
The decoder is a neural network that maps latent codes back to data space:
The choice of output distribution depends on your data:
- Gaussian (continuous): For real-valued images (often simplified to MSE loss with fixed variance)
- Bernoulli (binary): For binary images or features (use sigmoid + BCE loss)
- Categorical: For discrete data (use softmax + cross-entropy)
graph LR
A["Latent z<br/>(d dimensions)"] --> B["Decoder Network<br/>(Transposed Conv or FC)"]
B --> C["Reconstruction μ(z)<br/>(same shape as input)"]
style B fill:#fff3e0
style C fill:#e8f5e9
The Reparameterization Trick¤
The Problem: Backpropagation Through Sampling¤
We need to compute gradients of \(\mathbb{E}_{q_\phi(z|x)}[f(z)]\) with respect to \(\phi\). Naively sampling \(z \sim q_\phi(z|x)\) and computing \(\nabla_\phi f(z)\) doesn't work because the sampling operation itself depends on \(\phi\) but isn't differentiable.
The Solution: Separate Randomness from Parameters¤
Instead of sampling \(z\) directly from \(q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma^2_\phi(x))\), reparameterize as:
where \(\odot\) denotes element-wise multiplication.
graph TD
A["Input x"] --> B["Encoder"]
B --> C["μ (mean)"]
B --> D["σ (std dev)"]
E["ε ~ 𝒩(0,I)<br/>(random noise)"] --> F["z = μ + σ ⊙ ε"]
C --> F
D --> F
F --> G["Decoder"]
G --> H["Reconstruction x̂"]
style F fill:#ffebee
style E fill:#e1f5ff
Why this works:
- The randomness (\(\epsilon\)) is now independent of our parameters \(\phi\)
- Gradients flow through the deterministic operations \(\mu_\phi\) and \(\sigma_\phi\)
- The expectation becomes \(\mathbb{E}_{p(\epsilon)}[f(g_\phi(\epsilon, x))]\) where \(g_\phi\) is deterministic
- We can approximate this expectation with Monte Carlo sampling: sample \(\epsilon\), compute gradients, average
This clever trick enabled practical VAE training and has since become fundamental to probabilistic deep learning.
Loss Function and Training¤
The VAE loss is derived directly from the negative ELBO:
Practical Implementation¤
For Gaussian encoder \(q_\phi(z|x) = \mathcal{N}(\mu, \sigma^2\mathbf{I})\) and standard normal prior \(p(z) = \mathcal{N}(0, \mathbf{I})\):
Reconstruction Loss (assuming Gaussian decoder with fixed variance):
KL Divergence (closed-form for Gaussians):
Total Loss:
Training Algorithm¤
import jax
import jax.numpy as jnp
from flax import nnx
for epoch in epochs:
for batch in dataloader:
# Forward pass
mu, log_var = encoder(batch)
# Reparameterization trick
epsilon = jax.random.normal(rng_key, mu.shape)
z = mu + jnp.exp(0.5 * log_var) * epsilon
# Decode
x_recon = decoder(z)
# Compute losses
recon_loss = jnp.mean((x_recon - batch) ** 2)
kl_loss = -0.5 * jnp.sum(1 + log_var - mu ** 2 - jnp.exp(log_var))
# Total loss
loss = recon_loss + kl_loss
# Gradient update
grads = jax.grad(loss_fn)(params)
optimizer.update(grads)
Key Training Metrics to Monitor¤
- Reconstruction loss: Should decrease steadily (lower = better reconstruction)
- KL divergence: Should stabilize at a positive value (5-20 is typical for well-trained models)
- ELBO: Combination of both, the primary metric
- Per-dimension KL: Helps detect posterior collapse (all values near 0 indicates problem)
VAE Variants¤
β-VAE: Disentangled Representations¤
β-VAE modifies the objective to encourage disentanglement, where individual latent dimensions capture independent factors of variation:
Effect of β:
- β = 1: Standard VAE (no additional emphasis on disentanglement)
- β > 1: Stronger regularization → encourages independent latent dimensions, improves disentanglement, but reduces reconstruction quality
- β < 1: Weaker regularization → better reconstruction, less structured latent space
graph LR
subgraph "β < 1: Reconstruction Focus"
A1[Sharp Images] --> B1[Entangled Latents]
end
subgraph "β = 1: Standard VAE"
A2[Balanced] --> B2[Some Structure]
end
subgraph "β > 1: Disentanglement Focus"
A3[Blurrier Images] --> B3[Disentangled Latents]
end
style A1 fill:#c8e6c9
style A3 fill:#ffccbc
style B3 fill:#c8e6c9
Practical β values: Start with β=1, try β=4-10 for image disentanglement tasks (dSprites, CelebA), use β=0.1-0.5 for text (to avoid posterior collapse).
Applications:
- Interpretable representations for analysis and visualization
- Fair AI by removing sensitive attributes from representations
- Controllable generation by manipulating specific latent factors
Conditional VAE (CVAE)¤
Conditional VAEs incorporate additional information \(y\) (class labels, attributes, text descriptions) to enable controlled generation:
graph TD
A["Input x"] --> E["Encoder"]
B["Condition y<br/>(e.g., class label)"] --> E
E --> C["Latent z"]
C --> D["Decoder"]
B --> D
D --> F["Reconstruction x̂"]
style B fill:#fff9c4
style E fill:#f3e5f5
style D fill:#fff3e0
How conditioning works:
- Concatenation: Append \(y\) to the input before encoding/after sampling before decoding
- Conditional Batch Normalization: Modulate batch norm parameters based on \(y\)
- FiLM (Feature-wise Linear Modulation): Scale and shift features based on \(y\)
Applications:
- Class-conditional generation: Generate specific digit classes in MNIST
- Attribute manipulation: Change hair color, age, expression in face images
- Text-to-image: Generate images matching text descriptions
Vector Quantized VAE (VQ-VAE)¤
VQ-VAE replaces continuous latent representations with discrete codes from a learned codebook:
where \(\mathcal{C} = \{e_1, ..., e_K\}\) is a learned codebook of \(K\) embedding vectors.
graph TD
A["Input x"] --> B["Encoder"]
B --> C["Continuous z_e"]
C --> D["Vector<br/>Quantization"]
E["Learned<br/>Codebook"] --> D
D --> F["Discrete z_q"]
F --> G["Decoder"]
G --> H["Reconstruction x̂"]
style D fill:#ffebee
style E fill:#e1f5ff
VQ-VAE Loss Function:
where \(sg[\cdot]\) is the stop-gradient operator. The three terms are:
- Reconstruction loss: Standard pixel-wise error
- Codebook loss: Updates codebook embeddings via exponential moving average
- Commitment loss: Encourages encoder to commit to codebook entries
Key advantages:
- ✅ No posterior collapse - discrete latents can't collapse to uninformative distributions
- ✅ Better codebook utilization - all codebook entries get used
- ✅ Powerful for hierarchical models - DALL-E, DALL-E 2 use VQ-VAE as foundation
- ✅ Near-GAN quality - produces sharper images than standard VAEs
Applications:
- DALL-E: Text-to-image generation using discrete visual codes
- Jukebox: High-fidelity music generation
- High-resolution image synthesis: VQ-GAN combines VQ-VAE with adversarial training
Training Dynamics and Common Challenges¤
Posterior Collapse¤
What is it?
The encoder learns to ignore the input, producing latent codes that are essentially identical to the prior \(q_\phi(z|x) \approx p(z)\). The decoder learns to generate data without using latent information, defeating the purpose of the model.
How to detect:
- KL divergence ≈ 0 across all dimensions
- Random samples from prior produce diverse outputs, but encoding-decoding produces generic/blurry results
- Reconstructions don't match inputs well despite low reconstruction loss
Why does it happen?
Powerful autoregressive decoders (especially in text VAEs) can model \(p(x)\) without needing latent information. The KL term drives encodings toward the prior, and if the decoder doesn't need \(z\), the KL term wins.
graph TD
A[Strong Decoder] --> B{Can generate<br/>without z?}
B -->|Yes| C[Ignores Latent Code]
B -->|No| D[Uses Latent Code]
C --> E[Posterior Collapse]
D --> F[Healthy Training]
G[KL Annealing] --> D
H[Weak Decoder] --> D
I[Free Bits] --> D
style E fill:#ffccbc
style F fill:#c8e6c9
Solutions ranked by effectiveness:
- KL Annealing (CRITICAL for text): Start with β=0, gradually increase to 1 over 20-40 epochs
- Linear:
β = min(1.0, epoch / 40) -
Cyclical (BEST): Cycle β from 0→1 multiple times during training
-
Free Bits: Only penalize KL when it drops below a threshold per dimension
-
KL_constrained = max(KL_per_dim, λ)where λ=0.5 works well -
β-VAE with β < 1: Reduce KL penalty (β=0.1-0.5 for text)
-
Word Dropout (for text): Randomly replace 25-50% of input words with
<UNK> -
Weakening the Decoder: Use simpler decoder architecture or add noise
Blurry Reconstructions¤
Why it happens:
MSE loss encourages the decoder to output \(\mathbb{E}[x|z]\), the average of all plausible outputs. Averaging sharp images produces blur—this is a fundamental consequence of the Gaussian likelihood assumption, not a bug.
Solutions:
- Perceptual Loss: Replace pixel-wise MSE with VGG/AlexNet feature matching
- Significantly improves sharpness while maintaining structure
-
Used in Deep Feature Consistent VAE (DFC-VAE)
-
Adversarial Training: Add discriminator to penalize unrealistic outputs (VAE-GAN)
- Used in Stable Diffusion's VAE component
-
Combines reconstruction, KL, and adversarial losses
-
Multi-scale SSIM: Structural similarity loss instead of MSE
-
Better captures perceptual quality
-
VQ-VAE: Discrete latents naturally produce sharper outputs
-
Learned Variance: Let decoder predict per-pixel variance instead of fixed σ²
Optimization Challenges¤
NaN losses:
- Check activation functions: ensure Sigmoid on decoder output for [0,1] images
- Add gradient clipping:
grads = jax.tree.map(lambda g: jnp.clip(g, -1.0, 1.0), grads) - Use Softplus for log_var:
log_var = nnx.softplus(log_var_raw) + 1e-6 - Reduce learning rate if gradients explode
Loss not decreasing:
- Verify loss signs: minimize negative ELBO
- Check data normalization: should be [0,1] or [-1,1]
- Ensure encoder-decoder dimension matching
- Monitor gradient norms: should be in range [0.1, 10]
Imbalanced loss terms:
- Reconstruction loss sums over many pixels; KL sums over few latent dimensions
- Solution: normalize by dimension count or manually weight with β
Advanced Topics¤
Hierarchical VAEs¤
Stack multiple layers of latent variables for richer, more structured representations:
Benefits:
- Coarse features (object class) at top levels
- Fine details (texture, color) at lower levels
- Better for complex, high-resolution data
State-of-the-art: NVAE (Vahdat & Kautz, 2020) uses 36 hierarchical groups, first VAE to successfully model 256×256 natural images.
Importance Weighted VAE (IWAE)¤
Use multiple samples to get tighter bounds on the log-likelihood:
With \(K\) samples, IWAE provides a strictly tighter bound than standard VAE (K=1). Typical values: K=5-50.
Normalizing Flow VAE¤
Replace Gaussian posterior with flexible distributions via invertible transformations:
where \(f\) is an invertible function (Real NVP, MAF, IAF, Glow, etc.)
Benefits:
- Arbitrarily complex posterior distributions
- Better approximation of true posterior
- Improved generation quality
Trade-off: Increased computational cost during training
Latent Space Properties and Interpretation¤
Continuity and Interpolation¤
A well-trained VAE has a continuous latent space where:
- Nearby points decode to similar outputs
- Linear interpolation produces smooth transitions
- The space is "covered" - no holes where sampling produces garbage
Testing interpolation:
# Encode two images
z1 = encoder(x1)[0] # Take mean, ignore variance
z2 = encoder(x2)[0]
# Interpolate
alphas = jnp.linspace(0, 1, num=10)
z_interp = [(1-α)*z1 + α*z2 for α in alphas]
# Decode interpolated points
x_interp = [decoder(z) for z in z_interp]
Disentanglement: Independent Factors of Variation¤
In a disentangled representation, each latent dimension captures a single, interpretable factor:
- \(z_1\): Object class (digit identity)
- \(z_2\): Rotation angle
- \(z_3\): Stroke width
- \(z_4\): Position
- ...
graph TD
subgraph "Disentangled Latent Space"
A["z₁: Rotation"] --> E["Decoder"]
B["z₂: Size"] --> E
C["z₃: Color"] --> E
D["z₄: Position"] --> E
end
E --> F["Generated Image"]
subgraph "Entangled Latent Space"
G["z₁: Mixed<br/>(rotation + size)"] --> H["Decoder"]
I["z₂: Mixed<br/>(color + position)"] --> H
end
H --> J["Generated Image"]
style E fill:#c8e6c9
style H fill:#ffccbc
Achieving disentanglement:
- Train with β-VAE (β > 1)
- Use structured datasets (dSprites, 3D shapes)
- Apply supervision or weak supervision
- Consider Factor-VAE or TC-VAE variants
Measuring disentanglement:
- MIG (Mutual Information Gap): Measures how informative each latent is about one specific factor
- SAP (Separated Attribute Predictability): Measures how predictable factors are from individual latents
- DCI (Disentanglement, Completeness, Informativeness): Three-metric framework
Comparing VAEs with Other Generative Models¤
| Aspect | VAE | GAN | Diffusion | Normalizing Flow |
|---|---|---|---|---|
| Likelihood | Lower bound (ELBO) | Implicit | Tractable | Exact |
| Training Stability | Stable | Unstable | Stable | Stable |
| Sample Quality | Good (blurry) | Excellent (sharp) | Excellent | Good |
| Sampling Speed | Fast | Fast | Slow (50-1000 steps) | Fast |
| Latent Space | Structured, smooth | None (no encoder) | Gradual diffusion | Exact bijection |
| Mode Coverage | Excellent | Poor (mode collapse) | Excellent | Excellent |
| Architecture Constraints | Flexible | Flexible | Flexible | Invertible only |
When to Use VAEs¤
VAEs Excel When:
- You need structured latent representations for downstream tasks
- Training stability is more important than peak image quality
- You want both generation and reconstruction capabilities
- Interpretability matters (anomaly detection, representation learning)
- You're working with non-image data (text, graphs, molecules)
Example Applications:
- Medical image anomaly detection via reconstruction error
- Molecular design with controllable chemical properties
- Semi-supervised learning with limited labels
- Data compression and denoising
- Recommendation systems
When to Use GANs¤
GANs Excel When:
- Image quality is paramount (super-resolution, photorealistic faces)
- You don't need an encoder (generation-only tasks)
- You're willing to handle training instability
- Mode coverage isn't critical
Limitations:
- No structured latent space for interpolation/arithmetic
- Training instability (mode collapse, oscillation)
- No reconstruction capability
When to Use Diffusion Models¤
Diffusion Models Excel When:
- You want state-of-the-art quality (DALL-E 2, Imagen, Stable Diffusion)
- Computational cost is acceptable
- You need excellent mode coverage and diversity
Limitations:
- Slow sampling (requires many iterative steps)
- Higher computational cost
- Often combined with VAEs (Latent Diffusion Models)
Practical Implementation Guide¤
Architecture Recommendations¤
For Images (MNIST, CIFAR-10, CelebA):
# Encoder (using Flax NNX)
nnx.Conv(3, 32, kernel_size=(4, 4), strides=2) → nnx.BatchNorm → nnx.relu
nnx.Conv(32, 64, kernel_size=(4, 4), strides=2) → nnx.BatchNorm → nnx.relu
nnx.Conv(64, 128, kernel_size=(4, 4), strides=2) → nnx.BatchNorm → nnx.relu
Flatten → nnx.Linear(latent_dim × 2) → Split into μ and log(σ²)
# Decoder (mirror)
nnx.Linear(latent_dim, 128×4×4) → Reshape
nnx.ConvTranspose(128, 64, kernel_size=(4, 4), strides=2) → nnx.BatchNorm → nnx.relu
nnx.ConvTranspose(64, 32, kernel_size=(4, 4), strides=2) → nnx.BatchNorm → nnx.relu
nnx.ConvTranspose(32, 3, kernel_size=(4, 4), strides=2) → nnx.sigmoid
For Text/Sequential Data:
# Encoder (using Flax NNX)
nnx.Embed(vocab_size, embed_dim) → Bidirectional nnx.LSTM/nnx.GRU (2-3 layers)
→ Take final hidden state → nnx.Linear(latent_dim × 2)
# Decoder
Repeat latent vector for each timestep
→ nnx.LSTM/nnx.GRU → nnx.Linear(vocab_size) → nnx.softmax
Hyperparameter Recommendations¤
Latent Dimensions:
- MNIST (28×28): 2-20 dimensions
- CIFAR-10 (32×32): 128-256 dimensions
- CelebA (64×64): 256-512 dimensions
- Text (sentences): 32-128 dimensions
Learning Rates:
- Simple datasets (MNIST): 1e-3 to 5e-3
- Complex images: 1e-4 to 1e-3
- Text: 5e-4 to 1e-3
- Always use Adam or AdamW optimizer
Batch Sizes:
- 64-128 works well across domains
- Larger batches improve gradient estimates but require more memory
Training Epochs:
- MNIST: 50-100 epochs
- CIFAR-10/CelebA: 100-300 epochs
- Text: 50-200 epochs
Essential Training Techniques¤
- KL Annealing (CRITICAL for text, helpful for images):
# Linear annealing
beta = min(1.0, epoch / 40)
loss = recon_loss + beta * kl_loss
# Cyclical annealing (BEST for NLP)
cycle_length = 10
t = epoch % cycle_length
if t <= 0.5 * cycle_length:
beta = t / (0.5 * cycle_length)
else:
beta = 1.0
- Numerical Stability:
# Use Softplus + epsilon for variance
log_var = nnx.softplus(log_var_raw) + 1e-6
sigma = jnp.sqrt(jnp.exp(log_var))
# Gradient clipping
grads = jax.tree.map(lambda g: jnp.clip(g, -1.0, 1.0), grads)
- Loss Balancing:
# Normalize by dimensions
recon_loss = jnp.mean((x_recon - x) ** 2) # averages over pixels
kl_loss = kl_divergence.mean() # average over batch and dimensions
Summary and Key Takeaways¤
VAEs are powerful generative models that combine deep learning with variational inference to learn structured, interpretable latent representations. Understanding VAEs provides essential foundations for modern generative modeling, from Stable Diffusion's latent space to DALL-E's discrete representations.
Core Principles:
- ELBO objective balances reconstruction quality with latent space structure
- Reparameterization trick enables efficient gradient-based optimization
- Probabilistic framework creates smooth, continuous latent spaces suitable for generation
- Variational inference provides principled approximations to intractable posteriors
Key Variants:
- β-VAE trades reconstruction for disentangled, interpretable representations
- VQ-VAE uses discrete latents for improved quality and codebook learning
- Conditional VAE enables controlled generation with auxiliary information
- Hierarchical VAE captures multi-scale structure in complex data
Best Practices:
- Use KL annealing, especially for text
- Monitor both reconstruction and KL losses during training
- Consider perceptual or adversarial losses for sharper images
- Apply appropriate architecture choices for your data modality
- Start simple, add complexity as needed
Next Steps¤
-
Practical usage guide with implementation examples and training workflows
-
Complete API documentation for VAE, β-VAE, CVAE, and VQ-VAE classes
-
Step-by-step hands-on tutorial: train a VAE on MNIST from scratch
-
Explore hierarchical VAEs, VQ-VAE applications, and multi-modal learning
Additional Readings¤
Seminal Papers (Must Read)¤
Kingma, D. P., & Welling, M. (2013). "Auto-Encoding Variational Bayes"
arXiv:1312.6114
The original VAE paper introducing the framework and reparameterization trick
Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). "Stochastic Backpropagation and Approximate Inference in Deep Generative Models"
arXiv:1401.4082
Independent development of similar ideas with deep latent Gaussian models
Higgins, I., et al. (2017). "β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"
ICLR 2017
Introduces β-VAE for disentangled representations
Van den Oord, A., Vinyals, O., & Kavukcuoglu, K. (2017). "Neural Discrete Representation Learning"
arXiv:1711.00937
VQ-VAE for discrete latent representations
Tutorial Papers and Books¤
Kingma, D. P., & Welling, M. (2019). "An Introduction to Variational Autoencoders"
arXiv:1906.02691
Authoritative modern tutorial by the original authors
Doersch, C. (2016). "Tutorial on Variational Autoencoders"
arXiv:1606.05908
Excellent intuitive introduction with minimal prerequisites
Ghojogh, B., et al. (2021). "Factor Analysis, Probabilistic PCA, Variational Inference, and VAE: Tutorial and Survey"
arXiv:2101.00734
Connects VAEs to classical dimensionality reduction methods
Important VAE Variants¤
Burda, Y., Grosse, R., & Salakhutdinov, R. (2015). "Importance Weighted Autoencoders"
arXiv:1509.00519
Tighter likelihood bounds using importance sampling
Burgess, C. P., et al. (2018). "Understanding Disentangling in β-VAE"
arXiv:1804.03599
Theory and practice of disentanglement in β-VAE
Sønderby, C. K., et al. (2016). "Ladder Variational Autoencoders"
arXiv:1602.02282
Hierarchical VAEs with bidirectional inference
Vahdat, A., & Kautz, J. (2020). "NVAE: A Deep Hierarchical Variational Autoencoder"
arXiv:2007.03898
State-of-the-art deep hierarchical VAE for high-resolution images
Rezende, D., & Mohamed, S. (2015). "Variational Inference with Normalizing Flows"
arXiv:1505.05770
Flexible posterior distributions using invertible transformations
Kingma, D. P., et al. (2016). "Improved Variational Inference with Inverse Autoregressive Flow"
arXiv:1606.04934
Scalable flexible posteriors for complex distributions
Tomczak, J., & Welling, M. (2017). "VAE with a VampPrior"
arXiv:1705.07120
Learned mixture-of-posteriors prior for better modeling
Makhzani, A., et al. (2015). "Adversarial Autoencoders"
arXiv:1511.05644
Combining VAEs with adversarial training
Recent Advances (2023-2025)¤
Sadat, A., et al. (2024). "LiteVAE: Lightweight and Efficient Variational Autoencoders for Latent Diffusion Models"
arXiv:2405.14477
6× parameter reduction using wavelet transforms
Li, Y., et al. (2024). "WF-VAE: Enhancing Video VAE via Wavelet-Driven Energy Flow for Latent Video Diffusion Model"
arXiv:2411.17459
Efficient video VAE with wavelet flows
Kouzelis, T., et al. (2025). "EQ-VAE: Equivariance VAE with Application to Image Generation"
arXiv:2502.09509
Geometric equivariance for improved generation
Anonymous (2025). "Posterior Collapse as a Phase Transition"
arXiv:2510.01621
Statistical physics analysis of posterior collapse
Application Papers¤
Bowman, S. R., et al. (2015). "Generating Sentences from a Continuous Space"
arXiv:1511.06349
VAEs for text generation (pioneering work)
Sohn, K., Lee, H., & Yan, X. (2015). "Learning Structured Output Representation using Deep Conditional Generative Models"
NeurIPS 2015
Conditional VAE framework
Gomez-Bombarelli, R., et al. (2018). "Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules"
ACS Central Science
VAEs for molecular design and drug discovery
Online Resources and Code¤
Lil'Log: From Autoencoder to Beta-VAE
lilianweng.github.io/posts/2018-08-12-vae
Comprehensive blog post with excellent visualizations
Jaan Altosaar's VAE Tutorial
jaan.io/what-is-variational-autoencoder-vae-tutorial
Clear mathematical derivations with intuitive explanations
Pythae: Unifying VAE Framework
github.com/clementchadebec/benchmark_VAE
Production-ready implementations with 15+ VAE variants
AntixK/PyTorch-VAE
github.com/AntixK/PyTorch-VAE
18+ VAE variants trained on CelebA for comparison
Awesome VAEs Collection
github.com/matthewvowels1/Awesome-VAEs
Curated list of ~900 papers on VAEs and disentanglement
Books and Surveys¤
Murphy, K. P. (2022). "Probabilistic Machine Learning: Advanced Topics"
Chapter on variational inference and deep generative models
Comprehensive treatment connecting theory and practice
Foster, D. (2019). "Generative Deep Learning"
O'Reilly book with practical VAE implementations
Covers VAE, GAN, and autoregressive models
Zhang, C., et al. (2021). "An Overview of Variational Autoencoders for Source Separation, Finance, and Bio-Signal Applications"
PMC8774760
Survey of VAE applications across domains
Ready to implement VAEs? Start with the VAE User Guide for practical usage, check the API Reference for complete documentation, or dive into the MNIST Tutorial for hands-on experience!