Skip to content

Diffusion Models User Guide¤

This guide covers practical usage of diffusion models in Workshop, from basic DDPM to advanced techniques like latent diffusion and guidance.

Quick Start¤

Here's a minimal example to get started with diffusion models:

import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.diffusion import DDPMModel

# Initialize RNGs
rngs = nnx.Rngs(0, params=1, noise=2, sample=3)

# Configure the model
config = ModelConfiguration(
    name="my_diffusion",
    model_class="DDPMModel",
    input_dim=(28, 28, 1),  # MNIST dimensions
    parameters={
        "noise_steps": 1000,
        "beta_start": 1e-4,
        "beta_end": 0.02,
    }
)

# Create model
model = DDPMModel(config, rngs=rngs)

# Generate samples
samples = model.generate(n_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}")  # (16, 28, 28, 1)

Creating Diffusion Models¤

DDPM (Denoising Diffusion Probabilistic Models)¤

DDPM is the foundational diffusion model with stable training and excellent quality.

from workshop.generative_models.models.diffusion import DDPMModel

# Standard DDPM configuration
config = ModelConfiguration(
    name="ddpm_model",
    model_class="DDPMModel",
    input_dim=(32, 32, 3),
    parameters={
        "noise_steps": 1000,        # Number of diffusion steps
        "beta_start": 1e-4,          # Starting noise level
        "beta_end": 0.02,            # Ending noise level
        "beta_schedule": "linear",   # Noise schedule type
    }
)

# Create model
model = DDPMModel(config, rngs=rngs)

# Forward diffusion (add noise)
x_clean = jnp.ones((4, 32, 32, 3))
t = jnp.array([100, 200, 300, 400])  # Different timesteps
x_noisy, noise = model.forward_diffusion(x_clean, t, rngs=rngs)

print(f"Clean shape: {x_clean.shape}")
print(f"Noisy shape: {x_noisy.shape}")
print(f"Noise shape: {noise.shape}")

Key Parameters:

Parameter Default Description
noise_steps 1000 Number of diffusion timesteps
beta_start 1e-4 Initial noise variance
beta_end 0.02 Final noise variance
beta_schedule "linear" Schedule type (linear/cosine)

DDIM (Faster Sampling)¤

DDIM enables much faster sampling with fewer steps while maintaining quality.

from workshop.generative_models.models.diffusion import DDIMModel

# DDIM configuration
config = ModelConfiguration(
    name="ddim_model",
    model_class="DDIMModel",
    input_dim=(32, 32, 3),
    parameters={
        "noise_steps": 1000,      # Training steps
        "ddim_steps": 50,          # Sampling steps (much fewer!)
        "eta": 0.0,                # 0 = deterministic, 1 = stochastic
        "skip_type": "uniform",    # How to select timesteps
        "beta_start": 1e-4,
        "beta_end": 0.02,
    }
)

# Create DDIM model
model = DDIMModel(config, rngs=rngs)

# Fast sampling with only 50 steps
samples = model.ddim_sample(
    n_samples=16,
    steps=50,      # Much faster than 1000!
    eta=0.0,       # Deterministic
    rngs=rngs
)

print(f"Generated {samples.shape[0]} samples in only 50 steps!")

DDIM vs DDPM:

Aspect DDPM DDIM
Sampling Steps 1000 50-100
Speed Slow 10-20x faster
Stochasticity Stochastic Deterministic (η=0)
Quality Excellent Very good
Use Case Training, quality Inference, speed

DDIM Inversion (Image Editing)¤

DDIM's deterministic nature enables image editing through inversion:

# Encode a real image to noise
real_image = load_image("path/to/image.png")  # Shape: (1, 32, 32, 3)

# DDIM reverse (image → noise)
noise_code = model.ddim_reverse(
    real_image,
    ddim_steps=50,
    rngs=rngs
)

# Now you can edit the noise and decode back
edited_noise = noise_code + 0.1 * modification_vector

