Skip to content

Training a Diffusion Model on MNIST¤

Level: Beginner | Runtime: ~30-60 minutes (GPU), ~2-3 hours (CPU) | Format: Python + Jupyter

This tutorial provides a complete, production-ready example of training a DDPM (Denoising Diffusion Probabilistic Model) on the MNIST dataset. By the end, you'll have trained a diffusion model from scratch that generates realistic handwritten digits.

Files¤

Dual-Format Implementation

This example is available in two synchronized formats:

  • Python Script (.py) - For version control, IDE development, and CI/CD integration
  • Jupyter Notebook (.ipynb) - For interactive learning, experimentation, and exploration

Both formats contain identical content and can be used interchangeably. Choose the format that best suits your workflow.

Quick Start¤

# Activate Workshop environment
source activate.sh

# Run the Python script (recommended for first run)
python examples/generative_models/image/diffusion/diffusion_mnist_training.py

# Or launch Jupyter notebook for interactive exploration
jupyter lab examples/generative_models/image/diffusion/diffusion_mnist_training.ipynb

Overview¤

Learning Objectives:

  • Load and preprocess MNIST dataset for diffusion training
  • Configure and create a DDPM model using Workshop APIs
  • Implement a complete training loop with monitoring
  • Generate samples using DDPM (1000 steps) and DDIM (50 steps)
  • Compare sampling speed and quality tradeoffs
  • Save and load model checkpoints
  • Visualize training progress and sample quality

Prerequisites:

  • Basic understanding of neural networks and diffusion models
  • Familiarity with JAX and Flax NNX basics
  • Understanding of denoising diffusion probabilistic models (DDPM)
  • Workshop installed with CUDA support (recommended)

Estimated Time: 45-60 minutes (including training time)

What's Covered¤

  • Data Pipeline


    Loading MNIST, preprocessing to [-1, 1] range, creating data loaders

  • Model Configuration


    Setting up DDPM with 1000 timesteps, linear beta schedule

  • Training Loop


    Complete training with optimizer, learning rate schedule, monitoring

  • Sample Generation


    DDPM (1000 steps) vs DDIM (50 steps, 20x faster)

  • Visualization


    Training curves, progressive denoising, sample quality

  • Model Persistence


    Saving and loading trained model checkpoints

Expected Results:

  • Training time: ~30-60 minutes on GPU (RTX 4090), ~2-3 hours on CPU
  • Final loss: ~0.03-0.05 (2 epochs)
  • Generated samples: Recognizable handwritten digits
  • DDIM speedup: ~20x faster than DDPM

Prerequisites¤

Installation¤

# Install Workshop with CUDA support (recommended)
uv sync --extra cuda-dev

# Or CPU-only
uv sync

Setup and Imports¤

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from flax import nnx
from tqdm import tqdm
import numpy as np

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.diffusion import DDPMModel, DDIMModel
from workshop.generative_models.core.device_manager import DeviceManager

# Set up device
device_manager = DeviceManager()
device = device_manager.get_device()
print(f"Using device: {device}")

# Initialize RNGs
seed = 42
rngs = nnx.Rngs(seed, params=seed+1, noise=seed+2, sample=seed+3, timestep=seed+4)

Data Loading and Preprocessing¤

Load MNIST Dataset¤

