Skip to content

VAE User Guide¤

Complete guide to building, training, and using Variational Autoencoders with Artifex.

Overview¤

This guide covers practical usage of VAEs in Artifex, from basic setup to advanced techniques. You'll learn how to:

  • Configure VAEs


    Set up encoder/decoder architectures and configure hyperparameters

  • Train Models


    Train VAEs with proper loss functions and monitoring

  • Generate Samples


    Sample from the prior and manipulate latent representations

  • Tune & Debug


    Optimize hyperparameters and troubleshoot common issues


Quick Start¤

Basic VAE Example¤

import jax.numpy as jnp
from flax import nnx
from artifex.generative_models.core.configuration.network_configs import (
    EncoderConfig,
    DecoderConfig,
)
from artifex.generative_models.core.configuration.vae_config import VAEConfig
from artifex.generative_models.models.vae import VAE

# Initialize RNGs
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Configuration
latent_dim = 20

# Create encoder config
encoder_config = EncoderConfig(
    name="mlp_encoder",
    hidden_dims=(256, 128),
    latent_dim=latent_dim,
    activation="relu",
    input_shape=(28, 28, 1),  # Image shape
)

# Create decoder config
decoder_config = DecoderConfig(
    name="mlp_decoder",
    hidden_dims=(128, 256),
    output_shape=(28, 28, 1),
    latent_dim=latent_dim,
    activation="relu",
)

# Create VAE config
vae_config = VAEConfig(
    name="basic_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",  # Use MLP encoder/decoder
    kl_weight=1.0,
)

# Create model
vae = VAE(config=vae_config, rngs=rngs)

# Forward pass
x = jnp.ones((32, 28, 28, 1))
outputs = vae(x)  # Model uses internal RNGs

# Get outputs (handle different key names)
reconstructed = outputs.get("reconstructed")
if reconstructed is None:
    reconstructed = outputs["reconstruction"]
latent = outputs.get("z")
if latent is None:
    latent = outputs["latent"]

print(f"Reconstruction shape: {reconstructed.shape}")
print(f"Latent shape: {latent.shape}")

Creating VAE Models¤

Artifex uses a config-based API where you define configurations first, then create models from them. This provides type safety, validation, and easy serialization.

1. Encoder Configurations¤

MLP Encoder (Fully-Connected)¤

Best for tabular data and flattened images:

from artifex.generative_models.core.configuration.network_configs import EncoderConfig
from artifex.generative_models.models.vae.encoders import MLPEncoder

# Define encoder configuration
encoder_config = EncoderConfig(
    name="mlp_encoder",
    hidden_dims=(512, 256, 128),  # Network depth (use tuples)
    latent_dim=32,                 # Latent space dimension
    activation="relu",             # Activation function
    input_shape=(784,),            # Flattened input size
)

# Create encoder from config
encoder = MLPEncoder(config=encoder_config, rngs=rngs)

# Forward pass returns (mean, log_var)
mean, log_var = encoder(x)

CNN Encoder (Convolutional)¤

Best for image data with spatial structure:

from artifex.generative_models.core.configuration.network_configs import EncoderConfig
from artifex.generative_models.models.vae.encoders import CNNEncoder

encoder_config = EncoderConfig(
    name="cnn_encoder",
    hidden_dims=(32, 64, 128, 256),  # Channel progression
    latent_dim=64,
    activation="relu",
    input_shape=(28, 28, 1),          # (H, W, C)
)

encoder = CNNEncoder(config=encoder_config, rngs=rngs)

# Preserves spatial information through convolutions
mean, log_var = encoder(x)

Conditional Encoder¤

Add class conditioning using ConditionalVAEConfig:

from artifex.generative_models.core.configuration.network_configs import EncoderConfig
from artifex.generative_models.core.configuration.vae_config import ConditionalVAEConfig
from artifex.generative_models.models.vae import ConditionalVAE

# For conditional VAEs, use the ConditionalVAEConfig which handles
# label embedding and conditioning automatically

2. Decoder Configurations¤

MLP Decoder¤

from artifex.generative_models.core.configuration.network_configs import DecoderConfig
from artifex.generative_models.models.vae.decoders import MLPDecoder

decoder_config = DecoderConfig(
    name="mlp_decoder",
    hidden_dims=(128, 256, 512),  # Reversed from encoder
    output_shape=(784,),          # Reconstruction size
    latent_dim=32,
    activation="relu",
)

