Simple Diffusion Example¤
Level: Beginner | Runtime: ~30 seconds (CPU) / ~10 seconds (GPU) | Format: Python + Jupyter
Overview¤
This example demonstrates the fundamentals of diffusion models using a simple implementation. It covers the core concepts of the forward diffusion process, reverse denoising, and sample generation.
What You'll Learn¤
- How to create and configure a basic diffusion model
- Understanding noise schedules (beta schedules)
- Forward diffusion process (adding noise)
- Reverse process (denoising)
- Generating samples from random noise
- Visualizing diffusion model outputs
Files¤
- Python Script:
examples/generative_models/diffusion/simple_diffusion_example.py - Jupyter Notebook:
examples/generative_models/diffusion/simple_diffusion_example.ipynb
Quick Start¤
Run the Python Script¤
# Activate environment
source activate.sh
# Run the example
python examples/generative_models/diffusion/simple_diffusion_example.py
Run the Jupyter Notebook¤
# Activate environment
source activate.sh
# Launch Jupyter
jupyter lab examples/generative_models/diffusion/simple_diffusion_example.ipynb
Key Concepts¤
Forward Diffusion Process¤
The forward process gradually adds Gaussian noise to data according to a variance schedule:
Where \(\beta_t\) is the noise schedule at timestep \(t\).
Reverse Process¤
The model learns to reverse this process, removing noise step by step:
Noise Schedules¤
The example demonstrates different beta schedules:
- Linear schedule: Simple linear increase from \(\beta_{min}\) to \(\beta_{max}\)
- Cosine schedule: Smoother noise addition following a cosine curve
Code Structure¤
The example is organized into clear sections:
- Model Definition: Creating a
SimpleDiffusionModelclass - Configuration: Setting up model parameters and noise schedule
- Model Instantiation: Creating the model with proper RNG handling
- Sample Generation: Generating samples from random noise
- Visualization: Displaying generated samples
Example Code¤
from flax import nnx
import jax
import jax.numpy as jnp
from workshop.generative_models.core.base import GenerativeModel
# Create RNG
rngs = nnx.Rngs(42)
# Define simple diffusion model
class SimpleDiffusionModel(GenerativeModel):
def __init__(self, input_dim, timesteps=1000, *, rngs):
super().__init__()
self.input_dim = input_dim
self.timesteps = timesteps
# Initialize noise schedule (beta values)
self.betas = jnp.linspace(1e-4, 0.02, timesteps)
self.alphas = 1.0 - self.betas
self.alpha_bars = jnp.cumprod(self.alphas)
def __call__(self, x, t, *, rngs=None):
# Predict noise at timestep t
# ... model implementation ...
pass
# Create model
model = SimpleDiffusionModel(
input_dim=(28, 28, 1),
timesteps=1000,
rngs=rngs
)
# Generate samples
samples = model.generate(num_samples=16, rngs=rngs)
Features Demonstrated¤
SimpleDiffusionModel Creation¤
- Custom model class extending
GenerativeModel - Proper initialization with RNG handling
- Beta schedule setup for noise control
Noise Schedule¤
- Linear schedule implementation
- Alpha and alpha_bar calculations
- Understanding variance schedules
Sample Generation¤
- Starting from random noise
- Iterative denoising process
- Controlling generation quality
Visualization¤
- Displaying generated samples
- Comparing different timesteps
- Analyzing generation quality
Experiments to Try¤
- Modify the noise schedule: Try different beta ranges or cosine schedules
- Change timesteps: Experiment with different numbers of diffusion steps
- Vary sample size: Generate different numbers of samples
- Add conditioning: Extend to conditional generation
- Custom architecture: Implement different denoising networks
Next Steps¤
After understanding this basic example:
- DiT Demo: Learn about Diffusion Transformers for more advanced architectures
- Training: Implement a full training loop for your own dataset
- Advanced Schedules: Explore more sophisticated noise schedules
- Conditional Generation: Add class or text conditioning
Troubleshooting¤
Import Errors¤
Make sure you've activated the Workshop environment:
CUDA Issues¤
If you encounter GPU errors, try running on CPU:
Memory Issues¤
Reduce the batch size or number of timesteps if you run out of memory.
Additional Resources¤
- Paper: Denoising Diffusion Probabilistic Models (DDPM)
- Workshop Diffusion Guide: Diffusion Models Guide
- API Reference: Diffusion API
Related Examples¤
- DiT Demo - Advanced diffusion with transformers
- Simple EBM - Energy-based models with MCMC sampling