def load_mnist_data():
    """Load MNIST dataset.

    Returns:
        train_images: Training images (60000, 28, 28, 1)
        test_images: Test images (10000, 28, 28, 1)
    """
    # Download MNIST using torchvision or tensorflow
    try:
        # Using torchvision
        from torchvision import datasets

        train_dataset = datasets.MNIST(
            root="./data",
            train=True,
            download=True
        )
        test_dataset = datasets.MNIST(
            root="./data",
            train=False,
            download=True
        )

        # Convert to numpy arrays
        train_images = train_dataset.data.numpy()
        test_images = test_dataset.data.numpy()

    except ImportError:
        # Using tensorflow
        import tensorflow as tf

        (train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

    # Add channel dimension
    train_images = train_images[..., np.newaxis]
    test_images = test_images[..., np.newaxis]

    print(f"Train images shape: {train_images.shape}")
    print(f"Test images shape: {test_images.shape}")

    return train_images, test_images

# Load data
train_images, test_images = load_mnist_data()

Preprocess Data¤

def preprocess_mnist(images):
    """Preprocess MNIST images.

    Normalizes to [-1, 1] range as expected by diffusion models.

    Args:
        images: Images in [0, 255] range

    Returns:
        Preprocessed images in [-1, 1] range
    """
    # Convert to float32
    images = images.astype(np.float32)

    # Normalize to [-1, 1]
    images = (images / 127.5) - 1.0

    return images

# Preprocess
train_images = preprocess_mnist(train_images)
test_images = preprocess_mnist(test_images)

print(f"Data range: [{train_images.min():.2f}, {train_images.max():.2f}]")

Create DataLoader¤

class NumpyDataLoader:
    """Simple DataLoader for numpy arrays."""

    def __init__(self, data, batch_size, shuffle=True):
        """Initialize DataLoader.

        Args:
            data: Numpy array of data
            batch_size: Batch size
            shuffle: Whether to shuffle data
        """
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.n_samples = len(data)
        self.n_batches = (self.n_samples + batch_size - 1) // batch_size

    def __iter__(self):
        """Iterate over batches."""
        indices = np.arange(self.n_samples)

        if self.shuffle:
            np.random.shuffle(indices)

        for i in range(self.n_batches):
            batch_indices = indices[i * self.batch_size: (i + 1) * self.batch_size]
            batch = self.data[batch_indices]
            yield jnp.array(batch)

    def __len__(self):
        """Number of batches."""
        return self.n_batches

# Create dataloader
batch_size = 128
train_loader = NumpyDataLoader(train_images, batch_size, shuffle=True)

print(f"Number of batches: {len(train_loader)}")

Visualize Data¤

def visualize_samples(images, title="Samples", n_cols=8):
    """Visualize a grid of images.

    Args:
        images: Images to visualize (N, H, W, C)
        title: Plot title
        n_cols: Number of columns in grid
    """
    n_images = len(images)
    n_rows = (n_images + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
    axes = axes.flatten()

    for i, (ax, img) in enumerate(zip(axes, images)):
        # Denormalize from [-1, 1] to [0, 1]
        img = (img + 1.0) / 2.0
        img = np.clip(img, 0, 1)

        # Display
        ax.imshow(img.squeeze(), cmap="gray")
        ax.axis("off")

    # Hide unused subplots
    for i in range(n_images, len(axes)):
        axes[i].axis("off")

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Visualize some training samples
sample_batch = next(iter(train_loader))
visualize_samples(sample_batch[:16], title="Training Samples")

Model Creation¤

Configure the Model¤

# DDPM configuration
config = ModelConfiguration(
    name="ddpm_mnist",
    model_class="DDPMModel",
    input_dim=(28, 28, 1),  # MNIST dimensions
    parameters={
        "noise_steps": 1000,        # Number of diffusion timesteps
        "beta_start": 1e-4,          # Initial noise level
        "beta_end": 0.02,            # Final noise level
        "beta_schedule": "linear",   # Linear noise schedule
    }
)

print(f"Model configuration:")
print(f"  Name: {config.name}")
print(f"  Input dimension: {config.input_dim}")
print(f"  Noise steps: {config.parameters['noise_steps']}")

Create the Model¤

# Create DDPM model
model = DDPMModel(config, rngs=rngs)

print(f"Model created successfully!")
print(f"Model type: {type(model).__name__}")

# Test forward pass
test_x = jax.random.normal(rngs.sample(), (4, 28, 28, 1))
test_t = jnp.array([100, 200, 300, 400])
test_outputs = model(test_x, test_t, rngs=rngs)

print(f"Test forward pass:")
print(f"  Input shape: {test_x.shape}")
print(f"  Output shape: {test_outputs['predicted_noise'].shape}")

Training Setup¤

Create Optimizer¤

# Learning rate schedule with warmup
warmup_steps = 1000
total_steps = 50000

schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-6,
    peak_value=1e-4,
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps,
    end_value=1e-5
)

# Optimizer with gradient clipping
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip gradients
        optax.adam(schedule)
    )
)

print(f"Optimizer created with warmup schedule")

Training Step¤

