Skip to content

Building Your First VAE from Scratch¤

This comprehensive tutorial walks you through building, training, and evaluating a Variational Autoencoder (VAE) from scratch using Workshop. By the end, you'll have a working VAE trained on MNIST that can generate new digit images.

What You'll Learn¤

  • VAE Architecture


    Understand encoder-decoder structure and latent space representation

  • Data Loading


    Load and preprocess MNIST dataset with JAX-compatible pipelines

  • Training Loop


    Implement complete training with loss monitoring and checkpointing

  • Evaluation


    Generate samples, reconstruct images, and visualize latent space

Prerequisites¤

Time Estimate¤

  • Reading time: 20-30 minutes
  • Hands-on coding: 45-60 minutes
  • Training time: 10-15 minutes (CPU) or 2-3 minutes (GPU)

Part 1: Understanding VAEs¤

What is a VAE?¤

A Variational Autoencoder (VAE) is a generative model that learns to compress data into a lower-dimensional latent space and reconstruct it. Unlike standard autoencoders, VAEs learn a probabilistic latent space, enabling:

  • Generation: Sample new data from learned distribution
  • Interpolation: Smoothly transition between data points
  • Disentanglement: Separate meaningful factors of variation

Architecture Overview¤

graph LR
    A[Input Image<br/>28×28] --> B[Encoder<br/>Neural Network]
    B --> C[Latent Mean μ]
    B --> D[Latent Log-Var log σ²]
    C --> E[Sampling<br/>z ~ N(μ, σ²)]
    D --> E
    E --> F[Decoder<br/>Neural Network]
    F --> G[Reconstructed Image<br/>28×28]

    style A fill:#e1f5ff
    style G fill:#e1f5ff
    style E fill:#fff4e1

Loss Function¤

VAEs optimize the Evidence Lower Bound (ELBO):

\[ \mathcal{L}_{\text{VAE}} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\beta \cdot D_{\text{KL}}(q(z|x) || p(z))}_{\text{KL Divergence}} \]

Where:

  • Reconstruction term: How well we reconstruct input
  • KL divergence term: How close latent distribution is to prior (usually \(\mathcal{N}(0, I)\))
  • β: Weight balancing reconstruction vs. regularization

Part 2: Setting Up Your Environment¤

Step 1: Create Project Directory¤

# Create project structure
mkdir vae_tutorial && cd vae_tutorial
mkdir -p {data,checkpoints,outputs,logs}

# Create main training script
touch train_vae.py

Step 2: Import Dependencies¤

Create train_vae.py and add imports:

"""Train a VAE on MNIST from scratch."""

import os
from pathlib import Path
from typing import Dict, Tuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from flax import nnx
from flax.training import checkpoints
import numpy as np

# Workshop imports
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.vae.base import VAE

# Set random seed for reproducibility
SEED = 42
jax.config.update("jax_platform_name", "cpu")  # Change to "gpu" if available

Part 3: Loading MNIST Data¤

Understanding the Dataset¤

