Skip to content

Simple Diffusion Example¤

Level: Beginner | Runtime: ~30 seconds (CPU) / ~10 seconds (GPU) | Format: Python + Jupyter

Overview¤

This example demonstrates the fundamentals of diffusion models using a simple implementation. It covers the core concepts of the forward diffusion process, reverse denoising, and sample generation.

What You'll Learn¤

  • How to create and configure a basic diffusion model
  • Understanding noise schedules (beta schedules)
  • Forward diffusion process (adding noise)
  • Reverse process (denoising)
  • Generating samples from random noise
  • Visualizing diffusion model outputs

Files¤

Quick Start¤

Run the Python Script¤

# Activate environment
source activate.sh

# Run the example
python examples/generative_models/diffusion/simple_diffusion_example.py

Run the Jupyter Notebook¤

# Activate environment
source activate.sh

# Launch Jupyter
jupyter lab examples/generative_models/diffusion/simple_diffusion_example.ipynb

Key Concepts¤

Forward Diffusion Process¤

The forward process gradually adds Gaussian noise to data according to a variance schedule:

\[q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I)\]

Where \(\beta_t\) is the noise schedule at timestep \(t\).

Reverse Process¤

The model learns to reverse this process, removing noise step by step:

\[p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))\]

Noise Schedules¤

The example demonstrates different beta schedules:

  • Linear schedule: Simple linear increase from \(\beta_{min}\) to \(\beta_{max}\)
  • Cosine schedule: Smoother noise addition following a cosine curve

Code Structure¤

The example is organized into clear sections:

  1. Model Definition: Creating a SimpleDiffusionModel class
  2. Configuration: Setting up model parameters and noise schedule
  3. Model Instantiation: Creating the model with proper RNG handling
  4. Sample Generation: Generating samples from random noise
  5. Visualization: Displaying generated samples

Example Code¤

from flax import nnx
import jax
import jax.numpy as jnp
from workshop.generative_models.core.base import GenerativeModel

# Create RNG
rngs = nnx.Rngs(42)

# Define simple diffusion model
class SimpleDiffusionModel(GenerativeModel):
    def __init__(self, input_dim, timesteps=1000, *, rngs):
        super().__init__()
        self.input_dim = input_dim
        self.timesteps = timesteps

        # Initialize noise schedule (beta values)
        self.betas = jnp.linspace(1e-4, 0.02, timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = jnp.cumprod(self.alphas)

    def __call__(self, x, t, *, rngs=None):
        # Predict noise at timestep t
        # ... model implementation ...
        pass

# Create model
model = SimpleDiffusionModel(
    input_dim=(28, 28, 1),
    timesteps=1000,
    rngs=rngs
)

# Generate samples
samples = model.generate(num_samples=16, rngs=rngs)

Features Demonstrated¤

SimpleDiffusionModel Creation¤

  • Custom model class extending GenerativeModel
  • Proper initialization with RNG handling
  • Beta schedule setup for noise control

Noise Schedule¤

  • Linear schedule implementation
  • Alpha and alpha_bar calculations
  • Understanding variance schedules

Sample Generation¤

  • Starting from random noise
  • Iterative denoising process
  • Controlling generation quality

Visualization¤

  • Displaying generated samples
  • Comparing different timesteps
  • Analyzing generation quality

Experiments to Try¤

  1. Modify the noise schedule: Try different beta ranges or cosine schedules
  2. Change timesteps: Experiment with different numbers of diffusion steps
  3. Vary sample size: Generate different numbers of samples
  4. Add conditioning: Extend to conditional generation
  5. Custom architecture: Implement different denoising networks

Next Steps¤

After understanding this basic example:

  1. DiT Demo: Learn about Diffusion Transformers for more advanced architectures
  2. Training: Implement a full training loop for your own dataset
  3. Advanced Schedules: Explore more sophisticated noise schedules
  4. Conditional Generation: Add class or text conditioning

Troubleshooting¤

Import Errors¤

Make sure you've activated the Workshop environment:

source activate.sh

CUDA Issues¤

If you encounter GPU errors, try running on CPU:

export JAX_PLATFORMS=cpu
python examples/generative_models/diffusion/simple_diffusion_example.py

Memory Issues¤

Reduce the batch size or number of timesteps if you run out of memory.

Additional Resources¤

  • DiT Demo - Advanced diffusion with transformers
  • Simple EBM - Energy-based models with MCMC sampling