decoder = MLPDecoder(config=decoder_config, rngs=rngs)

reconstructed = decoder(z)  # Returns JAX array

CNN Decoder (Transposed Convolutions)¤

from artifex.generative_models.core.configuration.network_configs import DecoderConfig
from artifex.generative_models.models.vae.decoders import CNNDecoder

decoder_config = DecoderConfig(
    name="cnn_decoder",
    hidden_dims=(256, 128, 64, 32),  # Reversed channel progression
    output_shape=(28, 28, 1),         # Output image shape
    latent_dim=64,
    activation="relu",
)

decoder = CNNDecoder(config=decoder_config, rngs=rngs)

reconstructed = decoder(z)  # Returns (batch, 28, 28, 1)

Conditional Decoder¤

For conditional generation, use the full ConditionalVAE model which handles conditioning in both encoder and decoder:

# See "Conditional VAE" section below for the full config-based approach

3. Complete VAE Models¤

Standard VAE¤

from artifex.generative_models.core.configuration.network_configs import (
    EncoderConfig,
    DecoderConfig,
)
from artifex.generative_models.core.configuration.vae_config import VAEConfig
from artifex.generative_models.models.vae import VAE

# Define configurations
encoder_config = EncoderConfig(
    name="encoder",
    hidden_dims=(256, 128),
    latent_dim=32,
    activation="relu",
    input_shape=(28, 28, 1),
)

decoder_config = DecoderConfig(
    name="decoder",
    hidden_dims=(128, 256),
    output_shape=(28, 28, 1),
    latent_dim=32,
    activation="relu",
)

vae_config = VAEConfig(
    name="standard_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",     # "dense" for MLP, "cnn" for convolutional
    kl_weight=1.0,            # Beta parameter (1.0 = standard VAE)
)

# Create model
vae = VAE(config=vae_config, rngs=rngs)

β-VAE (Disentangled Representations)¤

from artifex.generative_models.core.configuration.vae_config import BetaVAEConfig
from artifex.generative_models.models.vae import BetaVAE

beta_config = BetaVAEConfig(
    name="beta_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    beta_default=4.0,                    # Higher beta = more disentanglement
    beta_warmup_steps=10000,             # Gradual beta annealing
    reconstruction_loss_type="mse",      # "mse" or "bce"
)

beta_vae = BetaVAE(config=beta_config, rngs=rngs)

Conditional VAE¤

from artifex.generative_models.core.configuration.vae_config import ConditionalVAEConfig
from artifex.generative_models.models.vae import ConditionalVAE

cvae_config = ConditionalVAEConfig(
    name="conditional_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    num_classes=10,                      # Number of classes for conditioning
    kl_weight=1.0,
)

cvae = ConditionalVAE(config=cvae_config, rngs=rngs)

# Forward pass with condition (one-hot encoded labels)
labels = jax.nn.one_hot(jnp.array([0, 1, 2]), num_classes=10)
outputs = cvae(x, y=labels)

VQ-VAE (Discrete Latents)¤

from artifex.generative_models.core.configuration.vae_config import VQVAEConfig
from artifex.generative_models.models.vae import VQVAE

vqvae_config = VQVAEConfig(
    name="vqvae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    num_embeddings=512,                  # Codebook size
    embedding_dim=64,                    # Embedding dimension
    commitment_cost=0.25,                # Commitment loss weight
)

vqvae = VQVAE(config=vqvae_config, rngs=rngs)

Training VAEs¤

Basic Training Loop¤

import jax
import jax.numpy as jnp
import optax
from flax import nnx

from artifex.generative_models.core.configuration.network_configs import (
    EncoderConfig,
    DecoderConfig,
)
from artifex.generative_models.core.configuration.vae_config import VAEConfig
from artifex.generative_models.models.vae import VAE

# 1. Create synthetic data (replace with real data loading)
key = jax.random.key(42)
train_data = jax.random.uniform(key, (1000, 28, 28, 1))

# 2. Create model configuration
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

encoder_config = EncoderConfig(
    name="encoder",
    hidden_dims=(256, 128),
    latent_dim=32,
    activation="relu",
    input_shape=(28, 28, 1),
)

decoder_config = DecoderConfig(
    name="decoder",
    hidden_dims=(128, 256),
    output_shape=(28, 28, 1),
    latent_dim=32,
    activation="relu",
)