@nnx.jit
def train_step(model, optimizer, batch, rngs):
    """Single training step.

    Args:
        model: Diffusion model
        optimizer: Optimizer
        batch: Batch of images
        rngs: Random number generators

    Returns:
        Dictionary of metrics
    """
    def loss_fn(model):
        # Sample random timesteps
        batch_size = batch.shape[0]
        t = jax.random.randint(
            model.rngs.timestep(),
            (batch_size,),
            0,
            config.parameters["noise_steps"]
        )

        # Add noise to images (forward_diffusion returns the noise it used)
        noisy_images, noise = model.forward_diffusion(batch, t)

        # Predict noise
        outputs = model(noisy_images, t)
        predicted_noise = outputs["predicted_noise"]

        # Compute MSE loss (compare to the ACTUAL noise that was used)
        loss = jnp.mean((predicted_noise - noise) ** 2)

        return loss

    # Compute loss and gradients
    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # Update parameters (NEW API: requires model as first argument)
    optimizer.update(model, grads)

    return {"loss": loss}

Training Loop¤

Train the Model¤

# Training configuration
num_epochs = 10
log_interval = 100
sample_interval = 1000

# Training history
history = {
    "loss": [],
    "epoch_loss": []
}

print("Starting training...")

for epoch in range(num_epochs):
    epoch_losses = []

    # Progress bar for this epoch
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for step, batch in enumerate(pbar):
        # Training step
        metrics = train_step(model, optimizer, batch, rngs)

        # Record loss
        loss = float(metrics["loss"])
        epoch_losses.append(loss)
        history["loss"].append(loss)

        # Update progress bar
        pbar.set_postfix({"loss": f"{loss:.4f}"})

        # Log
        if step % log_interval == 0:
            avg_loss = np.mean(epoch_losses[-log_interval:])
            print(f"  Step {step}/{len(train_loader)}, Loss: {avg_loss:.4f}")

        # Generate samples during training
        if step % sample_interval == 0 and step > 0:
            print(f"  Generating samples at step {step}...")
            # Use DDIM for faster sampling during training
            samples = model.sample(n_samples_or_shape=16, scheduler="ddim", steps=50)
            visualize_samples(samples, title=f"Epoch {epoch+1}, Step {step}")

    # Epoch summary
    avg_epoch_loss = np.mean(epoch_losses)
    history["epoch_loss"].append(avg_epoch_loss)
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Average Loss: {avg_epoch_loss:.4f}")
    print()

print("Training complete!")

Plot Training Curve¤

