Diffusion Transformer (DiT) Demo¤
Level: Advanced | Runtime: ~5 minutes (CPU) / ~1-2 minutes (GPU) | Format: Python + Jupyter
Overview¤
This advanced example demonstrates Diffusion Transformers (DiT), which combines the power of Vision Transformers with diffusion models. DiT represents a significant advancement in diffusion model architectures, using transformer blocks instead of traditional U-Net architectures for the denoising process.
What You'll Learn¤
- DiffusionTransformer backbone architecture
- DiT model sizes and scaling (S, B, L, XL configurations)
- Conditional generation with classifier-free guidance (CFG)
- Patch-based image processing
- Performance benchmarking across model sizes
- Advanced sampling techniques
Files¤
- Python Script:
examples/generative_models/diffusion/dit_demo.py - Jupyter Notebook:
examples/generative_models/diffusion/dit_demo.ipynb
Quick Start¤
Run the Python Script¤
# Activate environment
source activate.sh
# Run the example
python examples/generative_models/diffusion/dit_demo.py
Run the Jupyter Notebook¤
# Activate environment
source activate.sh
# Launch Jupyter
jupyter lab examples/generative_models/diffusion/dit_demo.ipynb
Key Concepts¤
Diffusion Transformer Architecture¤
DiT replaces the traditional U-Net backbone with a Vision Transformer:
- Patch Embedding: Images are divided into patches and linearly embedded
- Positional Encoding: Added to maintain spatial information
- Transformer Blocks: Self-attention and feed-forward layers
- Adaptive Layer Normalization: Conditioned on timestep and class labels
Classifier-Free Guidance (CFG)¤
CFG enables stronger conditional generation by learning both conditional and unconditional models simultaneously:
Where:
- \(c\) is the conditioning (e.g., class label)
- \(\emptyset\) is the unconditional case
- \(s\) is the guidance scale (higher = stronger conditioning)
Model Scaling¤
DiT comes in different sizes, trading off quality and speed:
| Model | Hidden Dim | Depth | Heads | Parameters |
|---|---|---|---|---|
| DiT-S | 384 | 12 | 6 | ~33M |
| DiT-B | 768 | 12 | 12 | ~130M |
| DiT-L | 1024 | 24 | 16 | ~458M |
| DiT-XL | 1152 | 28 | 16 | ~675M |
Patch-Based Processing¤
Images are processed as sequences of patches:
- Divide image into non-overlapping patches (e.g., 16×16)
- Flatten each patch into a vector
- Apply linear projection
- Add positional embeddings
- Process through transformer blocks
Code Structure¤
The example demonstrates 7 major sections:
- Import Dependencies: Setting up the environment
- Test DiT Components: Verifying backbone and full model
- Test Different Model Sizes: Comparing S, B, L configurations
- Conditional Generation: Using classifier-free guidance
- Visualization: Displaying generated samples
- Performance Benchmark: Measuring throughput across sizes
- Summary: Key takeaways and next steps
Example Code¤
Creating DiT Model¤
from flax import nnx
from workshop.generative_models.models.backbones import DiffusionTransformer
from workshop.generative_models.models.diffusion import DiTModel
# Initialize RNG
rngs = nnx.Rngs(42)
# Create DiT-S model
dit_model = DiTModel(
image_size=32,
patch_size=4,
in_channels=3,
num_classes=10,
hidden_size=384,
depth=12,
num_heads=6,
mlp_ratio=4.0,
learn_sigma=False,
rngs=rngs
)
# Test forward pass
batch_size = 4
images = jax.random.normal(rngs.params(), (batch_size, 32, 32, 3))
timesteps = jnp.array([100, 200, 300, 400])
labels = jnp.array([0, 1, 2, 3])
# Predict noise
noise_pred = dit_model(images, timesteps, labels, rngs=rngs)
print(f"Noise prediction shape: {noise_pred.shape}") # (4, 32, 32, 3)
Testing Different Model Sizes¤
# DiT-S (Small) - Fast, good for prototyping
dit_s = DiTModel(
image_size=32,
patch_size=4,
hidden_size=384,
depth=12,
num_heads=6,
rngs=rngs
)
# DiT-B (Base) - Balanced quality/speed
dit_b = DiTModel(
image_size=32,
patch_size=4,
hidden_size=768,
depth=12,
num_heads=12,
rngs=rngs
)
# DiT-L (Large) - High quality, slower
dit_l = DiTModel(
image_size=32,
patch_size=4,
hidden_size=1024,
depth=24,
num_heads=16,
rngs=rngs
)
Classifier-Free Guidance¤
def generate_with_cfg(model, num_samples, num_classes, guidance_scale=2.0):
"""Generate samples with classifier-free guidance."""
# Generate class labels
labels = jnp.arange(num_samples) % num_classes
# Start from noise
x = jax.random.normal(rngs.sample(), (num_samples, 32, 32, 3))
# Reverse diffusion with CFG
for t in reversed(range(num_timesteps)):
t_batch = jnp.full((num_samples,), t)
# Conditional prediction
noise_pred_cond = model(x, t_batch, labels, rngs=rngs)
# Unconditional prediction (with null label)
noise_pred_uncond = model(x, t_batch, None, rngs=rngs)
# Apply classifier-free guidance
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# Denoise step
x = denoise_step(x, noise_pred, t)
return x
# Generate samples
samples = generate_with_cfg(
dit_model,
num_samples=16,
num_classes=10,
guidance_scale=2.0
)
Performance Benchmarking¤
import time
def benchmark_model(model, batch_size=32, num_iterations=100):
"""Benchmark model throughput."""
# Warmup
for _ in range(10):
_ = model(images, timesteps, labels, rngs=rngs)
# Benchmark
start = time.time()
for _ in range(num_iterations):
_ = model(images, timesteps, labels, rngs=rngs)
elapsed = time.time() - start
samples_per_sec = (batch_size * num_iterations) / elapsed
return samples_per_sec
# Compare model sizes
results = {
"DiT-S": benchmark_model(dit_s),
"DiT-B": benchmark_model(dit_b),
"DiT-L": benchmark_model(dit_l),
}
for model_name, throughput in results.items():
print(f"{model_name}: {throughput:.2f} samples/sec")
Features Demonstrated¤
DiffusionTransformer Backbone¤
- Vision Transformer architecture
- Adaptive layer normalization (adaLN)
- Position-wise feed-forward networks
- Multi-head self-attention
DiT Model Sizes¤
- Small (S): Fast prototyping and testing
- Base (B): Production-ready performance
- Large (L): High-quality generation
- Configurable depth, hidden size, and heads
Conditional Generation¤
- Class-conditional generation
- Classifier-free guidance
- Guidance scale tuning
- Null conditioning for unconditional mode
Patch-Based Processing¤
- Efficient patch embeddings
- Positional encoding strategies
- Sequence-to-image reconstruction
- Variable patch sizes
Performance Analysis¤
- Throughput benchmarking
- Memory profiling
- Quality vs. speed trade-offs
- Scaling behavior analysis
Experiments to Try¤
- Vary patch size: Try 2×2, 4×4, 8×8 patches and observe quality/speed trade-offs
- Modify model size: Create custom configurations between S/B/L
- Tune guidance scale: Experiment with CFG scales from 1.0 to 5.0
- Custom conditioning: Add additional conditioning (text, attributes, etc.)
- Training from scratch: Implement full training loop on your dataset
- Distillation: Train a smaller model to match larger model's quality
Next Steps¤
After understanding this example:
- Full Training: Implement training loop with ImageNet or custom data
- Custom Conditioning: Add text or multi-modal conditioning
- Faster Sampling: Explore DDIM, DPM-Solver, or other fast samplers
- Latent DiT: Apply DiT in latent space (like Stable Diffusion)
- Model Compression: Distillation, pruning, quantization
- Evaluation: FID, Inception Score, and other metrics
Troubleshooting¤
Out of Memory¤
- Reduce model size (use DiT-S instead of DiT-L)
- Decrease batch size
- Use smaller images or larger patch size
- Enable gradient checkpointing
Slow Generation¤
- Use GPU acceleration
- Reduce number of denoising steps (try 50-100 instead of 1000)
- Use smaller model (DiT-S)
- Implement faster samplers (DDIM)
Poor Sample Quality¤
- Increase model size (DiT-B or DiT-L)
- Tune classifier-free guidance scale
- Increase number of denoising steps
- Check training convergence
Patch Size Issues¤
Ensure image size is divisible by patch size:
Additional Resources¤
- Paper: Scalable Diffusion Models with Transformers
- Paper: Classifier-Free Diffusion Guidance
- Workshop Diffusion Guide: Diffusion Models Guide
- API Reference: DiffusionTransformer API
- Vision Transformer: An Image is Worth 16x16 Words
Related Examples¤
- Simple Diffusion - Diffusion basics
- Simple EBM - Energy-based models
- Advanced Diffusion - More diffusion techniques
Performance Comparison¤
Expected performance on a modern GPU (A100):
| Model | Samples/sec | Memory (GB) | FID (ImageNet) |
|---|---|---|---|
| DiT-S | ~120 | ~4 | ~9.5 |
| DiT-B | ~50 | ~12 | ~5.3 |
| DiT-L | ~25 | ~24 | ~3.4 |
| DiT-XL | ~15 | ~32 | ~2.3 |
Note: Actual performance depends on hardware, image size, and implementation details.