# DDIM forward (noise → image)
edited_image = model.ddim_sample(
    n_samples=1,
    steps=50,
    rngs=rngs
)

Score-Based Diffusion Models¤

Score-based models predict the score function (gradient of log-likelihood) using continuous-time SDEs.

from workshop.generative_models.models.diffusion import ScoreDiffusionModel

# Score-based configuration
config = ModelConfiguration(
    name="score_model",
    model_class="ScoreDiffusionModel",
    input_dim=(32, 32, 3),
    parameters={
        "sigma_min": 0.01,          # Minimum noise level
        "sigma_max": 1.0,            # Maximum noise level
        "score_scaling": 1.0,        # Score scaling factor
        "noise_steps": 1000,
    }
)

# Create model
model = ScoreDiffusionModel(config=config, rngs=rngs)

# Generate samples using reverse SDE
samples = model.sample(
    num_samples=16,
    num_steps=1000,
    return_trajectory=False,
    rngs=rngs
)

Score-Based Features:

  • Continuous-time formulation
  • Flexible noise schedules
  • Connection to score matching theory
  • Can use various SDE solvers

Latent Diffusion Models (Efficient High-Res)¤

Latent diffusion applies diffusion in a compressed latent space for efficiency.

from workshop.generative_models.models.diffusion import LDMModel

# Latent diffusion configuration
config = ModelConfiguration(
    name="ldm_model",
    model_class="LDMModel",
    input_dim=(64, 64, 3),  # High resolution input
    parameters={
        "latent_dim": 16,              # Compressed latent dimension
        "encoder_hidden_dims": [64, 128],
        "decoder_hidden_dims": [128, 64],
        "encoder_type": "simple",       # or "vae" for pretrained
        "scale_factor": 0.18215,        # Latent scaling
        "noise_steps": 1000,
        "beta_start": 1e-4,
        "beta_end": 0.02,
    }
)

# Create latent diffusion model
model = LDMModel(config=config, rngs=rngs)

# The model automatically encodes to latent space
# Training happens in latent space (much faster!)
samples = model.sample(
    num_samples=16,
    rngs=rngs
)
# Samples are automatically decoded to pixel space
print(f"High-res samples: {samples.shape}")  # (16, 64, 64, 3)

LDM Advantages:

  • 8x faster training than pixel-space diffusion
  • Lower memory requirements
  • Enables high-resolution generation
  • Foundation of Stable Diffusion

Diffusion Transformers (DiT)¤

DiT uses a Vision Transformer backbone for better scalability.

from workshop.generative_models.models.diffusion import DiTModel

# DiT configuration
config = ModelConfiguration(
    name="dit_model",
    model_class="DiTModel",
    input_dim=(32, 32, 3),
    parameters={
        "img_size": 32,              # Image size
        "patch_size": 4,             # Patch size (32/4 = 8 patches per side)
        "hidden_size": 512,          # Transformer hidden dimension
        "depth": 12,                 # Number of transformer layers
        "num_heads": 8,              # Number of attention heads
        "mlp_ratio": 4.0,            # MLP expansion ratio
        "num_classes": 10,           # For conditional generation
        "dropout_rate": 0.1,
        "learn_sigma": False,        # Learn variance
        "cfg_scale": 2.0,            # Classifier-free guidance scale
        "noise_steps": 1000,
    }
)

# Create DiT model
model = DiTModel(config, rngs=rngs)

# Generate with class conditioning
class_labels = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])  # One of each class

samples = model.generate(
    n_samples=10,
    y=class_labels,
    cfg_scale=2.0,  # Classifier-free guidance
    num_steps=1000,
    rngs=rngs
)

DiT Architecture:

