Skip to content

Building Your First VAE from Scratch¤

This comprehensive tutorial walks you through building, training, and evaluating a Variational Autoencoder (VAE) from scratch using Artifex. By the end, you'll have a working VAE trained on MNIST that can generate new digit images.

What You'll Learn¤

  • VAE Architecture


    Understand encoder-decoder structure and latent space representation

  • Data Loading


    Load and preprocess MNIST dataset with JAX-compatible pipelines

  • Training Loop


    Implement complete training with loss monitoring and checkpointing

  • Evaluation


    Generate samples, reconstruct images, and visualize latent space

Prerequisites¤

Time Estimate¤

  • Reading time: 20-30 minutes
  • Hands-on coding: 45-60 minutes
  • Training time: 10-15 minutes (CPU) or 2-3 minutes (GPU)

Part 1: Understanding VAEs¤

What is a VAE?¤

A Variational Autoencoder (VAE) is a generative model that learns to compress data into a lower-dimensional latent space and reconstruct it. Unlike standard autoencoders, VAEs learn a probabilistic latent space, enabling:

  • Generation: Sample new data from learned distribution
  • Interpolation: Smoothly transition between data points
  • Disentanglement: Separate meaningful factors of variation

Architecture Overview¤

graph LR
    A[Input Image<br/>28×28] --> B[Encoder<br/>Neural Network]
    B --> C[Latent Mean μ]
    B --> D[Latent Log-Var log σ²]
    C --> E[Sampling<br/>z ~ N(μ, σ²)]
    D --> E
    E --> F[Decoder<br/>Neural Network]
    F --> G[Reconstructed Image<br/>28×28]

    style A fill:#e1f5ff
    style G fill:#e1f5ff
    style E fill:#fff4e1

Loss Function¤

VAEs optimize the Evidence Lower Bound (ELBO):

\[ \mathcal{L}_{\text{VAE}} = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\beta \cdot D_{\text{KL}}(q(z|x) || p(z))}_{\text{KL Divergence}} \]

Where:

  • Reconstruction term: How well we reconstruct input
  • KL divergence term: How close latent distribution is to prior (usually \(\mathcal{N}(0, I)\))
  • β: Weight balancing reconstruction vs. regularization

Part 2: Setting Up Your Environment¤

Step 1: Create Project Directory¤

# Create project structure
mkdir vae_tutorial && cd vae_tutorial
mkdir -p {data,checkpoints,outputs,logs}

# Create main training script
touch train_vae.py

Step 2: Import Dependencies¤

Create train_vae.py and add imports:

"""Train a VAE on MNIST from scratch."""

import os
from pathlib import Path
from typing import Dict, Tuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from flax import nnx
from flax.training import checkpoints
import numpy as np

# Note: This tutorial builds a VAE from scratch without using Artifex's
# pre-built VAE class. For production use, see the quickstart guide for
# how to use Artifex's VAEConfig-based approach.

# Set random seed for reproducibility
SEED = 42
jax.config.update("jax_platform_name", "cpu")  # Change to "gpu" if available

Part 3: Loading MNIST Data¤

Understanding the Dataset¤

