Skip to content

Diffusion Transformer (DiT) Demo¤

Level: Advanced | Runtime: ~5 minutes (CPU) / ~1-2 minutes (GPU) | Format: Python + Jupyter

Overview¤

This advanced example demonstrates Diffusion Transformers (DiT), which combines the power of Vision Transformers with diffusion models. DiT represents a significant advancement in diffusion model architectures, using transformer blocks instead of traditional U-Net architectures for the denoising process.

What You'll Learn¤

  • DiffusionTransformer backbone architecture
  • DiT model sizes and scaling (S, B, L, XL configurations)
  • Conditional generation with classifier-free guidance (CFG)
  • Patch-based image processing
  • Performance benchmarking across model sizes
  • Advanced sampling techniques

Files¤

Quick Start¤

Run the Python Script¤

# Activate environment
source activate.sh

# Run the example
python examples/generative_models/diffusion/dit_demo.py

Run the Jupyter Notebook¤

# Activate environment
source activate.sh

# Launch Jupyter
jupyter lab examples/generative_models/diffusion/dit_demo.ipynb

Key Concepts¤

Diffusion Transformer Architecture¤

DiT replaces the traditional U-Net backbone with a Vision Transformer:

  • Patch Embedding: Images are divided into patches and linearly embedded
  • Positional Encoding: Added to maintain spatial information
  • Transformer Blocks: Self-attention and feed-forward layers
  • Adaptive Layer Normalization: Conditioned on timestep and class labels

Classifier-Free Guidance (CFG)¤

CFG enables stronger conditional generation by learning both conditional and unconditional models simultaneously:

\[\tilde{\epsilon}_\theta(x_t, c) = \epsilon_\theta(x_t, \emptyset) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \emptyset))\]

Where:

  • \(c\) is the conditioning (e.g., class label)
  • \(\emptyset\) is the unconditional case
  • \(s\) is the guidance scale (higher = stronger conditioning)

Model Scaling¤

DiT comes in different sizes, trading off quality and speed:

Model Hidden Dim Depth Heads Parameters
DiT-S 384 12 6 ~33M
DiT-B 768 12 12 ~130M
DiT-L 1024 24 16 ~458M
DiT-XL 1152 28 16 ~675M

Patch-Based Processing¤

Images are processed as sequences of patches:

  1. Divide image into non-overlapping patches (e.g., 16×16)
  2. Flatten each patch into a vector
  3. Apply linear projection
  4. Add positional embeddings
  5. Process through transformer blocks

Code Structure¤

The example demonstrates 7 major sections:

  1. Import Dependencies: Setting up the environment
  2. Test DiT Components: Verifying backbone and full model
  3. Test Different Model Sizes: Comparing S, B, L configurations
  4. Conditional Generation: Using classifier-free guidance
  5. Visualization: Displaying generated samples
  6. Performance Benchmark: Measuring throughput across sizes
  7. Summary: Key takeaways and next steps

Example Code¤

Creating DiT Model¤

from flax import nnx
from workshop.generative_models.models.backbones import DiffusionTransformer
from workshop.generative_models.models.diffusion import DiTModel

# Initialize RNG
rngs = nnx.Rngs(42)

# Create DiT-S model
dit_model = DiTModel(
    image_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    hidden_size=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4.0,
    learn_sigma=False,
    rngs=rngs
)

# Test forward pass
batch_size = 4
images = jax.random.normal(rngs.params(), (batch_size, 32, 32, 3))
timesteps = jnp.array([100, 200, 300, 400])
labels = jnp.array([0, 1, 2, 3])

# Predict noise
noise_pred = dit_model(images, timesteps, labels, rngs=rngs)
print(f"Noise prediction shape: {noise_pred.shape}")  # (4, 32, 32, 3)

Testing Different Model Sizes¤

# DiT-S (Small) - Fast, good for prototyping
dit_s = DiTModel(
    image_size=32,
    patch_size=4,
    hidden_size=384,
    depth=12,
    num_heads=6,
    rngs=rngs
)

# DiT-B (Base) - Balanced quality/speed
dit_b = DiTModel(
    image_size=32,
    patch_size=4,
    hidden_size=768,
    depth=12,
    num_heads=12,
    rngs=rngs
)

# DiT-L (Large) - High quality, slower
dit_l = DiTModel(
    image_size=32,
    patch_size=4,
    hidden_size=1024,
    depth=24,
    num_heads=16,
    rngs=rngs
)

Classifier-Free Guidance¤