graph TD
    Input[Image 32×32×3] --> Patch[Patchify<br/>8×8 patches]
    Patch --> Embed[Linear Projection]
    Embed --> PE[+ Position Embedding]

    Time[Timestep t] --> TEmb[Time MLP]
    Class[Class y] --> CEmb[Class Embedding]

    PE --> T1[Transformer<br/>Block 1]
    TEmb --> T1
    CEmb --> T1

    T1 --> T2[...]
    T2 --> T12[Transformer<br/>Block 12]

    T12 --> Final[Final Layer Norm]
    Final --> Linear[Linear<br/>Projection]
    Linear --> Reshape[Reshape to Image]
    Reshape --> Output[Predicted Noise]

    style T1 fill:#9C27B0
    style T12 fill:#9C27B0

Training Diffusion Models¤

Basic Training Loop¤

import optax
from flax import nnx

# Create model
model = DDPMModel(config, rngs=rngs)

# Create optimizer
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-4))

# Training step
@nnx.jit
def train_step(model, optimizer, batch, rngs):
    """Single training step."""

    def loss_fn(model):
        # Sample random timesteps
        batch_size = batch.shape[0]
        t = jax.random.randint(
            rngs.timestep(),
            (batch_size,),
            0,
            config.parameters["noise_steps"]
        )

        # Add noise to batch
        noise = jax.random.normal(rngs.noise(), batch.shape)
        x_noisy = model.q_sample(batch, t, noise, rngs=rngs)

        # Predict noise
        outputs = model(x_noisy, t, training=True, rngs=rngs)
        predicted_noise = outputs["predicted_noise"]

        # MSE loss
        loss = jnp.mean((predicted_noise - noise) ** 2)

        return loss

    # Compute loss and gradients
    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # Update parameters
    optimizer.update(grads)

    return {"loss": loss}

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        metrics = train_step(model, optimizer, batch, rngs)

        if step % 100 == 0:
            print(f"Epoch {epoch}, Step {step}, Loss: {metrics['loss']:.4f}")

Training with EMA (Exponential Moving Average)¤

EMA improves sample quality by maintaining a moving average of parameters:

from workshop.generative_models.core.training import EMAModel

# Create model and EMA
model = DDPMModel(config, rngs=rngs)
ema_model = EMAModel(model, decay=0.9999)

# Training step with EMA
@nnx.jit
def train_step_with_ema(model, ema_model, optimizer, batch, rngs):
    # Compute loss and update (same as before)
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)

    # Update EMA
    ema_model.update(model)

    return {"loss": loss}

# Use EMA model for sampling
samples = ema_model.generate(n_samples=16, rngs=rngs)

Mixed Precision Training¤

Use mixed precision to speed up training and reduce memory:

# Configure for mixed precision
config = ModelConfiguration(
    name="ddpm_fp16",
    model_class="DDPMModel",
    input_dim=(32, 32, 3),
    parameters={
        "noise_steps": 1000,
        "beta_start": 1e-4,
        "beta_end": 0.02,
        "use_fp16": True,  # Enable mixed precision
    }
)

# Create model with mixed precision
model = DDPMModel(config, rngs=rngs)

# Use dynamic loss scaling
loss_scale = 2 ** 15

@nnx.jit
def train_step_fp16(model, optimizer, batch, rngs):
    def loss_fn(model):
        # ... compute loss ...
        return loss * loss_scale  # Scale loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # Unscale gradients
    grads = jax.tree_map(lambda g: g / loss_scale, grads)

    optimizer.update(grads)

    return {"loss": loss / loss_scale}

Sampling Strategies¤

DDPM Sampling (High Quality)¤

Standard DDPM sampling with all 1000 steps:

# Generate with full DDPM sampling
samples = model.generate(
    n_samples=16,
    shape=(32, 32, 3),
    clip_denoised=True,  # Clip to [-1, 1]
    rngs=rngs
)

# This takes all 1000 steps - highest quality but slow

DDIM Sampling (Fast)¤

Use DDIM for 10-20x faster sampling:

# Generate with DDIM (50 steps instead of 1000)
samples = model.sample(
    n_samples_or_shape=16,
    scheduler="ddim",
    steps=50,  # Only 50 steps!
    rngs=rngs
)