vae_config = VAEConfig(
    name="mnist_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    kl_weight=1.0,
)

# 3. Initialize model and optimizer
vae = VAE(config=vae_config, rngs=rngs)
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=1e-3), wrt=nnx.Param)

# 4. Training step (JIT-compiled for speed)
@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        outputs = model(batch)  # Model uses internal RNGs
        losses = model.loss_fn(x=batch, outputs=outputs)
        return losses["loss"], losses

    (loss, losses), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)

    return losses

# 5. Training loop
batch_size = 32
num_epochs = 5

for epoch in range(num_epochs):
    # Simple batching
    for i in range(0, len(train_data), batch_size):
        batch = train_data[i : i + batch_size]
        losses = train_step(vae, optimizer, batch)

    print(f"Epoch {epoch + 1} | Loss: {losses['loss']:.4f}")

Training β-VAE with Annealing¤

from artifex.generative_models.core.configuration.vae_config import BetaVAEConfig
from artifex.generative_models.models.vae import BetaVAE

beta_config = BetaVAEConfig(
    name="beta_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    beta_default=4.0,
    beta_warmup_steps=10000,
    reconstruction_loss_type="mse",
)

beta_vae = BetaVAE(config=beta_config, rngs=rngs)
optimizer = nnx.Optimizer(beta_vae, optax.adam(learning_rate=1e-3), wrt=nnx.Param)

step = 0
for epoch in range(num_epochs):
    for i in range(0, len(train_data), batch_size):
        batch = train_data[i : i + batch_size]

        def loss_fn(model):
            outputs = model(batch)
            # Pass current step for beta annealing
            losses = model.loss_fn(x=batch, outputs=outputs, step=step)
            return losses["loss"], losses

        (loss, losses), grads = nnx.value_and_grad(loss_fn, has_aux=True)(beta_vae)
        optimizer.update(beta_vae, grads)
        step += 1

    print(f"Epoch {epoch + 1}, Beta: {losses.get('beta', 1.0):.4f}")

Training Conditional VAE¤

from artifex.generative_models.core.configuration.vae_config import ConditionalVAEConfig
from artifex.generative_models.models.vae import ConditionalVAE

cvae_config = ConditionalVAEConfig(
    name="conditional_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    num_classes=10,
    kl_weight=1.0,
)

cvae = ConditionalVAE(config=cvae_config, rngs=rngs)
optimizer = nnx.Optimizer(cvae, optax.adam(learning_rate=1e-3), wrt=nnx.Param)

# Create synthetic labels (replace with real labels)
train_labels = jax.random.randint(jax.random.key(0), (1000,), 0, 10)
train_labels_onehot = jax.nn.one_hot(train_labels, num_classes=10)

for epoch in range(num_epochs):
    for i in range(0, len(train_data), batch_size):
        batch_x = train_data[i : i + batch_size]
        batch_y = train_labels_onehot[i : i + batch_size]

        def loss_fn(model):
            outputs = model(batch_x, y=batch_y)  # Condition on labels
            losses = model.loss_fn(x=batch_x, outputs=outputs)
            return losses["loss"], losses

        (loss, losses), grads = nnx.value_and_grad(loss_fn, has_aux=True)(cvae)
        optimizer.update(cvae, grads)

    print(f"Epoch {epoch + 1} | Loss: {losses['loss']:.4f}")

Training VQ-VAE¤

from artifex.generative_models.core.configuration.vae_config import VQVAEConfig
from artifex.generative_models.models.vae import VQVAE

vqvae_config = VQVAEConfig(
    name="vqvae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    num_embeddings=512,
    embedding_dim=64,
    commitment_cost=0.25,
)

vqvae = VQVAE(config=vqvae_config, rngs=rngs)
optimizer = nnx.Optimizer(vqvae, optax.adam(learning_rate=1e-3), wrt=nnx.Param)

for epoch in range(num_epochs):
    for i in range(0, len(train_data), batch_size):
        batch = train_data[i : i + batch_size]

        def loss_fn(model):
            outputs = model(batch)
            losses = model.loss_fn(x=batch, outputs=outputs)
            return losses["loss"], losses

        (loss, losses), grads = nnx.value_and_grad(loss_fn, has_aux=True)(vqvae)
        optimizer.update(vqvae, grads)

    # VQ-VAE specific metrics
    print(f"Epoch {epoch + 1} | Recon: {losses.get('reconstruction_loss', 0.0):.4f}")