def plot_training_curve(history):
    """Plot training loss curve."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # Plot step loss
    ax1.plot(history["loss"], alpha=0.3)
    ax1.plot(
        np.convolve(history["loss"], np.ones(100)/100, mode="valid"),
        label="Smoothed"
    )
    ax1.set_xlabel("Step")
    ax1.set_ylabel("Loss")
    ax1.set_title("Training Loss (per step)")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot epoch loss
    ax2.plot(history["epoch_loss"], marker="o")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Average Loss")
    ax2.set_title("Training Loss (per epoch)")
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

# Plot training curve
plot_training_curve(history)

Sampling and Generation¤

Basic Sampling (DDPM)¤

print("Generating samples with DDPM (1000 steps)...")

# Generate samples
n_samples = 16
samples_ddpm = model.sample(n_samples_or_shape=n_samples, scheduler="ddpm")

print(f"Generated {n_samples} samples with shape {samples_ddpm.shape}")

# Visualize
visualize_samples(samples_ddpm, title="DDPM Samples (1000 steps)")

Fast Sampling (DDIM)¤

print("Generating samples with DDIM (50 steps)...")

# Generate with DDIM sampling (much faster!)
samples_ddim = model.sample(
    n_samples_or_shape=n_samples,
    scheduler="ddim",
    steps=50,  # Only 50 steps instead of 1000!
    rngs=rngs
)

print(f"Generated {n_samples} samples in 50 steps")

# Visualize
visualize_samples(samples_ddim, title="DDIM Samples (50 steps)")

Compare Sampling Speeds¤

import time

def time_sampling(model, method="ddpm", steps=None, n_trials=3):
    """Time the sampling process."""
    times = []

    for _ in range(n_trials):
        start = time.time()

        if method == "ddpm":
            _ = model.sample(n_samples_or_shape=16, scheduler="ddpm")
        elif method == "ddim":
            _ = model.sample(n_samples_or_shape=16, scheduler="ddim", steps=steps)

        elapsed = time.time() - start
        times.append(elapsed)

    return np.mean(times), np.std(times)

# Time DDPM
ddpm_time, ddpm_std = time_sampling(model, "ddpm")
print(f"DDPM (1000 steps): {ddpm_time:.2f}s ± {ddpm_std:.2f}s")

# Time DDIM
ddim_time, ddim_std = time_sampling(model, "ddim", steps=50)
print(f"DDIM (50 steps): {ddim_time:.2f}s ± {ddim_std:.2f}s")

# Speedup
speedup = ddpm_time / ddim_time
print(f"DDIM is {speedup:.1f}x faster!")

Progressive Sampling (Visualize Denoising)¤

def progressive_sampling(model, n_samples=4, save_every=100):
    """Visualize the progressive denoising process.

    Args:
        model: Diffusion model
        n_samples: Number of samples
        save_every: Save every N steps

    Returns:
        Trajectory of denoising process
    """
    trajectory = []
    shape = model._get_sample_shape()

    # Start from noise
    x = jax.random.normal(rngs.sample(), (n_samples, *shape))
    trajectory.append(x.copy())

    # Denoise step by step
    for t in tqdm(range(model.noise_steps - 1, -1, -1), desc="Denoising"):
        t_batch = jnp.full((n_samples,), t, dtype=jnp.int32)

        # Model prediction
        outputs = model(x, t_batch, rngs=rngs)
        predicted_noise = outputs["predicted_noise"]

        # Denoising step
        x = model.p_sample(predicted_noise, x, t_batch, rngs=rngs)

        # Save
        if t % save_every == 0 or t == 0:
            trajectory.append(x.copy())

    return trajectory

# Generate progressive samples
print("Generating progressive samples...")
trajectory = progressive_sampling(model, n_samples=4, save_every=100)

# Visualize progression
n_steps = len(trajectory)
fig, axes = plt.subplots(4, n_steps, figsize=(n_steps * 2, 8))

for sample_idx in range(4):
    for step_idx, snapshot in enumerate(trajectory):
        ax = axes[sample_idx, step_idx]

        # Get image for this sample
        img = snapshot[sample_idx]

        # Denormalize
        img = (img + 1.0) / 2.0
        img = np.clip(img, 0, 1)

        # Display
        ax.imshow(img.squeeze(), cmap="gray")
        ax.axis("off")

        if sample_idx == 0:
            step = (n_steps - step_idx - 1) * 100
            ax.set_title(f"t={step}", fontsize=10)

plt.suptitle("Progressive Denoising", fontsize=14)
plt.tight_layout()
plt.show()

Evaluation¤

Compute FID Score¤

def compute_inception_features(images):
    """Compute InceptionV3 features for FID.

    Note: This requires a pre-trained InceptionV3 model.
    For this tutorial, we'll use a simplified metric.
    """
    # This would use a pre-trained InceptionV3
    # For now, we'll use simple statistics
    return images.reshape(len(images), -1)

def compute_fid_simplified(real_images, fake_images):
    """Simplified FID computation.

    Uses pixel statistics instead of Inception features.
    For demonstration purposes only.
    """
    # Compute mean and covariance
    mu_real = np.mean(real_images.reshape(len(real_images), -1), axis=0)
    mu_fake = np.mean(fake_images.reshape(len(fake_images), -1), axis=0)

    sigma_real = np.cov(real_images.reshape(len(real_images), -1), rowvar=False)
    sigma_fake = np.cov(fake_images.reshape(len(fake_images), -1), rowvar=False)

    # Compute FID
    diff = mu_real - mu_fake
    fid = np.dot(diff, diff) + np.trace(sigma_real + sigma_fake - 2 * np.sqrt(sigma_real @ sigma_fake))

    return fid

# Generate many samples for evaluation
print("Generating 1000 samples for evaluation...")
eval_samples = []
for _ in tqdm(range(1000 // 16), desc="Generating"):
    batch_samples = model.sample(16, scheduler="ddim", steps=50, rngs=rngs)
    eval_samples.append(np.array(batch_samples))

eval_samples = np.concatenate(eval_samples, axis=0)

# Compute simplified FID
fid = compute_fid_simplified(test_images[:1000], eval_samples)
print(f"Simplified FID Score: {fid:.2f}")

Sample Diversity¤

def compute_diversity(samples):
    """Compute sample diversity using pairwise distances.

    Args:
        samples: Generated samples

    Returns:
        Mean pairwise distance
    """
    flat_samples = samples.reshape(len(samples), -1)

    # Compute pairwise distances
    distances = []
    n_samples = len(flat_samples)

    for i in range(n_samples):
        for j in range(i + 1, n_samples):
            dist = np.linalg.norm(flat_samples[i] - flat_samples[j])
            distances.append(dist)

    return np.mean(distances)

# Compute diversity
diversity = compute_diversity(eval_samples[:100])
print(f"Sample Diversity: {diversity:.2f}")

Advanced Techniques¤

Latent Space Interpolation¤

def interpolate_noise(model, n_steps=10):
    """Interpolate in the noise space.

    Args:
        model: Diffusion model
        n_steps: Number of interpolation steps

    Returns:
        Interpolated samples
    """
    shape = model._get_sample_shape()

    # Generate two random noise vectors
    noise1 = jax.random.normal(rngs.sample(), (1, *shape))
    noise2 = jax.random.normal(rngs.sample(), (1, *shape))

    # Interpolate
    alphas = np.linspace(0, 1, n_steps)
    interpolated = []

    for alpha in tqdm(alphas, desc="Interpolating"):
        # Linear interpolation
        noise = (1 - alpha) * noise1 + alpha * noise2

        # Denoise from this noise
        x = noise.copy()

        for t in range(model.noise_steps - 1, -1, -1):
            t_batch = jnp.full((1,), t, dtype=jnp.int32)
            outputs = model(x, t_batch, rngs=rngs)
            x = model.p_sample(outputs["predicted_noise"], x, t_batch, rngs=rngs)

        interpolated.append(x[0])

    return jnp.stack(interpolated)

# Interpolate
print("Generating interpolation...")
interpolated = interpolate_noise(model, n_steps=10)

# Visualize
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, (ax, img) in enumerate(zip(axes, interpolated)):
    img = (img + 1.0) / 2.0
    ax.imshow(img.squeeze(), cmap="gray")
    ax.axis("off")
    ax.set_title(f"α={i/9:.1f}")

plt.suptitle("Noise Space Interpolation")
plt.tight_layout()
plt.show()

Inpainting¤

def inpaint_image(model, image, mask, n_steps=1000):
    """Inpaint masked regions of an image.

    Args:
        model: Diffusion model
        image: Original image (1, H, W, C)
        mask: Binary mask (1, H, W, 1), 1=inpaint, 0=keep
        n_steps: Number of denoising steps

    Returns:
        Inpainted image
    """
    # Start from noise
    shape = image.shape
    x = jax.random.normal(rngs.sample(), shape)

    # Denoise with constraint
    for t in tqdm(range(n_steps - 1, -1, -1), desc="Inpainting"):
        t_batch = jnp.full((1,), t, dtype=jnp.int32)

        # Predict noise
        outputs = model(x, t_batch, rngs=rngs)
        predicted_noise = outputs["predicted_noise"]

        # Denoising step
        x = model.p_sample(predicted_noise, x, t_batch, rngs=rngs)

        # Replace known regions with noisy version of original
        if t > 0:
            x_noisy_orig, _ = model.forward_diffusion(image, t_batch)
            x = mask * x + (1 - mask) * x_noisy_orig

    return x

# Create a test image and mask
test_image = test_images[0:1]  # Take first test image

# Create mask (remove center region)
mask = np.ones((1, 28, 28, 1))
mask[:, 10:18, 10:18, :] = 0  # Remove center 8x8 region

# Inpaint
print("Inpainting image...")
inpainted = inpaint_image(model, test_image, mask, n_steps=200)

# Visualize
fig, axes = plt.subplots(1, 4, figsize=(12, 3))

# Original
axes[0].imshow((test_image[0, :, :, 0] + 1) / 2, cmap="gray")
axes[0].set_title("Original")
axes[0].axis("off")

# Masked
masked_img = test_image * (1 - mask)
axes[1].imshow((masked_img[0, :, :, 0] + 1) / 2, cmap="gray")
axes[1].set_title("Masked")
axes[1].axis("off")

# Mask
axes[2].imshow(mask[0, :, :, 0], cmap="gray")
axes[2].set_title("Mask")
axes[2].axis("off")

# Inpainted
axes[3].imshow((inpainted[0, :, :, 0] + 1) / 2, cmap="gray")
axes[3].set_title("Inpainted")
axes[3].axis("off")

plt.suptitle("Image Inpainting")
plt.tight_layout()
plt.show()

Saving and Loading¤

Save Model¤

def save_model(model, path="checkpoints/diffusion_mnist.pkl"):
    """Save model checkpoint.

    Args:
        model: Diffusion model
        path: Save path
    """
    import os
    import pickle

    # Create directory
    os.makedirs(os.path.dirname(path), exist_ok=True)

    # Get model state
    state = nnx.state(model)

    # Save
    with open(path, "wb") as f:
        pickle.dump(state, f)

    print(f"Model saved to {path}")

# Save the model
save_model(model, "checkpoints/diffusion_mnist.pkl")

Load Model¤

def load_model(config, path="checkpoints/diffusion_mnist.pkl", rngs=None):
    """Load model checkpoint.

    Args:
        config: Model configuration
        path: Checkpoint path
        rngs: Random number generators

    Returns:
        Loaded model
    """
    import pickle

    # Create model
    model = DDPMModel(config, rngs=rngs)

    # Load state
    with open(path, "rb") as f:
        state = pickle.load(f)

    # Update model
    nnx.update(model, state)

    print(f"Model loaded from {path}")
    return model

# Load the model
loaded_model = load_model(config, "checkpoints/diffusion_mnist.pkl", rngs=rngs)

# Test loaded model
test_samples = loaded_model.sample(n_samples_or_shape=16, scheduler="ddim", steps=50)
visualize_samples(test_samples, title="Samples from Loaded Model")

Troubleshooting¤

Issue 1: Blurry Samples¤

If your samples are blurry:

# Solution 1: Train longer
num_epochs = 20  # Instead of 10

# Solution 2: Lower learning rate
schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-7,
    peak_value=5e-5,  # Lower peak
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps,
)

# Solution 3: Use cosine schedule
config.parameters["beta_schedule"] = "cosine"

Issue 2: Training Instability¤

If training is unstable:

# Solution 1: Stronger gradient clipping
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(0.5),  # Stronger clipping
        optax.adam(schedule)
    )
)

# Solution 2: Reduce learning rate
schedule = optax.constant_schedule(5e-5)

# Solution 3: Reduce batch size
batch_size = 64  # Instead of 128

Issue 3: Out of Memory¤

If you run out of memory:

# Solution 1: Reduce batch size
batch_size = 32

# Solution 2: Generate samples in smaller batches
def generate_many_samples(model, n_total, batch_size=16):
    all_samples = []
    for _ in range(n_total // batch_size):
        batch = model.sample(n_samples_or_shape=batch_size, scheduler="ddim", steps=50)
        all_samples.append(batch)
    return jnp.concatenate(all_samples, axis=0)

# Solution 3: Use DDIM with fewer steps
samples = model.sample(n_samples_or_shape=16, scheduler="ddim", steps=20)

Next Steps and Variations¤

Try Different Architectures¤

# Use DDIM for faster sampling
ddim_config = ModelConfiguration(
    name="ddim_mnist",
    model_class="DDIMModel",
    input_dim=(28, 28, 1),
    parameters={
        "noise_steps": 1000,
        "ddim_steps": 50,
        "eta": 0.0,  # Deterministic
    }
)

ddim_model = DDIMModel(ddim_config, rngs=rngs)

Conditional Generation¤

# Add class conditioning (requires conditional diffusion model)
# This would require modifying the model to accept class labels

# Example usage:
# conditional_model = ConditionalDiffusionModel(config, num_classes=10, rngs=rngs)
# samples = conditional_model.sample(n_samples_or_shape=16, labels=class_labels)

Try on Other Datasets¤

# Fashion-MNIST
from torchvision import datasets

fashion_dataset = datasets.FashionMNIST(root="./data", train=True, download=True)

# CIFAR-10 (requires larger model)
cifar_config = ModelConfiguration(
    name="ddpm_cifar",
    model_class="DDPMModel",
    input_dim=(32, 32, 3),
    parameters={"noise_steps": 1000}
)

Summary¤

In this tutorial, you learned:

Key Achievements:

  1. ✅ Loaded and preprocessed MNIST dataset
  2. ✅ Created and configured a DDPM model
  3. ✅ Trained the model with proper monitoring
  4. ✅ Generated realistic handwritten digits
  5. ✅ Used DDIM for fast sampling (20x speedup)
  6. ✅ Visualized the denoising process
  7. ✅ Evaluated sample quality
  8. ✅ Performed interpolation and inpainting
  9. ✅ Saved and loaded model checkpoints

What You Can Do Next:

  • Experiment with different noise schedules (cosine vs linear)
  • Try larger models with more parameters
  • Add class conditioning for controlled generation
  • Apply to color datasets (Fashion-MNIST, CIFAR-10)
  • Implement advanced sampling techniques
  • Explore latent diffusion for higher resolutions

Complete Code¤

Here's the complete code in one place:

# [Complete code would go here, combining all snippets above]
# See the full tutorial above for the complete implementation

Additional Resources¤