# Quality vs Speed tradeoff:
# - 20 steps: Fast but lower quality
# - 50 steps: Good balance
# - 100 steps: High quality, still 10x faster than DDPM

Progressive Sampling (Visualize Process)¤

Visualize the denoising process:

def progressive_sampling(model, n_samples, save_every=100, rngs=None):
    """Generate samples and save intermediate steps."""
    trajectory = []

    # Start from noise
    shape = model._get_sample_shape()
    x = jax.random.normal(rngs.sample(), (n_samples, *shape))

    # Denoise step by step
    for t in range(model.noise_steps - 1, -1, -1):
        t_batch = jnp.full((n_samples,), t, dtype=jnp.int32)

        # Model prediction
        outputs = model(x, t_batch, rngs=rngs)
        predicted_noise = outputs["predicted_noise"]

        # Denoising step
        x = model.p_sample(predicted_noise, x, t_batch, rngs=rngs)

        # Save intermediate results
        if t % save_every == 0 or t == 0:
            trajectory.append(x.copy())
            print(f"Step {1000-t}/{1000}")

    return jnp.stack(trajectory)

# Generate and visualize
trajectory = progressive_sampling(model, n_samples=4, save_every=100, rngs=rngs)
# trajectory shape: (11, 4, 32, 32, 3) - 11 snapshots of 4 images

Conditional Sampling with Guidance¤

Classifier-Free Guidance¤

from workshop.generative_models.models.diffusion.guidance import ClassifierFreeGuidance

# Create guidance
cfg = ClassifierFreeGuidance(
    guidance_scale=7.5,  # Higher = stronger conditioning
    unconditional_conditioning=None  # Null token
)

# Sample with guidance
def sample_with_cfg(model, class_labels, guidance_scale=7.5, rngs=None):
    """Generate samples with classifier-free guidance."""

    n_samples = len(class_labels)
    shape = model._get_sample_shape()

    # Start from noise
    x = jax.random.normal(rngs.sample(), (n_samples, *shape))

    # Denoise with guidance
    for t in range(model.noise_steps - 1, -1, -1):
        t_batch = jnp.full((n_samples,), t)

        # Get conditional prediction
        cond_output = model(x, t_batch, conditioning=class_labels, rngs=rngs)
        cond_noise = cond_output["predicted_noise"]

        # Get unconditional prediction
        uncond_output = model(x, t_batch, conditioning=None, rngs=rngs)
        uncond_noise = uncond_output["predicted_noise"]

        # Apply guidance
        guided_noise = uncond_noise + guidance_scale * (cond_noise - uncond_noise)

        # Denoising step with guided noise
        x = model.p_sample(guided_noise, x, t_batch, rngs=rngs)

    return x

# Generate class-conditional samples
class_labels = jnp.array([0, 1, 2, 3])  # Classes to generate
samples = sample_with_cfg(model, class_labels, guidance_scale=7.5, rngs=rngs)

Guidance Scale Effects:

Scale Effect
w = 1.0 No guidance (unconditional)
w = 2.0 Mild conditioning
w = 7.5 Strong conditioning (common default)
w = 15.0 Very strong, may reduce diversity

Classifier Guidance¤

from workshop.generative_models.models.diffusion.guidance import ClassifierGuidance

# Assuming you have a trained classifier
classifier = load_pretrained_classifier()

# Create classifier guidance
cg = ClassifierGuidance(
    classifier=classifier,
    guidance_scale=1.0,
    class_label=5  # Generate class 5
)

# Sample with classifier guidance
guided_samples = cg(
    model=model,
    x=initial_noise,
    t=timesteps,
    rngs=rngs
)

Temperature Sampling¤

Control sample diversity with temperature:

def sample_with_temperature(model, n_samples, temperature=1.0, rngs=None):
    """Sample with temperature control.

    Args:
        temperature: Higher = more diverse, Lower = more conservative
    """
    shape = model._get_sample_shape()
    x = jax.random.normal(rngs.sample(), (n_samples, *shape))

    for t in range(model.noise_steps - 1, -1, -1):
        t_batch = jnp.full((n_samples,), t)

        # Model prediction
        outputs = model(x, t_batch, rngs=rngs)
        predicted_noise = outputs["predicted_noise"]

        # Get mean and variance
        out = model.p_mean_variance(predicted_noise, x, t_batch)

        # Sample with temperature-scaled variance
        if t > 0:
            noise = jax.random.normal(rngs.noise(), x.shape)
            scaled_std = jnp.exp(0.5 * out["log_variance"]) * temperature
            x = out["mean"] + scaled_std * noise
        else:
            x = out["mean"]

    return x

# Different temperatures
conservative = sample_with_temperature(model, 16, temperature=0.8, rngs=rngs)
diverse = sample_with_temperature(model, 16, temperature=1.2, rngs=rngs)

Common Patterns¤

Pattern 1: Custom Noise Schedules¤

Implement a custom noise schedule:

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule as proposed in Improved DDPM."""
    steps = timesteps + 1
    t = jnp.linspace(0, timesteps, steps)
    alphas_cumprod = jnp.cos(((t / timesteps) + s) / (1 + s) * jnp.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return jnp.clip(betas, 0.0001, 0.9999)

# Use custom schedule
class DDPMWithCosineSchedule(DDPMModel):
    def setup_noise_schedule(self):
        """Override to use cosine schedule."""
        params = self.config.parameters or {}
        num_timesteps = params.get("noise_steps", 1000)

        # Use cosine schedule
        self.betas = cosine_beta_schedule(num_timesteps)

        # Compute alpha values (same as parent)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = jnp.cumprod(self.alphas)
        # ... rest of alpha computations ...

Pattern 2: Multi-Scale Diffusion¤

Apply diffusion at multiple resolutions:

class MultiScaleDiffusion:
    """Diffusion at multiple resolutions for better quality."""

    def __init__(self, scales=[1.0, 0.5, 0.25], rngs=None):
        self.models = {}
        for scale in scales:
            size = int(32 * scale)
            config = ModelConfiguration(
                name=f"ddpm_{size}x{size}",
                model_class="DDPMModel",
                input_dim=(size, size, 3),
                parameters={"noise_steps": 1000},
            )
            self.models[scale] = DDPMModel(config, rngs=rngs)

    def generate(self, n_samples, rngs=None):
        """Generate using coarse-to-fine approach."""
        # Generate at coarsest scale
        x = self.models[0.25].generate(n_samples, rngs=rngs)

        # Upsample and refine at each scale
        for scale in [0.5, 1.0]:
            # Upsample
            x = jax.image.resize(x, (n_samples, int(32*scale), int(32*scale), 3), "bilinear")

            # Refine with diffusion at this scale
            # Add noise and denoise for refinement
            t = jnp.full((n_samples,), 200)  # Partial noise
            x_noisy = self.models[scale].q_sample(x, t, rngs=rngs)

            # Denoise
            for step in range(200, 0, -1):
                t = jnp.full((n_samples,), step)
                outputs = self.models[scale](x_noisy, t, rngs=rngs)
                x_noisy = self.models[scale].p_sample(
                    outputs["predicted_noise"], x_noisy, t, rngs=rngs
                )

            x = x_noisy

        return x

Pattern 3: Inpainting¤

Use diffusion for image inpainting:

def inpaint(model, image, mask, rngs=None):
    """Inpaint masked regions using diffusion.

    Args:
        image: Original image (1, H, W, C)
        mask: Binary mask (1, H, W, 1), 1 = inpaint, 0 = keep
        rngs: Random number generators

    Returns:
        Inpainted image
    """
    # Start from noise
    x = jax.random.normal(rngs.sample(), image.shape)

    # Denoise with guidance from known pixels
    for t in range(model.noise_steps - 1, -1, -1):
        t_batch = jnp.full((1,), t)

        # Predict noise
        outputs = model(x, t_batch, rngs=rngs)
        predicted_noise = outputs["predicted_noise"]

        # Denoising step
        x = model.p_sample(predicted_noise, x, t_batch, rngs=rngs)

        # Replace known regions with noisy version of original
        x_noisy_orig = model.q_sample(image, t_batch, rngs=rngs)
        x = mask * x + (1 - mask) * x_noisy_orig

    return x

# Usage
image = load_image("photo.png")
mask = create_mask(image, region="center")  # Mask out center
inpainted = inpaint(model, image, mask, rngs=rngs)

Pattern 4: Image Interpolation¤

Interpolate between images in latent space:

def interpolate_images(model, img1, img2, steps=10, rngs=None):
    """Interpolate between two images using DDIM inversion.

    Args:
        img1, img2: Images to interpolate (1, H, W, C)
        steps: Number of interpolation steps
        rngs: Random number generators

    Returns:
        Interpolated images (steps, H, W, C)
    """
    # Encode both images to noise using DDIM
    noise1 = model.ddim_reverse(img1, ddim_steps=50, rngs=rngs)
    noise2 = model.ddim_reverse(img2, ddim_steps=50, rngs=rngs)

    # Interpolate in noise space
    alphas = jnp.linspace(0, 1, steps)
    interpolated = []

    for alpha in alphas:
        # Spherical interpolation (better than linear)
        noise_interp = slerp(noise1, noise2, alpha)

        # Decode back to image
        img = model.ddim_sample(n_samples=1, steps=50, rngs=rngs)
        interpolated.append(img[0])

    return jnp.stack(interpolated)

def slerp(v1, v2, alpha):
    """Spherical linear interpolation."""
    v1_norm = v1 / jnp.linalg.norm(v1)
    v2_norm = v2 / jnp.linalg.norm(v2)

    dot = jnp.sum(v1_norm * v2_norm)
    theta = jnp.arccos(jnp.clip(dot, -1.0, 1.0))

    if theta < 1e-6:
        return (1 - alpha) * v1 + alpha * v2

    sin_theta = jnp.sin(theta)
    w1 = jnp.sin((1 - alpha) * theta) / sin_theta
    w2 = jnp.sin(alpha * theta) / sin_theta

    return w1 * v1 + w2 * v2

Common Issues and Solutions¤

Issue 1: Blurry Samples¤

Symptoms:

  • Generated images lack detail
  • Samples are smooth but not sharp

Solutions:

# Solution 1: Increase model capacity
config.parameters.update({
    "hidden_dims": [128, 256, 512],  # Larger network
})

# Solution 2: Use cosine schedule
config.parameters["beta_schedule"] = "cosine"

# Solution 3: Train longer
num_epochs = 500  # More training

# Solution 4: Use larger noise steps
config.parameters["noise_steps"] = 2000  # More steps

Issue 2: Training Instability¤

Symptoms:

  • Loss spikes or diverges
  • NaN values in gradients

Solutions:

# Solution 1: Lower learning rate
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-5))

# Solution 2: Gradient clipping
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip gradients
        optax.adam(1e-4),
    )
)

# Solution 3: Warmup learning rate
schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-6,
    peak_value=1e-4,
    warmup_steps=1000,
    decay_steps=100000,
)
optimizer = nnx.Optimizer(model, optax.adam(schedule))

# Solution 4: Mixed precision with loss scaling
# (See mixed precision training section above)

Issue 3: Slow Sampling¤

Symptoms:

  • Generating samples takes too long
  • Inference is impractical for real-time use

Solutions:

# Solution 1: Use DDIM sampling
samples = model.sample(16, scheduler="ddim", steps=50, rngs=rngs)  # 20x faster

# Solution 2: Use fewer sampling steps
samples = model.sample(16, scheduler="ddim", steps=20, rngs=rngs)  # Even faster

# Solution 3: Use Latent Diffusion
ldm = LDMModel(config, rngs=rngs)  # Operates in compressed space

# Solution 4: Distillation (train student model)
# Train a model to match DDPM in fewer steps
# (Advanced technique, requires separate training)

Issue 4: Mode Collapse (Repetitive Samples)¤

Symptoms:

  • Generated samples look too similar
  • Lack of diversity

Solutions:

# Solution 1: Increase temperature
samples = sample_with_temperature(model, 16, temperature=1.2, rngs=rngs)

# Solution 2: Decrease guidance scale
samples = model.generate(16, guidance_scale=2.0, rngs=rngs)  # Lower than 7.5

# Solution 3: More training data
# Ensure diverse training set

# Solution 4: Data augmentation
# Apply augmentations during training

Issue 5: Out of Memory¤

Symptoms:

  • GPU/TPU runs out of memory during training or sampling

Solutions:

# Solution 1: Reduce batch size
batch_size = 32  # Instead of 128

# Solution 2: Use gradient accumulation
for i in range(accumulation_steps):
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    accumulated_grads = jax.tree_map(lambda a, b: a + b, accumulated_grads, grads)

accumulated_grads = jax.tree_map(lambda g: g / accumulation_steps, accumulated_grads)
optimizer.update(accumulated_grads)

# Solution 3: Use Latent Diffusion
# Operate in compressed latent space (8x less memory)

# Solution 4: Enable mixed precision
config.parameters["use_fp16"] = True

Best Practices¤

Do's ✅¤

  1. Use EMA for sampling: Exponential moving average improves quality
  2. Start with DDPM: Master the basics before advanced techniques
  3. Try DDIM for speed: 10-20x faster with minimal quality loss
  4. Use cosine schedule for high-res: Better than linear for large images
  5. Implement proper data preprocessing: Scale to [-1, 1] range
  6. Monitor sample quality: Generate samples during training
  7. Use classifier-free guidance: Better than classifier guidance usually
  8. Save checkpoints frequently: Long training requires safety nets

Don'ts ❌¤

  1. Don't skip EMA: Samples will be lower quality
  2. Don't use too few steps: DDIM needs at least 20-50 steps
  3. Don't forget to clip outputs: Keeps samples in valid range
  4. Don't train without augmentation: Especially for small datasets
  5. Don't use batch size 1: Larger batches stabilize training
  6. Don't ignore timestep sampling: Uniform works well
  7. Don't use same RNG for everything: Separate RNGs for different operations
  8. Don't expect instant results: Diffusion training takes time

Hyperparameter Guidelines¤

Parameter Typical Range Notes
Learning Rate 1e-5 to 1e-4 Lower for large models
Batch Size 64-512 Larger is better (if memory allows)
Noise Steps 1000-2000 1000 is standard
DDIM Steps 20-100 50 is good balance
EMA Decay 0.999-0.9999 Higher for slower updates
Guidance Scale 1.0-15.0 7.5 is common default
Beta Start 1e-5 to 1e-4 1e-4 is standard
Beta End 0.02-0.05 0.02 is standard

Summary¤

This guide covered practical usage of diffusion models:

Key Takeaways:

  1. DDPM: Foundation model, excellent quality, slow sampling
  2. DDIM: Fast sampling (50 steps), deterministic, enables editing
  3. Score-Based: Continuous-time formulation, flexible schedules
  4. Latent Diffusion: Efficient high-resolution generation
  5. DiT: Transformer backbone, better scalability
  6. Guidance: Classifier-free guidance for conditional generation
  7. Training: Use EMA, proper preprocessing, and patience
  8. Sampling: DDIM for speed, temperature for diversity

Quick Reference:

# Standard training
model = DDPMModel(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
# ... train ...

# Fast inference
samples = model.sample(16, scheduler="ddim", steps=50, rngs=rngs)

# Conditional generation
samples = model.generate(16, guidance_scale=7.5, conditioning=labels, rngs=rngs)

Next Steps¤