Generating and Sampling¤

Generate New Samples¤

# Sample from prior distribution
n_samples = 16
samples = vae.sample(n_samples, temperature=1.0)

# Temperature controls diversity
hot_samples = vae.sample(n_samples, temperature=2.0)   # More diverse
cold_samples = vae.sample(n_samples, temperature=0.5)  # More focused

# Using generate() method (alias for sample)
samples = vae.generate(n_samples, temperature=1.0)

Conditional Generation¤

# Generate samples for specific classes
target_classes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])  # One of each digit
labels = jax.nn.one_hot(target_classes, num_classes=10)

samples = cvae.sample(n_samples=10, y=labels, temperature=1.0)

Reconstruction¤

# Stochastic reconstruction (uses internal RNGs)
reconstructed = vae.reconstruct(x, deterministic=False)

# Deterministic reconstruction (use mean of latent distribution)
deterministic_recon = vae.reconstruct(x, deterministic=True)

Latent Space Manipulation¤

Interpolation Between Images¤

# Linear interpolation in latent space
x1 = test_images[0:1]  # First image (keep batch dim)
x2 = test_images[1:2]  # Second image

interpolated = vae.interpolate(
    x1=x1,
    x2=x2,
    steps=10,  # Number of interpolation steps
)

# interpolated.shape = (10, *input_shape)

Latent Traversal (Disentanglement Analysis)¤

# Traverse a single latent dimension
x = test_images[0:1]
dim_to_traverse = 3  # Which latent dimension to vary

traversal = vae.latent_traversal(
    x=x,
    dim=dim_to_traverse,
    range_vals=(-3.0, 3.0),  # Range of values
    steps=10,                 # Number of steps
)

# traversal.shape = (10, *input_shape)

Manual Latent Manipulation¤

# Encode image to latent space
mean, log_var = vae.encode(x)

# Manipulate specific dimensions
modified_mean = mean.at[:, 5].set(2.0)    # Increase dimension 5
modified_mean = modified_mean.at[:, 10].set(-1.5)  # Decrease dimension 10

# Decode modified latent
modified_image = vae.decode(modified_mean)

Evaluation and Analysis¤

Reconstruction Quality¤

# Calculate reconstruction error
test_batch = test_images[:100]
reconstructed = vae.reconstruct(test_batch, deterministic=True)

mse = jnp.mean((test_batch - reconstructed) ** 2)
print(f"Reconstruction MSE: {mse:.4f}")

ELBO (Evidence Lower Bound)¤

# Full ELBO calculation
outputs = vae(test_batch)  # Model uses internal RNGs
losses = vae.loss_fn(x=test_batch, outputs=outputs)

elbo = -(losses['reconstruction_loss'] + losses['kl_loss'])
print(f"ELBO: {elbo:.4f}")

Latent Space Statistics¤

# Encode test set
all_means = []
all_logvars = []

batch_size = 32
for i in range(0, len(test_images), batch_size):
    batch = test_images[i : i + batch_size]
    mean, log_var = vae.encode(batch)  # Uses internal RNGs
    all_means.append(mean)
    all_logvars.append(log_var)

all_means = jnp.concatenate(all_means, axis=0)
all_logvars = jnp.concatenate(all_logvars, axis=0)

# Statistics per dimension
mean_per_dim = jnp.mean(all_means, axis=0)
std_per_dim = jnp.std(all_means, axis=0)
variance_per_dim = jnp.exp(jnp.mean(all_logvars, axis=0))

print(f"Latent mean: {mean_per_dim}")
print(f"Latent std: {std_per_dim}")
print(f"Average variance: {variance_per_dim}")

Disentanglement Metrics¤

# Per-dimension KL divergence (detect posterior collapse)
def per_dim_kl(mean, log_var):
    """Calculate KL divergence per dimension."""
    kl_per_dim = -0.5 * (1 + log_var - mean**2 - jnp.exp(log_var))
    return jnp.mean(kl_per_dim, axis=0)

kl_per_dimension = per_dim_kl(all_means, all_logvars)

# Dimensions with very low KL likely collapsed
inactive_dims = jnp.sum(kl_per_dimension < 0.01)
print(f"Inactive dimensions: {inactive_dims}/{vae.latent_dim}")

Hyperparameter Tuning¤