MNIST contains 70,000 grayscale images of handwritten digits:

  • Training set: 60,000 images
  • Test set: 10,000 images
  • Image shape: 28×28 pixels
  • Values: 0-255 (we'll normalize to 0-1)

Implementation¤

def load_mnist_data() -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Load and preprocess MNIST dataset.

    Returns:
        Tuple of (train_images, test_images) as JAX arrays
    """
    try:
        # Try using TensorFlow Datasets (recommended)
        import tensorflow_datasets as tfds

        # Load data
        ds_train = tfds.load('mnist', split='train', as_supervised=True)
        ds_test = tfds.load('mnist', split='test', as_supervised=True)

        # Convert to numpy arrays
        train_images = np.array([np.array(img) for img, _ in tfds.as_numpy(ds_train)])
        test_images = np.array([np.array(img) for img, _ in tfds.as_numpy(ds_test)])

    except ImportError:
        print("TensorFlow Datasets not found. Using sklearn MNIST...")
        # Fallback to sklearn
        from sklearn.datasets import fetch_openml

        mnist = fetch_openml('mnist_784', version=1, as_frame=False)
        data = mnist.data.reshape(-1, 28, 28, 1)

        # Split train/test
        train_images = data[:60000]
        test_images = data[60000:]

    # Normalize to [0, 1]
    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0

    # Ensure correct shape: (N, 28, 28, 1)
    if train_images.ndim == 3:
        train_images = train_images[..., None]
        test_images = test_images[..., None]

    # Convert to JAX arrays
    train_images = jnp.array(train_images)
    test_images = jnp.array(test_images)

    print(f"✓ Loaded MNIST dataset:")
    print(f"  Train: {train_images.shape}")
    print(f"  Test: {test_images.shape}")
    print(f"  Value range: [{train_images.min():.2f}, {train_images.max():.2f}]")

    return train_images, test_images


def create_batches(data: jnp.ndarray, batch_size: int, shuffle: bool = True):
    """Create batches from data.

    Args:
        data: Input data array
        batch_size: Batch size
        shuffle: Whether to shuffle data

    Yields:
        Batches of data
    """
    n_samples = data.shape[0]
    indices = np.arange(n_samples)

    if shuffle:
        np.random.shuffle(indices)

    for start_idx in range(0, n_samples, batch_size):
        batch_indices = indices[start_idx:start_idx + batch_size]
        yield data[batch_indices]

Alternative: Synthetic Data for Quick Testing¤

For rapid iteration during development:

def create_synthetic_mnist(n_train=1000, n_test=200):
    """Create synthetic MNIST-like data for quick testing."""
    key = jax.random.key(42)
    train_key, test_key = jax.random.split(key)

    train_images = jax.random.uniform(train_key, (n_train, 28, 28, 1))
    test_images = jax.random.uniform(test_key, (n_test, 28, 28, 1))

    return train_images, test_images

Part 4: Building the VAE Model¤

Step 1: Define Encoder¤

The encoder maps input images to latent parameters:

class Encoder(nnx.Module):
    """Encoder network for VAE.

    Maps input images to latent distribution parameters (mean and log-variance).
    """

    def __init__(
        self,
        latent_dim: int,
        hidden_dims: list[int] = [256, 128],
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize encoder.

        Args:
            latent_dim: Dimension of latent space
            hidden_dims: Hidden layer dimensions
            rngs: Random number generators
        """
        super().__init__()
        self.latent_dim = latent_dim

        # Convolutional layers for spatial feature extraction
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)
        self.conv3 = nnx.Conv(64, 128, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)

        # Dense layers for distribution parameters
        # After 3 conv layers with stride 2 (SAME padding): 28 -> 14 -> 7 -> 4
        self.dense1 = nnx.Linear(128 * 4 * 4, hidden_dims[0], rngs=rngs)
        self.dense2 = nnx.Linear(hidden_dims[0], hidden_dims[1], rngs=rngs)

        # Output layers for mean and log-variance
        self.mean_layer = nnx.Linear(hidden_dims[1], latent_dim, rngs=rngs)
        self.logvar_layer = nnx.Linear(hidden_dims[1], latent_dim, rngs=rngs)

    def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array]:
        """Encode input to latent parameters.

        Args:
            x: Input images (batch_size, 28, 28, 1)

        Returns:
            Tuple of (mean, log_var) with shape (batch_size, latent_dim)
        """
        # Convolutional feature extraction
        h = nnx.relu(self.conv1(x))
        h = nnx.relu(self.conv2(h))
        h = nnx.relu(self.conv3(h))

        # Flatten
        h = jnp.reshape(h, (h.shape[0], -1))

        # Dense layers
        h = nnx.relu(self.dense1(h))
        h = nnx.relu(self.dense2(h))

        # Output parameters
        mean = self.mean_layer(h)
        log_var = self.logvar_layer(h)

        return mean, log_var

Step 2: Define Decoder¤

The decoder reconstructs images from latent codes:

class Decoder(nnx.Module):
    """Decoder network for VAE.

    Maps latent codes back to reconstructed images.
    """

    def __init__(
        self,
        latent_dim: int,
        hidden_dims: list[int] = [128, 256],
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize decoder.

        Args:
            latent_dim: Dimension of latent space
            hidden_dims: Hidden layer dimensions
            rngs: Random number generators
        """
        super().__init__()

        # Dense layers to expand latent code
        self.dense1 = nnx.Linear(latent_dim, hidden_dims[0], rngs=rngs)
        self.dense2 = nnx.Linear(hidden_dims[0], hidden_dims[1], rngs=rngs)
        self.dense3 = nnx.Linear(hidden_dims[1], 128 * 4 * 4, rngs=rngs)

        # Transposed convolutions for upsampling
        self.conv_transpose1 = nnx.ConvTranspose(
            128, 64, kernel_size=(3, 3), strides=(2, 2), rngs=rngs
        )
        self.conv_transpose2 = nnx.ConvTranspose(
            64, 32, kernel_size=(3, 3), strides=(2, 2), rngs=rngs
        )
        self.conv_transpose3 = nnx.ConvTranspose(
            32, 1, kernel_size=(3, 3), strides=(2, 2), rngs=rngs
        )

    def __call__(self, z: jax.Array) -> jax.Array:
        """Decode latent code to reconstructed image.

        Args:
            z: Latent codes (batch_size, latent_dim)

        Returns:
            Reconstructed images (batch_size, 28, 28, 1)
        """
        # Expand through dense layers
        h = nnx.relu(self.dense1(z))
        h = nnx.relu(self.dense2(h))
        h = nnx.relu(self.dense3(h))

        # Reshape to spatial
        h = jnp.reshape(h, (h.shape[0], 4, 4, 128))

        # Transposed convolutions for upsampling
        h = nnx.relu(self.conv_transpose1(h))  # 3x3 -> 7x7
        h = nnx.relu(self.conv_transpose2(h))  # 7x7 -> 15x15
        h = self.conv_transpose3(h)             # 15x15 -> 31x31

        # Crop to 28x28
        h = h[:, :28, :28, :]

        # Apply sigmoid to get values in [0, 1]
        reconstruction = nnx.sigmoid(h)

        return reconstruction

Step 3: Create Complete VAE¤

Now we combine encoder and decoder with a simple VAE wrapper class:

class SimpleVAE(nnx.Module):
    """Simple VAE combining encoder and decoder.

    This is a standalone implementation for the 'from scratch' tutorial.
    For production use, consider using Artifex's built-in VAE with VAEConfig.
    """

    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        latent_dim: int,
        beta: float = 1.0,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = latent_dim
        self.beta = beta
        self.rngs = rngs

    def reparameterize(self, mean: jax.Array, log_var: jax.Array) -> jax.Array:
        """Apply the reparameterization trick."""
        sample_key = self.rngs.sample()
        std = jnp.exp(0.5 * log_var)
        eps = jax.random.normal(sample_key, mean.shape)
        return mean + eps * std

    def __call__(self, x: jax.Array) -> dict[str, jax.Array]:
        """Forward pass through the VAE."""
        mean, log_var = self.encoder(x)
        z = self.reparameterize(mean, log_var)
        reconstructed = self.decoder(z)
        return {
            "reconstructed": reconstructed,
            "mean": mean,
            "log_var": log_var,
            "z": z,
        }

    def loss_fn(self, x: jax.Array, outputs: dict[str, jax.Array]) -> dict[str, jax.Array]:
        """Compute VAE loss (ELBO)."""
        reconstructed = outputs["reconstructed"]
        mean = outputs["mean"]
        log_var = outputs["log_var"]

        # Reconstruction loss (MSE)
        recon_loss = jnp.mean((x - reconstructed) ** 2)

        # KL divergence: -0.5 * sum(1 + log(σ²) - μ² - σ²)
        kl_loss = -0.5 * jnp.mean(1 + log_var - mean ** 2 - jnp.exp(log_var))

        total_loss = recon_loss + self.beta * kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": recon_loss,
            "kl_loss": kl_loss,
        }

    def sample(self, n_samples: int) -> jax.Array:
        """Generate samples from random latent codes."""
        sample_key = self.rngs.sample()
        z = jax.random.normal(sample_key, (n_samples, self.latent_dim))
        return self.decoder(z)

    def reconstruct(self, x: jax.Array, deterministic: bool = False) -> jax.Array:
        """Reconstruct input images."""
        mean, log_var = self.encoder(x)
        if deterministic:
            z = mean  # Use mean for deterministic reconstruction
        else:
            z = self.reparameterize(mean, log_var)
        return self.decoder(z)

    def interpolate(self, x1: jax.Array, x2: jax.Array, steps: int = 10) -> jax.Array:
        """Interpolate between two images in latent space."""
        # Encode both images
        mean1, _ = self.encoder(x1)
        mean2, _ = self.encoder(x2)

        # Create interpolation weights
        alphas = jnp.linspace(0, 1, steps)

        # Interpolate in latent space
        z_interp = jnp.array([
            (1 - alpha) * mean1[0] + alpha * mean2[0]
            for alpha in alphas
        ])

        # Decode interpolated latents
        return self.decoder(z_interp)


def create_vae_model(
    latent_dim: int = 32,
    beta: float = 1.0,
    rngs: nnx.Rngs = None,
) -> SimpleVAE:
    """Create VAE model with encoder and decoder.

    Args:
        latent_dim: Dimension of latent space
        beta: Beta parameter for β-VAE (1.0 = standard VAE)
        rngs: Random number generators

    Returns:
        SimpleVAE model instance
    """
    if rngs is None:
        rngs = nnx.Rngs(params=0, dropout=1, sample=2)

    # Create encoder and decoder
    encoder = Encoder(latent_dim, hidden_dims=[256, 128], rngs=rngs)
    decoder = Decoder(latent_dim, hidden_dims=[128, 256], rngs=rngs)

    # Create VAE with our custom wrapper
    vae = SimpleVAE(encoder, decoder, latent_dim, beta=beta, rngs=rngs)

    print(f"✓ Created VAE model:")
    print(f"  Latent dimension: {latent_dim}")
    print(f"  Beta parameter: {beta}")

    return vae

Part 5: Training the VAE¤

Step 1: Define Training Step¤

@nnx.jit  # JIT compilation for 10-50x speedup
def train_step(
    model: SimpleVAE,
    optimizer: nnx.Optimizer,
    batch: jax.Array,
) -> Tuple[float, Dict[str, float]]:
    """Single training step (JIT-compiled for performance).

    Args:
        model: VAE model
        optimizer: NNX optimizer
        batch: Batch of images

    Returns:
        Tuple of (total_loss, loss_dict)
    """
    def loss_fn(model):
        # Forward pass - model uses internal RNGs
        outputs = model(batch)

        # Compute losses
        loss_dict = model.loss_fn(x=batch, outputs=outputs)
        return loss_dict['loss'], loss_dict

    # Compute gradients
    (loss, loss_dict), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

    # Update parameters (NNX 0.11.0+ API)
    optimizer.update(model, grads)

    return loss, loss_dict

Step 2: Evaluation Function¤

def evaluate_model(
    model: SimpleVAE,
    test_data: jnp.ndarray,
    batch_size: int = 128,
) -> Dict[str, float]:
    """Evaluate model on test data.

    Args:
        model: VAE model
        test_data: Test dataset
        batch_size: Batch size for evaluation

    Returns:
        Dictionary of average metrics
    """
    total_losses = {'loss': 0.0, 'reconstruction_loss': 0.0, 'kl_loss': 0.0}
    n_batches = 0

    for batch in create_batches(test_data, batch_size, shuffle=False):
        # Forward pass - model uses internal RNGs
        outputs = model(batch)

        # Compute losses
        loss_dict = model.loss_fn(x=batch, outputs=outputs)

        # Accumulate
        for key in total_losses:
            if key in loss_dict:
                total_losses[key] += float(loss_dict[key])
        n_batches += 1

    # Average losses
    avg_losses = {k: v / n_batches for k, v in total_losses.items()}

    return avg_losses

Step 3: Training Loop¤

def train_vae(
    model: SimpleVAE,
    train_data: jnp.ndarray,
    test_data: jnp.ndarray,
    n_epochs: int = 50,
    batch_size: int = 128,
    learning_rate: float = 1e-3,
    checkpoint_dir: str = "checkpoints",
):
    """Train VAE model.

    Args:
        model: VAE model
        train_data: Training dataset
        test_data: Test dataset
        n_epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        checkpoint_dir: Directory to save checkpoints
    """
    # Create optimizer (wrt=nnx.Param required in Flax NNX 0.11.0+)
    optimizer = nnx.Optimizer(model, optax.adam(learning_rate), wrt=nnx.Param)

    # Training state
    history = {
        'train_loss': [],
        'train_recon': [],
        'train_kl': [],
        'test_loss': [],
        'test_recon': [],
        'test_kl': [],
    }

    print(f"\n{'='*60}")
    print(f"Training VAE for {n_epochs} epochs")
    print(f"{'='*60}\n")

    best_test_loss = float('inf')

    for epoch in range(n_epochs):
        # Training
        epoch_losses = {'loss': 0.0, 'reconstruction_loss': 0.0, 'kl_loss': 0.0}
        n_batches = 0

        for batch in create_batches(train_data, batch_size, shuffle=True):
            # Training step (JIT-compiled for speed)
            loss, loss_dict = train_step(model, optimizer, batch)

            # Accumulate losses
            for key in epoch_losses:
                if key in loss_dict:
                    epoch_losses[key] += float(loss_dict[key])
            n_batches += 1

        # Average training losses
        avg_train_losses = {k: v / n_batches for k, v in epoch_losses.items()}

        # Evaluation
        avg_test_losses = evaluate_model(model, test_data, batch_size)

        # Store history
        history['train_loss'].append(float(avg_train_losses['loss']))
        history['train_recon'].append(float(avg_train_losses['reconstruction_loss']))
        history['train_kl'].append(float(avg_train_losses['kl_loss']))
        history['test_loss'].append(float(avg_test_losses['loss']))
        history['test_recon'].append(float(avg_test_losses['reconstruction_loss']))
        history['test_kl'].append(float(avg_test_losses['kl_loss']))

        # Print progress
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{n_epochs} | "
                  f"Train Loss: {avg_train_losses['loss']:.4f} "
                  f"(Recon: {avg_train_losses['reconstruction_loss']:.4f}, "
                  f"KL: {avg_train_losses['kl_loss']:.4f}) | "
                  f"Test Loss: {avg_test_losses['loss']:.4f}")

        # Save best model
        if avg_test_losses['loss'] < best_test_loss:
            best_test_loss = avg_test_losses['loss']
            save_checkpoint(model, checkpoint_dir, epoch, is_best=True)

        # Save regular checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            save_checkpoint(model, checkpoint_dir, epoch)

    print(f"\n✓ Training complete!")
    print(f"  Best test loss: {best_test_loss:.4f}")

    return history


def save_checkpoint(model: SimpleVAE, checkpoint_dir: str, epoch: int, is_best: bool = False):
    """Save model checkpoint."""
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Get model state
    state = nnx.state(model)

    # Save checkpoint
    if is_best:
        path = os.path.join(checkpoint_dir, "best_model")
    else:
        path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}")

    checkpoints.save_checkpoint(checkpoint_dir, state, epoch, prefix=path)

Part 6: Visualization and Analysis¤

Visualize Training Progress¤

def plot_training_history(history: Dict[str, list], save_path: str = "outputs/training_history.png"):
    """Plot training and validation losses."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    epochs = range(1, len(history['train_loss']) + 1)

    # Total loss
    axes[0].plot(epochs, history['train_loss'], label='Train', linewidth=2)
    axes[0].plot(epochs, history['test_loss'], label='Test', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Total Loss')
    axes[0].set_title('Total Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Reconstruction loss
    axes[1].plot(epochs, history['train_recon'], label='Train', linewidth=2)
    axes[1].plot(epochs, history['test_recon'], label='Test', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Reconstruction Loss')
    axes[1].set_title('Reconstruction Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # KL divergence
    axes[2].plot(epochs, history['train_kl'], label='Train', linewidth=2)
    axes[2].plot(epochs, history['test_kl'], label='Test', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('KL Divergence')
    axes[2].set_title('KL Divergence')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved training history to {save_path}")
    plt.close()

Visualize Reconstructions¤

def visualize_reconstructions(
    model: SimpleVAE,
    test_data: jnp.ndarray,
    n_samples: int = 10,
    save_path: str = "outputs/reconstructions.png",
):
    """Visualize original vs reconstructed images."""
    # Get random samples
    indices = np.random.choice(len(test_data), n_samples, replace=False)
    samples = test_data[indices]

    # Reconstruct (model uses internal RNGs)
    reconstructed = model.reconstruct(samples, deterministic=True)

    # Plot
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 1.5, 3))

    for i in range(n_samples):
        # Original
        axes[0, i].imshow(samples[i, :, :, 0], cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=10)

        # Reconstructed
        axes[1, i].imshow(reconstructed[i, :, :, 0], cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=10)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved reconstructions to {save_path}")
    plt.close()

Generate New Samples¤

def visualize_generated_samples(
    model: SimpleVAE,
    n_samples: int = 20,
    save_path: str = "outputs/generated_samples.png",
):
    """Visualize generated samples from random latent codes."""
    # Generate samples (model uses internal RNGs)
    samples = model.sample(n_samples)

    # Plot
    rows = 4
    cols = n_samples // rows
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))

    for idx, ax in enumerate(axes.flat):
        if idx < n_samples:
            ax.imshow(samples[idx, :, :, 0], cmap='gray')
            ax.axis('off')

    plt.suptitle('Generated Samples from VAE', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved generated samples to {save_path}")
    plt.close()

Visualize Latent Space Interpolation¤

def visualize_latent_interpolation(
    model: SimpleVAE,
    test_data: jnp.ndarray,
    n_steps: int = 10,
    save_path: str = "outputs/interpolation.png",
):
    """Visualize interpolation between two images in latent space."""
    # Get two random samples
    idx1, idx2 = np.random.choice(len(test_data), 2, replace=False)
    img1 = test_data[idx1:idx1+1]  # Keep batch dimension
    img2 = test_data[idx2:idx2+1]

    # Interpolate in latent space (model uses internal RNGs)
    interpolated = model.interpolate(img1, img2, steps=n_steps)

    # Plot
    fig, axes = plt.subplots(1, n_steps, figsize=(n_steps * 1.5, 2))

    for i in range(n_steps):
        axes[i].imshow(interpolated[i, :, :, 0], cmap='gray')
        axes[i].axis('off')

    plt.suptitle('Latent Space Interpolation', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved interpolation to {save_path}")
    plt.close()

Part 7: Putting It All Together¤

Complete Training Script¤

def main():
    """Main training function."""
    print("=" * 60)
    print("VAE Training on MNIST")
    print("=" * 60)

    # Configuration
    LATENT_DIM = 32
    BETA = 1.0
    N_EPOCHS = 50
    BATCH_SIZE = 128
    LEARNING_RATE = 1e-3

    # Load data
    print("\n1. Loading MNIST dataset...")
    train_data, test_data = load_mnist_data()

    # Create model
    print("\n2. Creating VAE model...")
    rngs = nnx.Rngs(SEED)
    model = create_vae_model(LATENT_DIM, BETA, rngs)

    # Train model
    print("\n3. Training model...")
    history = train_vae(
        model,
        train_data,
        test_data,
        n_epochs=N_EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
    )

    # Visualize results
    print("\n4. Generating visualizations...")
    plot_training_history(history)
    visualize_reconstructions(model, test_data)
    visualize_generated_samples(model)
    visualize_latent_interpolation(model, test_data)

    print("\n" + "=" * 60)
    print("Training complete! Check the 'outputs/' directory for results.")
    print("=" * 60)


if __name__ == "__main__":
    main()

Running the Training¤

# Run the training script
python train_vae.py

Expected output:

============================================================
VAE Training on MNIST
============================================================

1. Loading MNIST dataset...
✓ Loaded MNIST dataset:
  Train: (60000, 28, 28, 1)
  Test: (10000, 28, 28, 1)
  Value range: [0.00, 1.00]

2. Creating VAE model...
✓ Created VAE model:
  Latent dimension: 32
  Beta parameter: 1.0

3. Training model...
============================================================
Training VAE for 50 epochs
============================================================

Epoch   1/50 | Train Loss: 0.3124 (Recon: 0.2856, KL: 0.0268) | Test Loss: 0.2891
Epoch   5/50 | Train Loss: 0.2134 (Recon: 0.1923, KL: 0.0211) | Test Loss: 0.2089
Epoch  10/50 | Train Loss: 0.1856 (Recon: 0.1678, KL: 0.0178) | Test Loss: 0.1823
...
Epoch  50/50 | Train Loss: 0.1234 (Recon: 0.1089, KL: 0.0145) | Test Loss: 0.1256

✓ Training complete!
  Best test loss: 0.1256

4. Generating visualizations...
✓ Saved training history to outputs/training_history.png
✓ Saved reconstructions to outputs/reconstructions.png
✓ Saved generated samples to outputs/generated_samples.png
✓ Saved interpolation to outputs/interpolation.png

============================================================
Training complete! Check the 'outputs/' directory for results.
============================================================

Part 8: Common Gotchas and Tips¤

1. KL Divergence Collapse¤

Problem: KL divergence goes to zero, model ignores latent space.

Solutions:

  • Start with lower β (0.1-0.5) and gradually increase (β-annealing)
  • Use free bits: kl_loss = max(kl_loss, free_bits * latent_dim)
  • Monitor both reconstruction and KL losses
# β-annealing implementation
def get_beta(epoch, warmup_epochs=10, final_beta=1.0):
    return min(final_beta, (epoch / warmup_epochs) * final_beta)

2. Blurry Reconstructions¤

Problem: Reconstructed images are blurry.

Solutions:

  • Use perceptual loss instead of MSE
  • Increase decoder capacity
  • Try different reconstruction losses (L1, combined)
# L1 loss for sharper reconstructions
def l1_reconstruction_loss(x, x_recon):
    return jnp.mean(jnp.abs(x - x_recon))

3. Training Instability¤

Problem: Loss spikes or training diverges.

Solutions:

  • Lower learning rate (try 1e-4)
  • Add gradient clipping
  • Use batch normalization in encoder/decoder
# Gradient clipping (wrt=nnx.Param required in NNX 0.11.0+)
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate)
    ),
    wrt=nnx.Param
)

4. Memory Issues¤

Problem: Out of memory errors during training.

Solutions:

  • Reduce batch size
  • Use gradient accumulation
  • Enable mixed precision training
# Gradient accumulation
accumulation_steps = 4
for i, batch in enumerate(batches):
    loss, grads = compute_loss_and_grads(batch)
    accumulated_grads = add_grads(accumulated_grads, grads)

    if (i + 1) % accumulation_steps == 0:
        optimizer.update(accumulated_grads)
        accumulated_grads = zero_grads()

5. RNG Management¤

Problem: KeyError or unexpected randomness behavior.

Solutions:

  • Always create fresh nnx.Rngs for sampling
  • Split RNG keys properly
  • Check which keys your model expects
# Correct RNG usage with Flax NNX
# Models manage their own RNGs internally
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
model = VAE(config, rngs=rngs)

# During training/inference, just call the model
for step in range(n_steps):
    outputs = model(batch)  # Model uses internal RNGs

Next Steps¤

Congratulations! You've successfully built, trained, and evaluated a VAE from scratch. Here are some ways to extend your knowledge:

  • Explore Other Models


    Try GANs, Diffusion Models, or Flows

    Model Guides

  • Advanced Training


    Learn distributed training, mixed precision, and optimization

    Training Guide

  • Evaluation & Metrics


    Measure FID, Inception Score, and other metrics

    Evaluation Guide

  • More Examples


    Explore ready-to-run examples for various models

    Examples

Further Reading¤


Questions or Issues? Open an issue on GitHub.