Model Gallery¤
Workshop provides implementations of state-of-the-art generative models with 2025 research compliance, including Diffusion Transformers (DiT), SE(3) equivariant flows for molecular generation, and advanced MCMC sampling with BlackJAX integration.
-
7 Model Families
VAE, GAN, Diffusion, Flow, EBM, Autoregressive, and Geometric models with 67+ implementations
-
2025 Research Compliance
Latest architectures including DiT, StyleGAN3, SE(3) molecular flows, and score-based diffusion
-
Production Ready
Hardware-optimized, fully tested, type-safe implementations built on JAX/Flax NNX
-
Multi-Modal
Native support for images, text, audio, proteins, molecules, and 3D geometric data
Overview¤
All models in Workshop follow a unified interface and are built on JAX/Flax NNX for:
- Hardware acceleration: Automatic GPU/TPU support with XLA optimization
- Type safety: Full type annotations and protocol-based interfaces
- Composability: Mix and match components across model types
- Reproducibility: Deterministic RNG handling and comprehensive testing
- Scalability: Distributed training with data, model, and pipeline parallelism
Model Families¤
VAE Variational Autoencoders¤
Latent variable models with probabilistic encoding for representation learning and generation.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| VAE | Standard Variational Autoencoder | Gaussian latents, KL regularization | Representation learning, compression |
| β-VAE | Disentangled VAE | Controllable β parameter, beta annealing | Disentangled representations |
| β-VAE with Capacity | β-VAE with capacity control | Gradual capacity increase, controlled disentanglement | Balance reconstruction and disentanglement |
| Conditional VAE | Class-conditional VAE | Label conditioning, controlled generation | Supervised generation |
| VQ-VAE | Vector Quantized VAE | Discrete latent codes, codebook learning | Discrete representations, compression |
Quick Start:
from workshop.generative_models.factories import create_vae
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx
config = ModelConfiguration(
model_type="vae",
latent_dim=128,
input_shape=(32, 32, 3),
encoder_features=[64, 128, 256],
decoder_features=[256, 128, 64],
parameters={"beta": 1.0, "kl_weight": 1.0}
)
model = create_vae(config, rngs=nnx.Rngs(0))
Documentation:
- VAE User Guide - Complete guide with examples
- VAE API Reference - Detailed API documentation
GAN Generative Adversarial Networks¤
Adversarial training for high-quality image generation and image-to-image translation.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| DCGAN | Deep Convolutional GAN | Convolutional architecture, stable training | Image generation baseline |
| WGAN | Wasserstein GAN | Wasserstein distance, critic training | Stable training, mode coverage |
| LSGAN | Least Squares GAN | Least squares loss, improved stability | Image generation with stable training |
| StyleGAN | Style-based GAN | Style mixing, AdaIN layers | High-quality face generation |
| StyleGAN3 | Alias-free StyleGAN | Translation/rotation equivariance | Alias-free high-quality generation |
| CycleGAN | Cycle-consistent GAN | Unpaired translation, cycle loss | Image-to-image translation |
| PatchGAN | Patch-based discriminator | Local image patches, texture detail | Image-to-image tasks, super-resolution |
| Conditional GAN | Class-conditional GAN | Label conditioning | Controlled generation |
Quick Start:
from workshop.generative_models.factories import create_gan
config = ModelConfiguration(
model_type="dcgan",
input_shape=(64, 64, 3),
latent_dim=100,
generator_features=[512, 256, 128, 64],
discriminator_features=[64, 128, 256, 512],
parameters={"learning_rate_g": 2e-4, "learning_rate_d": 2e-4}
)
model = create_gan(config, rngs=nnx.Rngs(0))
Documentation:
- GAN User Guide - Complete guide with training tips
- GAN API Reference - Detailed API documentation
Diffusion Diffusion Models¤
State-of-the-art denoising diffusion models for high-quality generation.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| DDPM | Denoising Diffusion Probabilistic Models | Gaussian diffusion, noise prediction | Image generation, baseline |
| DDIM | Denoising Diffusion Implicit Models | Deterministic sampling, faster inference | Fast high-quality generation |
| Score-based | Score-based generative models | Score matching, SDE/ODE solvers | Flexible sampling strategies |
| Latent Diffusion | Latent space diffusion | VAE encoder/decoder, efficient training | High-resolution generation |
| DiT | Diffusion Transformer | Transformer backbone, scalable | Large-scale image generation |
| Stable Diffusion | Text-to-image diffusion | CLIP conditioning, latent diffusion | Text-to-image generation |
Quick Start:
from workshop.generative_models.factories import create_diffusion
config = ModelConfiguration(
model_type="ddpm",
input_shape=(32, 32, 3),
num_timesteps=1000,
backbone_type="unet",
backbone_features=[128, 256, 512],
parameters={
"beta_start": 1e-4,
"beta_end": 2e-2,
"beta_schedule": "linear"
}
)
model = create_diffusion(config, rngs=nnx.Rngs(0))
Documentation:
- Diffusion User Guide - Complete guide with sampling methods
- Diffusion API Reference - Detailed API documentation
- DiT Architecture - Diffusion Transformer details
Flow Normalizing Flows¤
Invertible transformations with tractable likelihoods for exact density estimation.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| RealNVP | Real-valued Non-Volume Preserving | Affine coupling layers, multi-scale | Density estimation baseline |
| Glow | Generative Flow | Invertible 1x1 convolutions, ActNorm | High-quality image generation |
| MAF | Masked Autoregressive Flow | Autoregressive coupling, parallel training | Flexible density estimation |
| IAF | Inverse Autoregressive Flow | Fast sampling, parallel generation | Variational inference |
| Neural Spline Flow | Spline-based coupling | Smooth transformations, expressive | High-quality density estimation |
| SE(3) Molecular Flow | Equivariant molecular flows | SE(3) symmetry, molecular generation | Drug design, molecular modeling |
| Conditional Flow | Class-conditional flows | Label conditioning | Controlled generation |
Quick Start:
from workshop.generative_models.factories import create_flow
config = ModelConfiguration(
model_type="realnvp",
input_shape=(32, 32, 3),
num_coupling_layers=8,
coupling_features=[512, 512],
parameters={"num_scales": 3, "scale_factor": 2}
)
model = create_flow(config, rngs=nnx.Rngs(0))
Documentation:
- Flow User Guide - Complete guide with coupling layers
- Flow API Reference - Detailed API documentation
- SE(3) Molecular Flows - Equivariant flows for molecules
EBM Energy-Based Models¤
Energy function learning with MCMC sampling for compositional generation.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| EBM | Energy-based model | Energy function learning, flexible | Compositional generation |
| EBM with MCMC | EBM with MCMC sampling | Langevin dynamics, HMC, NUTS | High-quality sampling |
| Persistent CD | Persistent Contrastive Divergence | Persistent chains, efficient training | Stable EBM training |
MCMC Samplers (via BlackJAX integration):
- HMC: Hamiltonian Monte Carlo
- NUTS: No-U-Turn Sampler (adaptive HMC)
- MALA: Metropolis-Adjusted Langevin Algorithm
- Langevin Dynamics: First-order gradient-based sampling
Quick Start:
from workshop.generative_models.factories import create_ebm
config = ModelConfiguration(
model_type="ebm",
input_shape=(28, 28, 1),
energy_features=[512, 256, 128],
parameters={
"mcmc_steps": 100,
"step_size": 0.01,
"sampler": "langevin"
}
)
model = create_ebm(config, rngs=nnx.Rngs(0))
Documentation:
- EBM Guide - Energy-based model details
- MCMC Sampling - Sampling algorithms
AR Autoregressive Models¤
Sequential generation with explicit likelihood for ordered data.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| PixelCNN | Autoregressive image model | Masked convolutions, pixel-by-pixel | Image generation with likelihood |
| WaveNet | Autoregressive audio model | Dilated convolutions, long context | Audio generation, TTS |
| Transformer | Transformer-based AR | Self-attention, parallel training | Text, structured sequences |
Quick Start:
from workshop.generative_models.factories import create_autoregressive
config = ModelConfiguration(
model_type="pixelcnn",
input_shape=(32, 32, 3),
num_layers=8,
hidden_channels=128,
parameters={"kernel_size": 3, "num_classes": 256}
)
model = create_autoregressive(config, rngs=nnx.Rngs(0))
Documentation:
- PixelCNN API - Image autoregressive models
- WaveNet API - Audio autoregressive models
- Transformer API - Transformer-based models
Geometric Geometric Models¤
3D structure generation with physical constraints and equivariance.
Available Models:
| Model | Description | Key Features | Use Cases |
|---|---|---|---|
| Point Cloud Generator | 3D point cloud generation | Permutation invariance | 3D object generation |
| Mesh Generator | 3D mesh generation | Vertex/face generation, deformation | 3D modeling |
| Protein Graph | Protein structure generation | Residue graphs, amino acid features | Protein design |
| Protein Point Cloud | Protein backbone generation | Cα coordinates, backbone geometry | Protein structure prediction |
| Voxel Generator | Voxel-based 3D generation | Regular 3D grid | 3D shape generation |
| Graph Generator | Graph-based generation | Node/edge features, message passing | Molecular graphs |
Quick Start:
from workshop.generative_models.factories import create_geometric
config = ModelConfiguration(
model_type="protein_point_cloud",
input_shape=(128, 3), # 128 residues, 3D coordinates
parameters={
"num_residues": 128,
"hidden_dim": 256,
"num_layers": 6
}
)
model = create_geometric(config, rngs=nnx.Rngs(0))
Documentation:
- Protein Models - Protein structure generation
- Point Cloud Models - 3D point cloud generation
- Graph Models - Graph-based generation
Model Comparison¤
Choose the right model for your task:
| Model Type | Sample Quality | Training Stability | Speed | Exact Likelihood | Best For |
|---|---|---|---|---|---|
| VAE | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ❌ (Lower bound) | Representation learning, fast sampling |
| GAN | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | ❌ | High-quality images, style transfer |
| Diffusion | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐ | ❌ | State-of-the-art generation, controllability |
| Flow | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | ✅ | Density estimation, exact inference |
| EBM | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐ | ❌ (unnormalized) | Compositional generation, flexibility |
| AR | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐ | ✅ | Sequences, explicit probabilities |
| Geometric | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | Varies | 3D structures, physical constraints |
Common Backbones¤
Workshop provides reusable backbone architectures used across model types:
U-Net¤
Widely used in diffusion models and image-to-image tasks:
from workshop.generative_models.models.common.unet import UNet
unet = UNet(
in_channels=3,
out_channels=3,
channels=[128, 256, 512, 1024],
num_res_blocks=2,
attention_resolutions=[16, 8],
rngs=nnx.Rngs(0)
)
Documentation: U-Net API
Diffusion Transformer (DiT)¤
Transformer-based backbone for diffusion models:
from workshop.generative_models.models.diffusion.dit import DiT
dit = DiT(
input_size=32,
patch_size=2,
in_channels=3,
hidden_size=768,
depth=12,
num_heads=12,
rngs=nnx.Rngs(0)
)
Documentation: DiT API
Encoders & Decoders¤
Modular encoder/decoder architectures:
- MLP Encoder/Decoder: Fully-connected networks
- CNN Encoder/Decoder: Convolutional networks for images
- Conditional Encoder/Decoder: Class-conditional variants
- ResNet Encoder/Decoder: Residual connections
Documentation: Encoders | Decoders
Conditioning Methods¤
Workshop supports multiple conditioning strategies across models:
| Method | Description | Supported Models | Use Cases |
|---|---|---|---|
| Class Conditioning | One-hot encoded labels | VAE, GAN, Diffusion, Flow | Supervised generation |
| Text Conditioning | CLIP embeddings | Diffusion, GAN | Text-to-image |
| Image Conditioning | Concatenation, cross-attention | GAN, Diffusion | Image-to-image, inpainting |
| Embedding Conditioning | Learned embeddings | All models | Flexible conditioning |
Documentation: Conditioning Guide
Model Registry¤
All models are registered in a global registry for easy instantiation:
from workshop.generative_models.models.registry import (
list_models,
get_model_class,
register_model
)
# List all available models
available = list_models()
print(f"Available models: {len(available)}")
# Get model class by name
vae_class = get_model_class("vae")
# Register custom model
from my_models import CustomVAE
register_model("custom_vae", CustomVAE)
Documentation: Model Registry
Factory Functions¤
Create models easily with factory functions:
from workshop.generative_models.factories import (
create_vae,
create_gan,
create_diffusion,
create_flow,
create_ebm,
create_autoregressive,
create_geometric,
create_model # Generic factory
)
# Generic factory (auto-detects model type)
config = ModelConfiguration(model_type="ddpm", ...)
model = create_model(config, rngs=nnx.Rngs(0))
Documentation: Factory API
Training¤
All models follow a unified training interface:
from workshop.generative_models.training import Trainer
from workshop.generative_models.core.configuration import TrainingConfiguration
training_config = TrainingConfiguration(
batch_size=128,
num_epochs=100,
optimizer={"type": "adam", "learning_rate": 1e-3},
scheduler={"type": "cosine", "warmup_steps": 1000}
)
trainer = Trainer(
model=model,
training_config=training_config,
train_dataset=train_data,
val_dataset=val_data
)
trainer.train()
Documentation:
- Training Guide - Complete training guide
- Distributed Training - Multi-GPU/TPU training
Evaluation¤
Evaluate models with modality-specific metrics:
from workshop.benchmarks import EvaluationFramework
framework = EvaluationFramework(
model=model,
modality="image",
metrics=["fid", "inception_score", "lpips"]
)
results = framework.evaluate(test_dataset)
print(results)
Documentation: Benchmarks
Examples¤
Hands-on examples for each model family:
- VAE on MNIST - Basic VAE training
- GAN on CelebA - Face generation
- Diffusion on CIFAR-10 - Image generation
- Flow on 2D data - Density estimation
- Protein generation - Geometric models
Contributing¤
Add new models to Workshop:
- Implement model following protocols in
core/protocols.py - Add to appropriate directory (vae/, gan/, diffusion/, etc.)
- Register in model registry
- Add comprehensive tests
- Document API and usage
Documentation: Contributing Guide
API Statistics¤
Current model coverage:
- Total modules: 67
- Total classes: 135
- Total functions: 482
- Model families: 7
- Conditioning methods: 4
- Sampling methods: 15+
-
User Guides
Deep dive into each model family with examples and best practices
-
API Reference
Complete API documentation for all models and components
-
Tutorials
Step-by-step tutorials for common tasks and workflows
-
Examples
Working code examples for all model types