MNIST contains 70,000 grayscale images of handwritten digits:

  • Training set: 60,000 images
  • Test set: 10,000 images
  • Image shape: 28×28 pixels
  • Values: 0-255 (we'll normalize to 0-1)

Implementation¤

def load_mnist_data() -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Load and preprocess MNIST dataset.

    Returns:
        Tuple of (train_images, test_images) as JAX arrays
    """
    try:
        # Try using TensorFlow Datasets (recommended)
        import tensorflow_datasets as tfds

        # Load data
        ds_train = tfds.load('mnist', split='train', as_supervised=True)
        ds_test = tfds.load('mnist', split='test', as_supervised=True)

        # Convert to numpy arrays
        train_images = np.array([np.array(img) for img, _ in tfds.as_numpy(ds_train)])
        test_images = np.array([np.array(img) for img, _ in tfds.as_numpy(ds_test)])

    except ImportError:
        print("TensorFlow Datasets not found. Using sklearn MNIST...")
        # Fallback to sklearn
        from sklearn.datasets import fetch_openml

        mnist = fetch_openml('mnist_784', version=1, as_frame=False)
        data = mnist.data.reshape(-1, 28, 28, 1)

        # Split train/test
        train_images = data[:60000]
        test_images = data[60000:]

    # Normalize to [0, 1]
    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0

    # Ensure correct shape: (N, 28, 28, 1)
    if train_images.ndim == 3:
        train_images = train_images[..., None]
        test_images = test_images[..., None]

    # Convert to JAX arrays
    train_images = jnp.array(train_images)
    test_images = jnp.array(test_images)

    print(f"✓ Loaded MNIST dataset:")
    print(f"  Train: {train_images.shape}")
    print(f"  Test: {test_images.shape}")
    print(f"  Value range: [{train_images.min():.2f}, {train_images.max():.2f}]")

    return train_images, test_images


def create_batches(data: jnp.ndarray, batch_size: int, shuffle: bool = True):
    """Create batches from data.

    Args:
        data: Input data array
        batch_size: Batch size
        shuffle: Whether to shuffle data

    Yields:
        Batches of data
    """
    n_samples = data.shape[0]
    indices = np.arange(n_samples)

    if shuffle:
        np.random.shuffle(indices)

    for start_idx in range(0, n_samples, batch_size):
        batch_indices = indices[start_idx:start_idx + batch_size]
        yield data[batch_indices]

Alternative: Synthetic Data for Quick Testing¤

For rapid iteration during development:

def create_synthetic_mnist(n_train=1000, n_test=200):
    """Create synthetic MNIST-like data for quick testing."""
    key = jax.random.key(42)
    train_key, test_key = jax.random.split(key)

    train_images = jax.random.uniform(train_key, (n_train, 28, 28, 1))
    test_images = jax.random.uniform(test_key, (n_test, 28, 28, 1))

    return train_images, test_images

Part 4: Building the VAE Model¤

Step 1: Define Encoder¤

The encoder maps input images to latent parameters:

class Encoder(nnx.Module):
    """Encoder network for VAE.

    Maps input images to latent distribution parameters (mean and log-variance).
    """

    def __init__(
        self,
        latent_dim: int,
        hidden_dims: list[int] = [256, 128],
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize encoder.

        Args:
            latent_dim: Dimension of latent space
            hidden_dims: Hidden layer dimensions
            rngs: Random number generators
        """
        super().__init__()
        self.latent_dim = latent_dim

        # Convolutional layers for spatial feature extraction
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)
        self.conv3 = nnx.Conv(64, 128, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)

        # Dense layers for distribution parameters
        # After 3 conv layers with stride 2: 28x28 -> 14x14 -> 7x7 -> 3x3
        self.dense1 = nnx.Linear(128 * 3 * 3, hidden_dims[0], rngs=rngs)
        self.dense2 = nnx.Linear(hidden_dims[0], hidden_dims[1], rngs=rngs)

        # Output layers for mean and log-variance
        self.mean_layer = nnx.Linear(hidden_dims[1], latent_dim, rngs=rngs)
        self.logvar_layer = nnx.Linear(hidden_dims[1], latent_dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array]:
        """Encode input to latent parameters.

        Args:
            x: Input images (batch_size, 28, 28, 1)

        Returns:
            Tuple of (mean, log_var) with shape (batch_size, latent_dim)
        """
        # Convolutional feature extraction
        h = nnx.relu(self.conv1(x))
        h = nnx.relu(self.conv2(h))
        h = nnx.relu(self.conv3(h))

        # Flatten
        h = jnp.reshape(h, (h.shape[0], -1))

        # Dense layers
        h = nnx.relu(self.dense1(h))
        h = nnx.relu(self.dense2(h))

        # Output parameters
        mean = self.mean_layer(h)
        log_var = self.logvar_layer(h)

        return mean, log_var

Step 2: Define Decoder¤

The decoder reconstructs images from latent codes:

class Decoder(nnx.Module):
    """Decoder network for VAE.

    Maps latent codes back to reconstructed images.
    """

    def __init__(
        self,
        latent_dim: int,
        hidden_dims: list[int] = [128, 256],
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize decoder.

        Args:
            latent_dim: Dimension of latent space
            hidden_dims: Hidden layer dimensions
            rngs: Random number generators
        """
        super().__init__()

        # Dense layers to expand latent code
        self.dense1 = nnx.Linear(latent_dim, hidden_dims[0], rngs=rngs)
        self.dense2 = nnx.Linear(hidden_dims[0], hidden_dims[1], rngs=rngs)
        self.dense3 = nnx.Linear(hidden_dims[1], 128 * 3 * 3, rngs=rngs)

        # Transposed convolutions for upsampling
        self.conv_transpose1 = nnx.ConvTranspose(
            128, 64, kernel_size=(3, 3), strides=(2, 2), rngs=rngs
        )
        self.conv_transpose2 = nnx.ConvTranspose(
            64, 32, kernel_size=(3, 3), strides=(2, 2), rngs=rngs
        )
        self.conv_transpose3 = nnx.ConvTranspose(
            32, 1, kernel_size=(3, 3), strides=(2, 2), rngs=rngs
        )

    def __call__(self, z: jax.Array) -> jax.Array:
        """Decode latent code to reconstructed image.

        Args:
            z: Latent codes (batch_size, latent_dim)

        Returns:
            Reconstructed images (batch_size, 28, 28, 1)
        """
        # Expand through dense layers
        h = nnx.relu(self.dense1(z))
        h = nnx.relu(self.dense2(h))
        h = nnx.relu(self.dense3(h))

        # Reshape to spatial
        h = jnp.reshape(h, (h.shape[0], 3, 3, 128))

        # Transposed convolutions for upsampling
        h = nnx.relu(self.conv_transpose1(h))  # 3x3 -> 7x7
        h = nnx.relu(self.conv_transpose2(h))  # 7x7 -> 15x15
        h = self.conv_transpose3(h)             # 15x15 -> 31x31

        # Crop to 28x28
        h = h[:, :28, :28, :]

        # Apply sigmoid to get values in [0, 1]
        reconstruction = nnx.sigmoid(h)

        return reconstruction

Step 3: Create Complete VAE¤

Now we combine encoder and decoder using Workshop's VAE base class:

def create_vae_model(
    latent_dim: int = 32,
    beta: float = 1.0,
    rngs: nnx.Rngs = None,
) -> VAE:
    """Create VAE model with encoder and decoder.

    Args:
        latent_dim: Dimension of latent space
        beta: Beta parameter for β-VAE (1.0 = standard VAE)
        rngs: Random number generators

    Returns:
        VAE model instance
    """
    if rngs is None:
        rngs = nnx.Rngs(SEED)

    # Create encoder and decoder
    encoder = Encoder(latent_dim, hidden_dims=[256, 128], rngs=rngs)
    decoder = Decoder(latent_dim, hidden_dims=[128, 256], rngs=rngs)

    # Create VAE using Workshop's base class
    vae = VAE(encoder, decoder, latent_dim, beta=beta, rngs=rngs)

    print(f"✓ Created VAE model:")
    print(f"  Latent dimension: {latent_dim}")
    print(f"  Beta parameter: {beta}")

    return vae

Part 5: Training the VAE¤

Step 1: Define Training Step¤

@nnx.jit
def train_step(
    model: VAE,
    optimizer: nnx.Optimizer,
    batch: jax.Array,
    rng: jax.Array,
) -> Tuple[float, Dict[str, float]]:
    """Single training step.

    Args:
        model: VAE model
        optimizer: NNX optimizer
        batch: Batch of images
        rng: Random key for sampling

    Returns:
        Tuple of (total_loss, loss_dict)
    """
    def loss_fn(model):
        # Forward pass
        sample_rngs = nnx.Rngs(sample=rng)
        outputs = model(batch, rngs=sample_rngs)

        # Compute losses
        loss_dict = model.loss_fn(x=batch, outputs=outputs)
        return loss_dict['total_loss'], loss_dict

    # Compute gradients
    (loss, loss_dict), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

    # Update parameters
    optimizer.update(grads)

    return loss, loss_dict

Step 2: Evaluation Function¤

def evaluate_model(
    model: VAE,
    test_data: jnp.ndarray,
    batch_size: int = 128,
    rng: jax.Array = None,
) -> Dict[str, float]:
    """Evaluate model on test data.

    Args:
        model: VAE model
        test_data: Test dataset
        batch_size: Batch size for evaluation
        rng: Random key

    Returns:
        Dictionary of average metrics
    """
    if rng is None:
        rng = jax.random.key(0)

    total_losses = {'total_loss': 0.0, 'reconstruction_loss': 0.0, 'kl_loss': 0.0}
    n_batches = 0

    for batch in create_batches(test_data, batch_size, shuffle=False):
        # Forward pass
        sample_rngs = nnx.Rngs(sample=rng)
        outputs = model(batch, rngs=sample_rngs)

        # Compute losses
        loss_dict = model.loss_fn(x=batch, outputs=outputs)

        # Accumulate
        for key in total_losses:
            total_losses[key] += loss_dict[key]
        n_batches += 1

        # Update RNG
        rng, _ = jax.random.split(rng)

    # Average losses
    avg_losses = {k: v / n_batches for k, v in total_losses.items()}

    return avg_losses

Step 3: Training Loop¤

def train_vae(
    model: VAE,
    train_data: jnp.ndarray,
    test_data: jnp.ndarray,
    n_epochs: int = 50,
    batch_size: int = 128,
    learning_rate: float = 1e-3,
    checkpoint_dir: str = "checkpoints",
):
    """Train VAE model.

    Args:
        model: VAE model
        train_data: Training dataset
        test_data: Test dataset
        n_epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        checkpoint_dir: Directory to save checkpoints
    """
    # Create optimizer
    optimizer = nnx.Optimizer(model, optax.adam(learning_rate))

    # Training state
    rng = jax.random.key(SEED)
    history = {
        'train_loss': [],
        'train_recon': [],
        'train_kl': [],
        'test_loss': [],
        'test_recon': [],
        'test_kl': [],
    }

    print(f"\n{'='*60}")
    print(f"Training VAE for {n_epochs} epochs")
    print(f"{'='*60}\n")

    best_test_loss = float('inf')

    for epoch in range(n_epochs):
        # Training
        epoch_losses = {'total_loss': 0.0, 'reconstruction_loss': 0.0, 'kl_loss': 0.0}
        n_batches = 0

        for batch in create_batches(train_data, batch_size, shuffle=True):
            # Training step
            rng, step_rng = jax.random.split(rng)
            loss, loss_dict = train_step(model, optimizer, batch, step_rng)

            # Accumulate losses
            for key in epoch_losses:
                epoch_losses[key] += loss_dict[key]
            n_batches += 1

        # Average training losses
        avg_train_losses = {k: v / n_batches for k, v in epoch_losses.items()}

        # Evaluation
        rng, eval_rng = jax.random.split(rng)
        avg_test_losses = evaluate_model(model, test_data, batch_size, eval_rng)

        # Store history
        history['train_loss'].append(float(avg_train_losses['total_loss']))
        history['train_recon'].append(float(avg_train_losses['reconstruction_loss']))
        history['train_kl'].append(float(avg_train_losses['kl_loss']))
        history['test_loss'].append(float(avg_test_losses['total_loss']))
        history['test_recon'].append(float(avg_test_losses['reconstruction_loss']))
        history['test_kl'].append(float(avg_test_losses['kl_loss']))

        # Print progress
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{n_epochs} | "
                  f"Train Loss: {avg_train_losses['total_loss']:.4f} "
                  f"(Recon: {avg_train_losses['reconstruction_loss']:.4f}, "
                  f"KL: {avg_train_losses['kl_loss']:.4f}) | "
                  f"Test Loss: {avg_test_losses['total_loss']:.4f}")

        # Save best model
        if avg_test_losses['total_loss'] < best_test_loss:
            best_test_loss = avg_test_losses['total_loss']
            save_checkpoint(model, checkpoint_dir, epoch, is_best=True)

        # Save regular checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model, checkpoint_dir, epoch)

    print(f"\n✓ Training complete!")
    print(f"  Best test loss: {best_test_loss:.4f}")

    return history


def save_checkpoint(model: VAE, checkpoint_dir: str, epoch: int, is_best: bool = False):
    """Save model checkpoint."""
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Get model state
    state = nnx.state(model)

    # Save checkpoint
    if is_best:
        path = os.path.join(checkpoint_dir, "best_model")
    else:
        path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}")

    checkpoints.save_checkpoint(checkpoint_dir, state, epoch, prefix=path)

Part 6: Visualization and Analysis¤

Visualize Training Progress¤

def plot_training_history(history: Dict[str, list], save_path: str = "outputs/training_history.png"):
    """Plot training and validation losses."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    epochs = range(1, len(history['train_loss']) + 1)

    # Total loss
    axes[0].plot(epochs, history['train_loss'], label='Train', linewidth=2)
    axes[0].plot(epochs, history['test_loss'], label='Test', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Total Loss')
    axes[0].set_title('Total Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Reconstruction loss
    axes[1].plot(epochs, history['train_recon'], label='Train', linewidth=2)
    axes[1].plot(epochs, history['test_recon'], label='Test', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Reconstruction Loss')
    axes[1].set_title('Reconstruction Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # KL divergence
    axes[2].plot(epochs, history['train_kl'], label='Train', linewidth=2)
    axes[2].plot(epochs, history['test_kl'], label='Test', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('KL Divergence')
    axes[2].set_title('KL Divergence')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved training history to {save_path}")
    plt.close()

Visualize Reconstructions¤

def visualize_reconstructions(
    model: VAE,
    test_data: jnp.ndarray,
    n_samples: int = 10,
    save_path: str = "outputs/reconstructions.png",
):
    """Visualize original vs reconstructed images."""
    # Get random samples
    indices = np.random.choice(len(test_data), n_samples, replace=False)
    samples = test_data[indices]

    # Reconstruct
    rngs = nnx.Rngs(sample=jax.random.key(42))
    reconstructed = model.reconstruct(samples, deterministic=True, rngs=rngs)

    # Plot
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 1.5, 3))

    for i in range(n_samples):
        # Original
        axes[0, i].imshow(samples[i, :, :, 0], cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=10)

        # Reconstructed
        axes[1, i].imshow(reconstructed[i, :, :, 0], cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved reconstructions to {save_path}")
    plt.close()

Generate New Samples¤

def visualize_generated_samples(
    model: VAE,
    n_samples: int = 20,
    save_path: str = "outputs/generated_samples.png",
):
    """Visualize generated samples from random latent codes."""
    # Generate samples
    rngs = nnx.Rngs(sample=jax.random.key(42))
    samples = model.sample(n_samples, rngs=rngs)

    # Plot
    rows = 4
    cols = n_samples // rows
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))

    for idx, ax in enumerate(axes.flat):
        if idx < n_samples:
            ax.imshow(samples[idx, :, :, 0], cmap='gray')
            ax.axis('off')

    plt.suptitle('Generated Samples from VAE', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved generated samples to {save_path}")
    plt.close()

Visualize Latent Space Interpolation¤

def visualize_latent_interpolation(
    model: VAE,
    test_data: jnp.ndarray,
    n_steps: int = 10,
    save_path: str = "outputs/interpolation.png",
):
    """Visualize interpolation between two images in latent space."""
    # Get two random samples
    idx1, idx2 = np.random.choice(len(test_data), 2, replace=False)
    img1 = test_data[idx1]
    img2 = test_data[idx2]

    # Interpolate in latent space
    rngs = nnx.Rngs(sample=jax.random.key(42))
    interpolated = model.interpolate(img1, img2, steps=n_steps, rngs=rngs)

    # Plot
    fig, axes = plt.subplots(1, n_steps, figsize=(n_steps * 1.5, 2))

    for i in range(n_steps):
        axes[i].imshow(interpolated[i, :, :, 0], cmap='gray')
        axes[i].axis('off')

    plt.suptitle('Latent Space Interpolation', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved interpolation to {save_path}")
    plt.close()

Part 7: Putting It All Together¤

Complete Training Script¤

def main():
    """Main training function."""
    print("=" * 60)
    print("VAE Training on MNIST")
    print("=" * 60)

    # Configuration
    LATENT_DIM = 32
    BETA = 1.0
    N_EPOCHS = 50
    BATCH_SIZE = 128
    LEARNING_RATE = 1e-3

    # Load data
    print("\n1. Loading MNIST dataset...")
    train_data, test_data = load_mnist_data()

    # Create model
    print("\n2. Creating VAE model...")
    rngs = nnx.Rngs(SEED)
    model = create_vae_model(LATENT_DIM, BETA, rngs)

    # Train model
    print("\n3. Training model...")
    history = train_vae(
        model,
        train_data,
        test_data,
        n_epochs=N_EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
    )

    # Visualize results
    print("\n4. Generating visualizations...")
    plot_training_history(history)
    visualize_reconstructions(model, test_data)
    visualize_generated_samples(model)
    visualize_latent_interpolation(model, test_data)

    print("\n" + "=" * 60)
    print("Training complete! Check the 'outputs/' directory for results.")
    print("=" * 60)


if __name__ == "__main__":
    main()

Running the Training¤

# Run the training script
python train_vae.py

Expected output:

============================================================
VAE Training on MNIST
============================================================

1. Loading MNIST dataset...
✓ Loaded MNIST dataset:
  Train: (60000, 28, 28, 1)
  Test: (10000, 28, 28, 1)
  Value range: [0.00, 1.00]

2. Creating VAE model...
✓ Created VAE model:
  Latent dimension: 32
  Beta parameter: 1.0

3. Training model...
============================================================
Training VAE for 50 epochs
============================================================

Epoch   1/50 | Train Loss: 0.3124 (Recon: 0.2856, KL: 0.0268) | Test Loss: 0.2891
Epoch   5/50 | Train Loss: 0.2134 (Recon: 0.1923, KL: 0.0211) | Test Loss: 0.2089
Epoch  10/50 | Train Loss: 0.1856 (Recon: 0.1678, KL: 0.0178) | Test Loss: 0.1823
...
Epoch  50/50 | Train Loss: 0.1234 (Recon: 0.1089, KL: 0.0145) | Test Loss: 0.1256

✓ Training complete!
  Best test loss: 0.1256

4. Generating visualizations...
✓ Saved training history to outputs/training_history.png
✓ Saved reconstructions to outputs/reconstructions.png
✓ Saved generated samples to outputs/generated_samples.png
✓ Saved interpolation to outputs/interpolation.png

============================================================
Training complete! Check the 'outputs/' directory for results.
============================================================

Part 8: Common Gotchas and Tips¤

1. KL Divergence Collapse¤

Problem: KL divergence goes to zero, model ignores latent space.

Solutions:

  • Start with lower β (0.1-0.5) and gradually increase (β-annealing)
  • Use free bits: kl_loss = max(kl_loss, free_bits * latent_dim)
  • Monitor both reconstruction and KL losses
# β-annealing implementation
def get_beta(epoch, warmup_epochs=10, final_beta=1.0):
    return min(final_beta, (epoch / warmup_epochs) * final_beta)

2. Blurry Reconstructions¤

Problem: Reconstructed images are blurry.

Solutions:

  • Use perceptual loss instead of MSE
  • Increase decoder capacity
  • Try different reconstruction losses (L1, combined)
# L1 loss for sharper reconstructions
def l1_reconstruction_loss(x, x_recon):
    return jnp.mean(jnp.abs(x - x_recon))

3. Training Instability¤

Problem: Loss spikes or training diverges.

Solutions:

  • Lower learning rate (try 1e-4)
  • Add gradient clipping
  • Use batch normalization in encoder/decoder
# Gradient clipping
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate)
    )
)

4. Memory Issues¤

Problem: Out of memory errors during training.

Solutions:

  • Reduce batch size
  • Use gradient accumulation
  • Enable mixed precision training
# Gradient accumulation
accumulation_steps = 4
for i, batch in enumerate(batches):
    loss, grads = compute_loss_and_grads(batch)
    accumulated_grads = add_grads(accumulated_grads, grads)

    if (i + 1) % accumulation_steps == 0:
        optimizer.update(accumulated_grads)
        accumulated_grads = zero_grads()

5. RNG Management¤

Problem: KeyError or unexpected randomness behavior.

Solutions:

  • Always create fresh nnx.Rngs for sampling
  • Split RNG keys properly
  • Check which keys your model expects
# Correct RNG usage
rng = jax.random.key(42)
for step in range(n_steps):
    rng, step_rng = jax.random.split(rng)
    sample_rngs = nnx.Rngs(sample=step_rng)
    outputs = model(batch, rngs=sample_rngs)

Next Steps¤

Congratulations! You've successfully built, trained, and evaluated a VAE from scratch. Here are some ways to extend your knowledge:

  • Explore Other Models


    Try GANs, Diffusion Models, or Flows

    Model Guides

  • Advanced Training


    Learn distributed training, mixed precision, and optimization

    Training Guide

  • Evaluation & Metrics


    Measure FID, Inception Score, and other metrics

    Evaluation Guide

  • More Examples


    Explore ready-to-run examples for various models

    Examples

Further Reading¤


Questions or Issues? Open an issue on GitHub.