Key Hyperparameters¤

# Architecture
config = {
    # Network architecture
    "latent_dim": 64,              # 10-100 for images, 2-20 for simple data
    "hidden_dims": [512, 256, 128], # Deeper for complex data
    "activation": "relu",          # or "gelu", "swish"

    # Training
    "learning_rate": 1e-3,         # 1e-4 to 1e-3 typical
    "batch_size": 128,             # Larger is more stable
    "num_epochs": 100,

    # VAE-specific
    "kl_weight": 1.0,              # Beta parameter
    "reconstruction_loss": "mse",  # "mse" or "bce"
}

Beta Tuning for β-VAE¤

from artifex.generative_models.core.configuration.vae_config import BetaVAEConfig
from artifex.generative_models.models.vae import BetaVAE

# Grid search over beta values
beta_values = [0.5, 1.0, 2.0, 4.0, 8.0]
results = {}

for beta in beta_values:
    # Create config with different beta
    beta_config = BetaVAEConfig(
        name=f"beta_vae_{beta}",
        encoder=encoder_config,
        decoder=decoder_config,
        encoder_type="dense",
        beta_default=beta,
    )

    rngs = nnx.Rngs(params=0, dropout=1, sample=2)
    model = BetaVAE(config=beta_config, rngs=rngs)

    # Train and evaluate (implement your train/evaluate functions)
    # trained_model = train(model, train_data, num_epochs=50)
    # recon_error = evaluate_reconstruction(trained_model, test_data)

    results[beta] = {"beta": beta}

# Find best trade-off
print(results)

Learning Rate Scheduling¤

import optax

# Cosine decay schedule
schedule = optax.cosine_decay_schedule(
    init_value=1e-3,
    decay_steps=num_train_steps,
    alpha=0.1,  # Final learning rate = 0.1 * init_value
)

optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=schedule), wrt=nnx.Param)

Common Issues and Solutions¤

Problem 1: Posterior Collapse¤

Symptoms: KL divergence near zero, poor generation quality

Solutions:

from artifex.generative_models.core.configuration.vae_config import BetaVAEConfig
from artifex.generative_models.models.vae import BetaVAE

# Solution 1: Beta annealing - start with β=0, gradually increase
beta_config = BetaVAEConfig(
    name="beta_annealing_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    beta_default=1.0,
    beta_warmup_steps=10000,  # Gradual warmup
)
beta_vae = BetaVAE(config=beta_config, rngs=rngs)

# Solution 2: Weaker decoder (make it harder to ignore latent)
# Use smaller hidden_dims in decoder than encoder
weak_decoder_config = DecoderConfig(
    name="weak_decoder",
    hidden_dims=(64, 128),  # Smaller than encoder
    output_shape=(28, 28, 1),
    latent_dim=32,
    activation="relu",
)

Problem 2: Blurry Reconstructions¤

Symptoms: Overly smooth outputs, lack of detail

Solutions:

# Solution 1: Lower kl_weight (emphasize reconstruction)
vae_config = VAEConfig(
    name="sharp_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    kl_weight=0.5,  # Lower than default 1.0
)
vae = VAE(config=vae_config, rngs=rngs)

# Solution 2: Use VQ-VAE (discrete latents often produce sharper outputs)
from artifex.generative_models.core.configuration.vae_config import VQVAEConfig
from artifex.generative_models.models.vae import VQVAE

vqvae_config = VQVAEConfig(
    name="vqvae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    num_embeddings=512,
    embedding_dim=64,
)
vqvae = VQVAE(config=vqvae_config, rngs=rngs)

Problem 3: Unstable Training¤

Symptoms: Loss oscillations, NaN values

Solutions:

# Solution 1: Gradient clipping
import optax

optimizer = nnx.Optimizer(
    vae,
    optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip gradients
        optax.adam(learning_rate=1e-3),
    ),
    wrt=nnx.Param
)

# Solution 2: Lower learning rate
optimizer = nnx.Optimizer(vae, optax.adam(learning_rate=1e-4), wrt=nnx.Param)

# Solution 3: Batch normalization in encoder/decoder
# (implement custom encoder/decoder with normalization)

Problem 4: Poor Disentanglement¤

Symptoms: Latent dimensions don't correspond to interpretable factors

Solutions:

# Solution 1: Increase beta for more disentanglement
from artifex.generative_models.core.configuration.vae_config import BetaVAEConfig
from artifex.generative_models.models.vae import BetaVAE

beta_config = BetaVAEConfig(
    name="high_beta_vae",
    encoder=encoder_config,
    decoder=decoder_config,
    encoder_type="dense",
    beta_default=4.0,  # Higher beta encourages disentanglement
)
beta_vae = BetaVAE(config=beta_config, rngs=rngs)

# Solution 2: More latent dimensions - give model more capacity
# Update encoder_config with larger latent_dim
encoder_config_large = EncoderConfig(
    name="encoder_large_latent",
    hidden_dims=(256, 128),
    latent_dim=128,  # Increased from 32
    activation="relu",
    input_shape=(28, 28, 1),
)

Advanced Techniques¤

Custom Loss Functions¤

def custom_loss_fn(x_true, x_pred):
    """Custom reconstruction loss combining multiple terms."""
    # L1 loss for sparsity
    l1_loss = jnp.mean(jnp.abs(x_true - x_pred))

    # L2 loss for overall quality
    l2_loss = jnp.mean((x_true - x_pred) ** 2)

    # Combine
    return 0.5 * l1_loss + 0.5 * l2_loss

# Use in training
losses = vae.loss_fn(
    x=batch,
    outputs=outputs,
    reconstruction_loss_fn=custom_loss_fn,
)

Multi-GPU Training¤

import jax
from jax import devices

# Check available devices
print(f"Available devices: {jax.devices()}")

# For multi-GPU training, consider using JAX's sharding API
# or Artifex's distributed training utilities
from artifex.generative_models.training.distributed import DistributedTrainer

Checkpointing¤

import orbax.checkpoint as ocp

# Create checkpointer
checkpointer = ocp.StandardCheckpointer()

# Save model state
graphdef, state = nnx.split(vae)
checkpointer.save("/tmp/vae_checkpoint", state)

# Load model state
restored_state = checkpointer.restore("/tmp/vae_checkpoint")

# Create new model and merge state
new_vae = VAE(config=vae_config, rngs=nnx.Rngs(0))
_, new_state = nnx.split(new_vae)
# Merge restored state into new model
nnx.update(new_vae, restored_state)

Best Practices¤

DO ✅¤

  • Start simple: Begin with standard VAE before trying variants
  • Monitor both losses: Track reconstruction AND KL divergence
  • Use appropriate loss: MSE for continuous, BCE for binary data
  • Visualize latent space: Plot 2D projections to check structure
  • Test interpolation: Smooth interpolation indicates good latent space
  • Check per-dim KL: Detect posterior collapse early
  • Use beta annealing: Helps avoid posterior collapse
  • Larger batch size: More stable training (128+ recommended)

DON'T ❌¤

  • Don't ignore KL: Zero KL means model ignores latent code
  • Don't use too small latent: Leads to underfitting
  • Don't overtrain: Can lead to posterior collapse
  • Don't skip validation: Regular evaluation prevents surprises
  • Don't forget temperature: Use temperature for diverse sampling
  • Don't compare different betas directly: Higher beta trades reconstruction for disentanglement

Performance Tips¤

Memory Optimization¤

# Use gradient checkpointing for large models
from jax import checkpoint

@checkpoint
def encoder_forward(encoder, x):
    return encoder(x)

# Use lower precision for faster training
# Set precision in config or at JAX level
import jax
jax.config.update("jax_default_matmul_precision", "float32")  # or "bfloat16"

Speed Optimization¤

# JIT compile training step
@nnx.jit
def fast_train_step(model, optimizer, batch):
    def loss_fn(model):
        outputs = model(batch)  # Model uses internal RNGs
        losses = model.loss_fn(x=batch, outputs=outputs)
        return losses["loss"], losses

    (loss, losses), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    return losses

# Vectorize sampling
vmapped_decode = jax.vmap(lambda z: vae.decode(z))
samples = vmapped_decode(latent_vectors)

Summary¤

This guide covered:

  • ✅ Creating encoders, decoders, and VAE models
  • ✅ Training standard VAE, β-VAE, CVAE, and VQ-VAE
  • ✅ Generating samples and manipulating latent space
  • ✅ Evaluation metrics and diagnostics
  • ✅ Hyperparameter tuning strategies
  • ✅ Troubleshooting common issues
  • ✅ Advanced techniques and optimizations

Next Steps¤