def generate_with_cfg(model, num_samples, num_classes, guidance_scale=2.0):
    """Generate samples with classifier-free guidance."""

    # Generate class labels
    labels = jnp.arange(num_samples) % num_classes

    # Start from noise
    x = jax.random.normal(rngs.sample(), (num_samples, 32, 32, 3))

    # Reverse diffusion with CFG
    for t in reversed(range(num_timesteps)):
        t_batch = jnp.full((num_samples,), t)

        # Conditional prediction
        noise_pred_cond = model(x, t_batch, labels, rngs=rngs)

        # Unconditional prediction (with null label)
        noise_pred_uncond = model(x, t_batch, None, rngs=rngs)

        # Apply classifier-free guidance
        noise_pred = noise_pred_uncond + guidance_scale * (
            noise_pred_cond - noise_pred_uncond
        )

        # Denoise step
        x = denoise_step(x, noise_pred, t)

    return x

# Generate samples
samples = generate_with_cfg(
    dit_model,
    num_samples=16,
    num_classes=10,
    guidance_scale=2.0
)

Performance Benchmarking¤

import time

def benchmark_model(model, batch_size=32, num_iterations=100):
    """Benchmark model throughput."""

    # Warmup
    for _ in range(10):
        _ = model(images, timesteps, labels, rngs=rngs)

    # Benchmark
    start = time.time()
    for _ in range(num_iterations):
        _ = model(images, timesteps, labels, rngs=rngs)
    elapsed = time.time() - start

    samples_per_sec = (batch_size * num_iterations) / elapsed
    return samples_per_sec

# Compare model sizes
results = {
    "DiT-S": benchmark_model(dit_s),
    "DiT-B": benchmark_model(dit_b),
    "DiT-L": benchmark_model(dit_l),
}

for model_name, throughput in results.items():
    print(f"{model_name}: {throughput:.2f} samples/sec")

Features Demonstrated¤

DiffusionTransformer Backbone¤

  • Vision Transformer architecture
  • Adaptive layer normalization (adaLN)
  • Position-wise feed-forward networks
  • Multi-head self-attention

DiT Model Sizes¤

  • Small (S): Fast prototyping and testing
  • Base (B): Production-ready performance
  • Large (L): High-quality generation
  • Configurable depth, hidden size, and heads

Conditional Generation¤

  • Class-conditional generation
  • Classifier-free guidance
  • Guidance scale tuning
  • Null conditioning for unconditional mode

Patch-Based Processing¤

  • Efficient patch embeddings
  • Positional encoding strategies
  • Sequence-to-image reconstruction
  • Variable patch sizes

Performance Analysis¤

  • Throughput benchmarking
  • Memory profiling
  • Quality vs. speed trade-offs
  • Scaling behavior analysis

Experiments to Try¤

  1. Vary patch size: Try 2×2, 4×4, 8×8 patches and observe quality/speed trade-offs
  2. Modify model size: Create custom configurations between S/B/L
  3. Tune guidance scale: Experiment with CFG scales from 1.0 to 5.0
  4. Custom conditioning: Add additional conditioning (text, attributes, etc.)
  5. Training from scratch: Implement full training loop on your dataset
  6. Distillation: Train a smaller model to match larger model's quality

Next Steps¤

After understanding this example:

  1. Full Training: Implement training loop with ImageNet or custom data
  2. Custom Conditioning: Add text or multi-modal conditioning
  3. Faster Sampling: Explore DDIM, DPM-Solver, or other fast samplers
  4. Latent DiT: Apply DiT in latent space (like Stable Diffusion)
  5. Model Compression: Distillation, pruning, quantization
  6. Evaluation: FID, Inception Score, and other metrics

Troubleshooting¤

Out of Memory¤

  • Reduce model size (use DiT-S instead of DiT-L)
  • Decrease batch size
  • Use smaller images or larger patch size
  • Enable gradient checkpointing

Slow Generation¤

  • Use GPU acceleration
  • Reduce number of denoising steps (try 50-100 instead of 1000)
  • Use smaller model (DiT-S)
  • Implement faster samplers (DDIM)

Poor Sample Quality¤

  • Increase model size (DiT-B or DiT-L)
  • Tune classifier-free guidance scale
  • Increase number of denoising steps
  • Check training convergence

Patch Size Issues¤

Ensure image size is divisible by patch size:

assert image_size % patch_size == 0, "Image size must be divisible by patch size"

Additional Resources¤

Performance Comparison¤

Expected performance on a modern GPU (A100):

Model Samples/sec Memory (GB) FID (ImageNet)
DiT-S ~120 ~4 ~9.5
DiT-B ~50 ~12 ~5.3
DiT-L ~25 ~24 ~3.4
DiT-XL ~15 ~32 ~2.3

Note: Actual performance depends on hardware, image size, and implementation details.