Diffusion Models User Guide¤
This guide covers practical usage of diffusion models in Workshop, from basic DDPM to advanced techniques like latent diffusion and guidance.
Quick Start¤
Here's a minimal example to get started with diffusion models:
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.diffusion import DDPMModel
# Initialize RNGs
rngs = nnx.Rngs(0, params=1, noise=2, sample=3)
# Configure the model
config = ModelConfiguration(
name="my_diffusion",
model_class="DDPMModel",
input_dim=(28, 28, 1), # MNIST dimensions
parameters={
"noise_steps": 1000,
"beta_start": 1e-4,
"beta_end": 0.02,
}
)
# Create model
model = DDPMModel(config, rngs=rngs)
# Generate samples
samples = model.generate(n_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}") # (16, 28, 28, 1)
Creating Diffusion Models¤
DDPM (Denoising Diffusion Probabilistic Models)¤
DDPM is the foundational diffusion model with stable training and excellent quality.
from workshop.generative_models.models.diffusion import DDPMModel
# Standard DDPM configuration
config = ModelConfiguration(
name="ddpm_model",
model_class="DDPMModel",
input_dim=(32, 32, 3),
parameters={
"noise_steps": 1000, # Number of diffusion steps
"beta_start": 1e-4, # Starting noise level
"beta_end": 0.02, # Ending noise level
"beta_schedule": "linear", # Noise schedule type
}
)
# Create model
model = DDPMModel(config, rngs=rngs)
# Forward diffusion (add noise)
x_clean = jnp.ones((4, 32, 32, 3))
t = jnp.array([100, 200, 300, 400]) # Different timesteps
x_noisy, noise = model.forward_diffusion(x_clean, t, rngs=rngs)
print(f"Clean shape: {x_clean.shape}")
print(f"Noisy shape: {x_noisy.shape}")
print(f"Noise shape: {noise.shape}")
Key Parameters:
| Parameter | Default | Description |
|---|---|---|
noise_steps |
1000 | Number of diffusion timesteps |
beta_start |
1e-4 | Initial noise variance |
beta_end |
0.02 | Final noise variance |
beta_schedule |
"linear" | Schedule type (linear/cosine) |
DDIM (Faster Sampling)¤
DDIM enables much faster sampling with fewer steps while maintaining quality.
from workshop.generative_models.models.diffusion import DDIMModel
# DDIM configuration
config = ModelConfiguration(
name="ddim_model",
model_class="DDIMModel",
input_dim=(32, 32, 3),
parameters={
"noise_steps": 1000, # Training steps
"ddim_steps": 50, # Sampling steps (much fewer!)
"eta": 0.0, # 0 = deterministic, 1 = stochastic
"skip_type": "uniform", # How to select timesteps
"beta_start": 1e-4,
"beta_end": 0.02,
}
)
# Create DDIM model
model = DDIMModel(config, rngs=rngs)
# Fast sampling with only 50 steps
samples = model.ddim_sample(
n_samples=16,
steps=50, # Much faster than 1000!
eta=0.0, # Deterministic
rngs=rngs
)
print(f"Generated {samples.shape[0]} samples in only 50 steps!")
DDIM vs DDPM:
| Aspect | DDPM | DDIM |
|---|---|---|
| Sampling Steps | 1000 | 50-100 |
| Speed | Slow | 10-20x faster |
| Stochasticity | Stochastic | Deterministic (η=0) |
| Quality | Excellent | Very good |
| Use Case | Training, quality | Inference, speed |
DDIM Inversion (Image Editing)¤
DDIM's deterministic nature enables image editing through inversion:
# Encode a real image to noise
real_image = load_image("path/to/image.png") # Shape: (1, 32, 32, 3)
# DDIM reverse (image → noise)
noise_code = model.ddim_reverse(
real_image,
ddim_steps=50,
rngs=rngs
)
# Now you can edit the noise and decode back
edited_noise = noise_code + 0.1 * modification_vector
# DDIM forward (noise → image)
edited_image = model.ddim_sample(
n_samples=1,
steps=50,
rngs=rngs
)
Score-Based Diffusion Models¤
Score-based models predict the score function (gradient of log-likelihood) using continuous-time SDEs.
from workshop.generative_models.models.diffusion import ScoreDiffusionModel
# Score-based configuration
config = ModelConfiguration(
name="score_model",
model_class="ScoreDiffusionModel",
input_dim=(32, 32, 3),
parameters={
"sigma_min": 0.01, # Minimum noise level
"sigma_max": 1.0, # Maximum noise level
"score_scaling": 1.0, # Score scaling factor
"noise_steps": 1000,
}
)
# Create model
model = ScoreDiffusionModel(config=config, rngs=rngs)
# Generate samples using reverse SDE
samples = model.sample(
num_samples=16,
num_steps=1000,
return_trajectory=False,
rngs=rngs
)
Score-Based Features:
- Continuous-time formulation
- Flexible noise schedules
- Connection to score matching theory
- Can use various SDE solvers
Latent Diffusion Models (Efficient High-Res)¤
Latent diffusion applies diffusion in a compressed latent space for efficiency.
from workshop.generative_models.models.diffusion import LDMModel
# Latent diffusion configuration
config = ModelConfiguration(
name="ldm_model",
model_class="LDMModel",
input_dim=(64, 64, 3), # High resolution input
parameters={
"latent_dim": 16, # Compressed latent dimension
"encoder_hidden_dims": [64, 128],
"decoder_hidden_dims": [128, 64],
"encoder_type": "simple", # or "vae" for pretrained
"scale_factor": 0.18215, # Latent scaling
"noise_steps": 1000,
"beta_start": 1e-4,
"beta_end": 0.02,
}
)
# Create latent diffusion model
model = LDMModel(config=config, rngs=rngs)
# The model automatically encodes to latent space
# Training happens in latent space (much faster!)
samples = model.sample(
num_samples=16,
rngs=rngs
)
# Samples are automatically decoded to pixel space
print(f"High-res samples: {samples.shape}") # (16, 64, 64, 3)
LDM Advantages:
- 8x faster training than pixel-space diffusion
- Lower memory requirements
- Enables high-resolution generation
- Foundation of Stable Diffusion
Diffusion Transformers (DiT)¤
DiT uses a Vision Transformer backbone for better scalability.
from workshop.generative_models.models.diffusion import DiTModel
# DiT configuration
config = ModelConfiguration(
name="dit_model",
model_class="DiTModel",
input_dim=(32, 32, 3),
parameters={
"img_size": 32, # Image size
"patch_size": 4, # Patch size (32/4 = 8 patches per side)
"hidden_size": 512, # Transformer hidden dimension
"depth": 12, # Number of transformer layers
"num_heads": 8, # Number of attention heads
"mlp_ratio": 4.0, # MLP expansion ratio
"num_classes": 10, # For conditional generation
"dropout_rate": 0.1,
"learn_sigma": False, # Learn variance
"cfg_scale": 2.0, # Classifier-free guidance scale
"noise_steps": 1000,
}
)
# Create DiT model
model = DiTModel(config, rngs=rngs)
# Generate with class conditioning
class_labels = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # One of each class
samples = model.generate(
n_samples=10,
y=class_labels,
cfg_scale=2.0, # Classifier-free guidance
num_steps=1000,
rngs=rngs
)
DiT Architecture:
graph TD
Input[Image 32×32×3] --> Patch[Patchify<br/>8×8 patches]
Patch --> Embed[Linear Projection]
Embed --> PE[+ Position Embedding]
Time[Timestep t] --> TEmb[Time MLP]
Class[Class y] --> CEmb[Class Embedding]
PE --> T1[Transformer<br/>Block 1]
TEmb --> T1
CEmb --> T1
T1 --> T2[...]
T2 --> T12[Transformer<br/>Block 12]
T12 --> Final[Final Layer Norm]
Final --> Linear[Linear<br/>Projection]
Linear --> Reshape[Reshape to Image]
Reshape --> Output[Predicted Noise]
style T1 fill:#9C27B0
style T12 fill:#9C27B0
Training Diffusion Models¤
Basic Training Loop¤
import optax
from flax import nnx
# Create model
model = DDPMModel(config, rngs=rngs)
# Create optimizer
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-4))
# Training step
@nnx.jit
def train_step(model, optimizer, batch, rngs):
"""Single training step."""
def loss_fn(model):
# Sample random timesteps
batch_size = batch.shape[0]
t = jax.random.randint(
rngs.timestep(),
(batch_size,),
0,
config.parameters["noise_steps"]
)
# Add noise to batch
noise = jax.random.normal(rngs.noise(), batch.shape)
x_noisy = model.q_sample(batch, t, noise, rngs=rngs)
# Predict noise
outputs = model(x_noisy, t, training=True, rngs=rngs)
predicted_noise = outputs["predicted_noise"]
# MSE loss
loss = jnp.mean((predicted_noise - noise) ** 2)
return loss
# Compute loss and gradients
loss, grads = nnx.value_and_grad(loss_fn)(model)
# Update parameters
optimizer.update(grads)
return {"loss": loss}
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
metrics = train_step(model, optimizer, batch, rngs)
if step % 100 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {metrics['loss']:.4f}")
Training with EMA (Exponential Moving Average)¤
EMA improves sample quality by maintaining a moving average of parameters:
from workshop.generative_models.core.training import EMAModel
# Create model and EMA
model = DDPMModel(config, rngs=rngs)
ema_model = EMAModel(model, decay=0.9999)
# Training step with EMA
@nnx.jit
def train_step_with_ema(model, ema_model, optimizer, batch, rngs):
# Compute loss and update (same as before)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
# Update EMA
ema_model.update(model)
return {"loss": loss}
# Use EMA model for sampling
samples = ema_model.generate(n_samples=16, rngs=rngs)
Mixed Precision Training¤
Use mixed precision to speed up training and reduce memory:
# Configure for mixed precision
config = ModelConfiguration(
name="ddpm_fp16",
model_class="DDPMModel",
input_dim=(32, 32, 3),
parameters={
"noise_steps": 1000,
"beta_start": 1e-4,
"beta_end": 0.02,
"use_fp16": True, # Enable mixed precision
}
)
# Create model with mixed precision
model = DDPMModel(config, rngs=rngs)
# Use dynamic loss scaling
loss_scale = 2 ** 15
@nnx.jit
def train_step_fp16(model, optimizer, batch, rngs):
def loss_fn(model):
# ... compute loss ...
return loss * loss_scale # Scale loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
# Unscale gradients
grads = jax.tree_map(lambda g: g / loss_scale, grads)
optimizer.update(grads)
return {"loss": loss / loss_scale}
Sampling Strategies¤
DDPM Sampling (High Quality)¤
Standard DDPM sampling with all 1000 steps:
# Generate with full DDPM sampling
samples = model.generate(
n_samples=16,
shape=(32, 32, 3),
clip_denoised=True, # Clip to [-1, 1]
rngs=rngs
)
# This takes all 1000 steps - highest quality but slow
DDIM Sampling (Fast)¤
Use DDIM for 10-20x faster sampling:
# Generate with DDIM (50 steps instead of 1000)
samples = model.sample(
n_samples_or_shape=16,
scheduler="ddim",
steps=50, # Only 50 steps!
rngs=rngs
)
# Quality vs Speed tradeoff:
# - 20 steps: Fast but lower quality
# - 50 steps: Good balance
# - 100 steps: High quality, still 10x faster than DDPM
Progressive Sampling (Visualize Process)¤
Visualize the denoising process:
def progressive_sampling(model, n_samples, save_every=100, rngs=None):
"""Generate samples and save intermediate steps."""
trajectory = []
# Start from noise
shape = model._get_sample_shape()
x = jax.random.normal(rngs.sample(), (n_samples, *shape))
# Denoise step by step
for t in range(model.noise_steps - 1, -1, -1):
t_batch = jnp.full((n_samples,), t, dtype=jnp.int32)
# Model prediction
outputs = model(x, t_batch, rngs=rngs)
predicted_noise = outputs["predicted_noise"]
# Denoising step
x = model.p_sample(predicted_noise, x, t_batch, rngs=rngs)
# Save intermediate results
if t % save_every == 0 or t == 0:
trajectory.append(x.copy())
print(f"Step {1000-t}/{1000}")
return jnp.stack(trajectory)
# Generate and visualize
trajectory = progressive_sampling(model, n_samples=4, save_every=100, rngs=rngs)
# trajectory shape: (11, 4, 32, 32, 3) - 11 snapshots of 4 images
Conditional Sampling with Guidance¤
Classifier-Free Guidance¤
from workshop.generative_models.models.diffusion.guidance import ClassifierFreeGuidance
# Create guidance
cfg = ClassifierFreeGuidance(
guidance_scale=7.5, # Higher = stronger conditioning
unconditional_conditioning=None # Null token
)
# Sample with guidance
def sample_with_cfg(model, class_labels, guidance_scale=7.5, rngs=None):
"""Generate samples with classifier-free guidance."""
n_samples = len(class_labels)
shape = model._get_sample_shape()
# Start from noise
x = jax.random.normal(rngs.sample(), (n_samples, *shape))
# Denoise with guidance
for t in range(model.noise_steps - 1, -1, -1):
t_batch = jnp.full((n_samples,), t)
# Get conditional prediction
cond_output = model(x, t_batch, conditioning=class_labels, rngs=rngs)
cond_noise = cond_output["predicted_noise"]
# Get unconditional prediction
uncond_output = model(x, t_batch, conditioning=None, rngs=rngs)
uncond_noise = uncond_output["predicted_noise"]
# Apply guidance
guided_noise = uncond_noise + guidance_scale * (cond_noise - uncond_noise)
# Denoising step with guided noise
x = model.p_sample(guided_noise, x, t_batch, rngs=rngs)
return x
# Generate class-conditional samples
class_labels = jnp.array([0, 1, 2, 3]) # Classes to generate
samples = sample_with_cfg(model, class_labels, guidance_scale=7.5, rngs=rngs)
Guidance Scale Effects:
| Scale | Effect |
|---|---|
w = 1.0 |
No guidance (unconditional) |
w = 2.0 |
Mild conditioning |
w = 7.5 |
Strong conditioning (common default) |
w = 15.0 |
Very strong, may reduce diversity |
Classifier Guidance¤
from workshop.generative_models.models.diffusion.guidance import ClassifierGuidance
# Assuming you have a trained classifier
classifier = load_pretrained_classifier()
# Create classifier guidance
cg = ClassifierGuidance(
classifier=classifier,
guidance_scale=1.0,
class_label=5 # Generate class 5
)
# Sample with classifier guidance
guided_samples = cg(
model=model,
x=initial_noise,
t=timesteps,
rngs=rngs
)
Temperature Sampling¤
Control sample diversity with temperature:
def sample_with_temperature(model, n_samples, temperature=1.0, rngs=None):
"""Sample with temperature control.
Args:
temperature: Higher = more diverse, Lower = more conservative
"""
shape = model._get_sample_shape()
x = jax.random.normal(rngs.sample(), (n_samples, *shape))
for t in range(model.noise_steps - 1, -1, -1):
t_batch = jnp.full((n_samples,), t)
# Model prediction
outputs = model(x, t_batch, rngs=rngs)
predicted_noise = outputs["predicted_noise"]
# Get mean and variance
out = model.p_mean_variance(predicted_noise, x, t_batch)
# Sample with temperature-scaled variance
if t > 0:
noise = jax.random.normal(rngs.noise(), x.shape)
scaled_std = jnp.exp(0.5 * out["log_variance"]) * temperature
x = out["mean"] + scaled_std * noise
else:
x = out["mean"]
return x
# Different temperatures
conservative = sample_with_temperature(model, 16, temperature=0.8, rngs=rngs)
diverse = sample_with_temperature(model, 16, temperature=1.2, rngs=rngs)
Common Patterns¤
Pattern 1: Custom Noise Schedules¤
Implement a custom noise schedule:
def cosine_beta_schedule(timesteps, s=0.008):
"""Cosine schedule as proposed in Improved DDPM."""
steps = timesteps + 1
t = jnp.linspace(0, timesteps, steps)
alphas_cumprod = jnp.cos(((t / timesteps) + s) / (1 + s) * jnp.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return jnp.clip(betas, 0.0001, 0.9999)
# Use custom schedule
class DDPMWithCosineSchedule(DDPMModel):
def setup_noise_schedule(self):
"""Override to use cosine schedule."""
params = self.config.parameters or {}
num_timesteps = params.get("noise_steps", 1000)
# Use cosine schedule
self.betas = cosine_beta_schedule(num_timesteps)
# Compute alpha values (same as parent)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas)
# ... rest of alpha computations ...
Pattern 2: Multi-Scale Diffusion¤
Apply diffusion at multiple resolutions:
class MultiScaleDiffusion:
"""Diffusion at multiple resolutions for better quality."""
def __init__(self, scales=[1.0, 0.5, 0.25], rngs=None):
self.models = {}
for scale in scales:
size = int(32 * scale)
config = ModelConfiguration(
name=f"ddpm_{size}x{size}",
model_class="DDPMModel",
input_dim=(size, size, 3),
parameters={"noise_steps": 1000},
)
self.models[scale] = DDPMModel(config, rngs=rngs)
def generate(self, n_samples, rngs=None):
"""Generate using coarse-to-fine approach."""
# Generate at coarsest scale
x = self.models[0.25].generate(n_samples, rngs=rngs)
# Upsample and refine at each scale
for scale in [0.5, 1.0]:
# Upsample
x = jax.image.resize(x, (n_samples, int(32*scale), int(32*scale), 3), "bilinear")
# Refine with diffusion at this scale
# Add noise and denoise for refinement
t = jnp.full((n_samples,), 200) # Partial noise
x_noisy = self.models[scale].q_sample(x, t, rngs=rngs)
# Denoise
for step in range(200, 0, -1):
t = jnp.full((n_samples,), step)
outputs = self.models[scale](x_noisy, t, rngs=rngs)
x_noisy = self.models[scale].p_sample(
outputs["predicted_noise"], x_noisy, t, rngs=rngs
)
x = x_noisy
return x
Pattern 3: Inpainting¤
Use diffusion for image inpainting:
def inpaint(model, image, mask, rngs=None):
"""Inpaint masked regions using diffusion.
Args:
image: Original image (1, H, W, C)
mask: Binary mask (1, H, W, 1), 1 = inpaint, 0 = keep
rngs: Random number generators
Returns:
Inpainted image
"""
# Start from noise
x = jax.random.normal(rngs.sample(), image.shape)
# Denoise with guidance from known pixels
for t in range(model.noise_steps - 1, -1, -1):
t_batch = jnp.full((1,), t)
# Predict noise
outputs = model(x, t_batch, rngs=rngs)
predicted_noise = outputs["predicted_noise"]
# Denoising step
x = model.p_sample(predicted_noise, x, t_batch, rngs=rngs)
# Replace known regions with noisy version of original
x_noisy_orig = model.q_sample(image, t_batch, rngs=rngs)
x = mask * x + (1 - mask) * x_noisy_orig
return x
# Usage
image = load_image("photo.png")
mask = create_mask(image, region="center") # Mask out center
inpainted = inpaint(model, image, mask, rngs=rngs)
Pattern 4: Image Interpolation¤
Interpolate between images in latent space:
def interpolate_images(model, img1, img2, steps=10, rngs=None):
"""Interpolate between two images using DDIM inversion.
Args:
img1, img2: Images to interpolate (1, H, W, C)
steps: Number of interpolation steps
rngs: Random number generators
Returns:
Interpolated images (steps, H, W, C)
"""
# Encode both images to noise using DDIM
noise1 = model.ddim_reverse(img1, ddim_steps=50, rngs=rngs)
noise2 = model.ddim_reverse(img2, ddim_steps=50, rngs=rngs)
# Interpolate in noise space
alphas = jnp.linspace(0, 1, steps)
interpolated = []
for alpha in alphas:
# Spherical interpolation (better than linear)
noise_interp = slerp(noise1, noise2, alpha)
# Decode back to image
img = model.ddim_sample(n_samples=1, steps=50, rngs=rngs)
interpolated.append(img[0])
return jnp.stack(interpolated)
def slerp(v1, v2, alpha):
"""Spherical linear interpolation."""
v1_norm = v1 / jnp.linalg.norm(v1)
v2_norm = v2 / jnp.linalg.norm(v2)
dot = jnp.sum(v1_norm * v2_norm)
theta = jnp.arccos(jnp.clip(dot, -1.0, 1.0))
if theta < 1e-6:
return (1 - alpha) * v1 + alpha * v2
sin_theta = jnp.sin(theta)
w1 = jnp.sin((1 - alpha) * theta) / sin_theta
w2 = jnp.sin(alpha * theta) / sin_theta
return w1 * v1 + w2 * v2
Common Issues and Solutions¤
Issue 1: Blurry Samples¤
Symptoms:
- Generated images lack detail
- Samples are smooth but not sharp
Solutions:
# Solution 1: Increase model capacity
config.parameters.update({
"hidden_dims": [128, 256, 512], # Larger network
})
# Solution 2: Use cosine schedule
config.parameters["beta_schedule"] = "cosine"
# Solution 3: Train longer
num_epochs = 500 # More training
# Solution 4: Use larger noise steps
config.parameters["noise_steps"] = 2000 # More steps
Issue 2: Training Instability¤
Symptoms:
- Loss spikes or diverges
- NaN values in gradients
Solutions:
# Solution 1: Lower learning rate
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-5))
# Solution 2: Gradient clipping
optimizer = nnx.Optimizer(
model,
optax.chain(
optax.clip_by_global_norm(1.0), # Clip gradients
optax.adam(1e-4),
)
)
# Solution 3: Warmup learning rate
schedule = optax.warmup_cosine_decay_schedule(
init_value=1e-6,
peak_value=1e-4,
warmup_steps=1000,
decay_steps=100000,
)
optimizer = nnx.Optimizer(model, optax.adam(schedule))
# Solution 4: Mixed precision with loss scaling
# (See mixed precision training section above)
Issue 3: Slow Sampling¤
Symptoms:
- Generating samples takes too long
- Inference is impractical for real-time use
Solutions:
# Solution 1: Use DDIM sampling
samples = model.sample(16, scheduler="ddim", steps=50, rngs=rngs) # 20x faster
# Solution 2: Use fewer sampling steps
samples = model.sample(16, scheduler="ddim", steps=20, rngs=rngs) # Even faster
# Solution 3: Use Latent Diffusion
ldm = LDMModel(config, rngs=rngs) # Operates in compressed space
# Solution 4: Distillation (train student model)
# Train a model to match DDPM in fewer steps
# (Advanced technique, requires separate training)
Issue 4: Mode Collapse (Repetitive Samples)¤
Symptoms:
- Generated samples look too similar
- Lack of diversity
Solutions:
# Solution 1: Increase temperature
samples = sample_with_temperature(model, 16, temperature=1.2, rngs=rngs)
# Solution 2: Decrease guidance scale
samples = model.generate(16, guidance_scale=2.0, rngs=rngs) # Lower than 7.5
# Solution 3: More training data
# Ensure diverse training set
# Solution 4: Data augmentation
# Apply augmentations during training
Issue 5: Out of Memory¤
Symptoms:
- GPU/TPU runs out of memory during training or sampling
Solutions:
# Solution 1: Reduce batch size
batch_size = 32 # Instead of 128
# Solution 2: Use gradient accumulation
for i in range(accumulation_steps):
loss, grads = nnx.value_and_grad(loss_fn)(model)
accumulated_grads = jax.tree_map(lambda a, b: a + b, accumulated_grads, grads)
accumulated_grads = jax.tree_map(lambda g: g / accumulation_steps, accumulated_grads)
optimizer.update(accumulated_grads)
# Solution 3: Use Latent Diffusion
# Operate in compressed latent space (8x less memory)
# Solution 4: Enable mixed precision
config.parameters["use_fp16"] = True
Best Practices¤
Do's ✅¤
- Use EMA for sampling: Exponential moving average improves quality
- Start with DDPM: Master the basics before advanced techniques
- Try DDIM for speed: 10-20x faster with minimal quality loss
- Use cosine schedule for high-res: Better than linear for large images
- Implement proper data preprocessing: Scale to [-1, 1] range
- Monitor sample quality: Generate samples during training
- Use classifier-free guidance: Better than classifier guidance usually
- Save checkpoints frequently: Long training requires safety nets
Don'ts ❌¤
- Don't skip EMA: Samples will be lower quality
- Don't use too few steps: DDIM needs at least 20-50 steps
- Don't forget to clip outputs: Keeps samples in valid range
- Don't train without augmentation: Especially for small datasets
- Don't use batch size 1: Larger batches stabilize training
- Don't ignore timestep sampling: Uniform works well
- Don't use same RNG for everything: Separate RNGs for different operations
- Don't expect instant results: Diffusion training takes time
Hyperparameter Guidelines¤
| Parameter | Typical Range | Notes |
|---|---|---|
| Learning Rate | 1e-5 to 1e-4 | Lower for large models |
| Batch Size | 64-512 | Larger is better (if memory allows) |
| Noise Steps | 1000-2000 | 1000 is standard |
| DDIM Steps | 20-100 | 50 is good balance |
| EMA Decay | 0.999-0.9999 | Higher for slower updates |
| Guidance Scale | 1.0-15.0 | 7.5 is common default |
| Beta Start | 1e-5 to 1e-4 | 1e-4 is standard |
| Beta End | 0.02-0.05 | 0.02 is standard |
Summary¤
This guide covered practical usage of diffusion models:
Key Takeaways:
- DDPM: Foundation model, excellent quality, slow sampling
- DDIM: Fast sampling (50 steps), deterministic, enables editing
- Score-Based: Continuous-time formulation, flexible schedules
- Latent Diffusion: Efficient high-resolution generation
- DiT: Transformer backbone, better scalability
- Guidance: Classifier-free guidance for conditional generation
- Training: Use EMA, proper preprocessing, and patience
- Sampling: DDIM for speed, temperature for diversity
Quick Reference:
# Standard training
model = DDPMModel(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
# ... train ...
# Fast inference
samples = model.sample(16, scheduler="ddim", steps=50, rngs=rngs)
# Conditional generation
samples = model.generate(16, guidance_scale=7.5, conditioning=labels, rngs=rngs)
Next Steps¤
-
Understand the theory behind diffusion models
-
Complete API documentation for all classes
-
Hands-on tutorial with complete working example
-
Explore distillation, super-resolution, and more