Workshop: Generative Modeling Research LibraryΒ€
A research-focused modular generative modeling library built on JAX/Flax NNX, providing implementations of various state-of-the-art generative models with multi-modal support and scientific computing focus. Workshop aims to reach production-ready status through rigorous testing and continuous development.
π― OverviewΒ€
Workshop is a research-focused library for generative modeling, working towards production deployment capabilities. Built on JAX and Flax NNX, it provides modular, composable implementations of various state-of-the-art generative model architectures including VAEs, GANs, Diffusion Models, Flow Models, Energy-Based Models, and Autoregressive Models.
Why Workshop?Β€
-
State-of-the-Art Models
Implementations of the latest generative models with 2025 research compliance including DiT, SE(3) flows, and advanced MCMC sampling
-
Research-Focused with Production Goals
Hardware-aware optimization, distributed training, mixed precision, and deployment pipelines validated through 2150+ comprehensive tests
-
Multi-Modal Support
Native support for images, text, audio, proteins, and multi-modal data with specialized evaluation metrics
-
Scalable Architecture
From single GPU to multi-node distributed training with FSDP, tensor parallelism, and pipeline parallelism
β¨ Key FeaturesΒ€
Generative ModelsΒ€
VAE Variational Autoencoders
- VAE - Base variational autoencoder with latent variable modeling
- Ξ²-VAE - Disentanglement learning with controlled KL weighting
- Ξ²-VAE with Capacity - Capacity control mechanism for improved training
- VQ-VAE - Vector quantized VAE with discrete latent representations
- Conditional VAE - Class-conditional generation
GAN Generative Adversarial Networks
- GAN - Base generative adversarial network
- DCGAN - Deep convolutional architecture for stable training
- WGAN - Wasserstein distance with gradient penalty
- LSGAN - Least squares loss for stable training
- Conditional GAN - Class-conditional generation
- CycleGAN - Unpaired image-to-image translation
- PatchGAN - Patch-based discrimination
Diffusion Diffusion Models
- DDPM - Denoising diffusion probabilistic models
- DDIM - Fast sampling with deterministic generation
- Score-based - Score matching and SDE-based generation
- Latent Diffusion - Efficient diffusion in compressed latent space
- DiT - Diffusion Transformer for scalable generation
- Classifier-Free Guidance - Strong conditioning without separate classifier
Flow Normalizing Flows
- RealNVP - Real-valued Non-Volume Preserving flow with coupling layers
- Glow - Multi-scale flow with ActNorm and invertible convolutions
- MAF - Masked Autoregressive Flow for fast density estimation
- IAF - Inverse Autoregressive Flow for fast sampling
- Neural Spline - Rational quadratic spline transformations
EBM Energy-Based Models
- EBM with MCMC - Energy-based model with MCMC sampling
- Langevin Dynamics - Sampling via Langevin MCMC
- Contrastive Divergence - Persistent contrastive divergence training
Autoregressive Autoregressive Models
- PixelCNN - Sequential image generation with masked convolutions
- WaveNet - Audio generation with dilated causal convolutions
- Transformer - Transformer-based autoregressive generation
Geometric Geometric Models
- Point Cloud Generation - 3D point cloud modeling
- Mesh Generation - Mesh generation with deformation
- Protein Structure - Protein structure generation with physical constraints
- SE(3) Molecular Flows - Equivariant flows for molecular generation
Multi-Modal SupportΒ€
Image Image Generation
- Datasets: CIFAR-10/100, ImageNet, FFHQ, CelebA
- Metrics: FID, Inception Score, LPIPS, Precision/Recall
- Preprocessing: Normalization, augmentation, multi-crop
Text Text Generation
- Datasets: BookCorpus, Wikipedia, custom corpora
- Metrics: Perplexity, BLEU, ROUGE, diversity scores
- Tokenization: BPE, SentencePiece, character-level
Audio Audio Generation
- Datasets: LibriSpeech, VCTK, music datasets
- Metrics: MCD, F0 RMSE, spectral analysis
- Processing: STFT, mel-spectrograms, raw waveform
Protein Protein Modeling
- Structure generation with physical constraints
- Bond length and angle constraints
- Amino acid features and properties
- SE(3) equivariant architectures
Multi-Modal Multi-Modal
- Image-text paired data (COCO)
- Cross-modal generation and retrieval
- Unified representations
- Contrastive learning
Training & OptimizationΒ€
- Model-Specific Trainers: Optimized training loops for VAE, GAN, Diffusion, Flow, EBM
- Distributed Training: Data parallel, model parallel, FSDP for large-scale training
- Advanced Optimizers: Adam, AdamW, Lion, Adafactor with learning rate schedules
- Mixed Precision: BF16/FP16 training for 2x speedup on modern hardware
- Gradient Accumulation: Train large models with limited memory
- Checkpointing: Save/resume training with Orbax integration
- Callbacks: Logging, early stopping, profiling, visualization
Evaluation & BenchmarkingΒ€
Image Metrics:
- FID (FrΓ©chet Inception Distance)
- Inception Score
- LPIPS (Perceptual distance)
- Precision and Recall
- SSIM (Structural similarity)
Text Metrics:
- Perplexity
- BLEU (Translation quality)
- ROUGE (Summarization)
- Diversity metrics
- Semantic coherence
Audio Metrics:
- MCD (Mel-Cepstral Distortion)
- F0 RMSE (Pitch error)
- Spectral convergence
- STFT distance
- Perceptual quality
Disentanglement:
- MIG (Mutual Information Gap)
- SAP (Separated Attribute Predictability)
- DCI (Disentanglement, Completeness, Informativeness)
- FactorVAE score
- Benchmark Suites: Standardized evaluation protocols for reproducible research
- Automatic Metric Selection: Choose appropriate metrics based on modality
- Visualization: Sample grids, latent space plots, protein structure rendering
Inference & SamplingΒ€
- Sampling Methods: Ancestral, temperature scaling, top-k, nucleus sampling
- Diffusion Sampling: DDPM, DDIM, ODE solvers, SDE solvers
- MCMC Sampling: HMC, NUTS, MALA, Langevin dynamics with BlackJAX
- Guidance: Classifier-free guidance, classifier guidance
- Fast Inference: Caching, JIT compilation, quantization
- Production Serving: REST API, gRPC, batching, streaming
π RequirementsΒ€
- Python 3.10+
- JAX 0.4.35+ with GPU/TPU support
- Flax NNX (latest version)
- CUDA 12.0+ (for GPU acceleration)
- 8GB+ RAM (16GB+ recommended)
- NVIDIA GPU with compute capability 7.0+ (optional but recommended)
π Quick StartΒ€
InstallationΒ€
See the Installation Guide for detailed setup instructions.
Your First Model: Train a VAEΒ€
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.factories import create_vae
from workshop.generative_models.core.configuration import ModelConfiguration
# Create configuration
config = ModelConfiguration(
model_type="vae",
latent_dim=128,
input_shape=(32, 32, 3),
encoder_features=[64, 128, 256],
decoder_features=[256, 128, 64],
parameters={
"beta": 1.0, # Ξ²-VAE parameter
"kl_weight": 1.0,
"reconstruction_loss": "mse"
}
)
# Initialize model
rngs = nnx.Rngs(0)
model = create_vae(config, rngs=rngs)
# Prepare data (example with random data)
batch_size = 16
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, 32, 32, 3))
# Forward pass: encode, reparameterize, decode
outputs = model(x, rngs=rngs)
reconstructed = outputs["reconstructed"]
mean, logvar = outputs["mean"], outputs["logvar"]
z = outputs["z"]
# Compute loss
loss_dict = model.loss_fn(x=x, outputs=outputs)
print(f"Total loss: {loss_dict['total_loss']:.4f}")
print(f"Reconstruction loss: {loss_dict['reconstruction_loss']:.4f}")
print(f"KL loss: {loss_dict['kl_loss']:.4f}")
# Generate new samples from prior
samples = model.sample(num_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}")
Train a Diffusion ModelΒ€
from workshop.generative_models.factories import create_diffusion
# Create DDPM configuration
config = ModelConfiguration(
model_type="ddpm",
input_shape=(32, 32, 3),
num_timesteps=1000,
backbone_type="unet",
backbone_features=[128, 256, 512],
parameters={
"beta_start": 1e-4,
"beta_end": 2e-2,
"beta_schedule": "linear"
}
)
# Create model
model = create_diffusion(config, rngs=rngs)
# Training: add noise and predict it
timesteps = jax.random.randint(jax.random.PRNGKey(1), (batch_size,), 0, 1000)
noise = jax.random.normal(jax.random.PRNGKey(2), x.shape)
noisy_x = model.forward_diffusion(x, timesteps, rngs=rngs)
# Predict noise
predicted_noise_dict = model(noisy_x, timesteps, rngs=rngs)
predicted_noise = predicted_noise_dict["predicted_noise"]
# Compute loss
loss = jnp.mean((noise - predicted_noise) ** 2)
print(f"Diffusion loss: {loss:.4f}")
# Sampling: generate new images
samples = model.sample(num_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}")
Next StepsΒ€
-
Learn the Basics
Understand generative modeling concepts and Workshop architecture
-
Complete Tutorial
Build your first generative model from scratch with step-by-step guidance
-
Model Guides
Deep dive into VAEs, GANs, Diffusion, Flows, and more
-
Advanced Topics
Distributed training, fine-tuning, protein modeling, and deployment strategies
ποΈ ArchitectureΒ€
System OverviewΒ€
graph TB
subgraph Input["Data Pipeline"]
A1[Image Data]
A2[Text Data]
A3[Audio Data]
A4[Protein Data]
end
subgraph Config["Configuration System"]
B1[ModelConfiguration]
B2[TrainingConfig]
B3[DataConfig]
end
subgraph Factory["Model Factory"]
C1[create_vae]
C2[create_gan]
C3[create_diffusion]
C4[create_flow]
C5[create_ebm]
end
subgraph Models["Generative Models"]
D1[VAE Models]
D2[GAN Models]
D3[Diffusion Models]
D4[Flow Models]
D5[EBM Models]
end
subgraph Training["Training System"]
E1[VAE Trainer]
E2[GAN Trainer]
E3[Diffusion Trainer]
E4[Flow Trainer]
E5[EBM Trainer]
end
subgraph Eval["Evaluation"]
F1[Metrics]
F2[Benchmarks]
F3[Visualization]
end
subgraph Deploy["Inference & Serving"]
G1[Sampling]
G2[Generation]
G3[API Serving]
end
A1 --> B1
A2 --> B1
A3 --> B1
A4 --> B1
B1 --> C1
B1 --> C2
B1 --> C3
B1 --> C4
B1 --> C5
C1 --> D1
C2 --> D2
C3 --> D3
C4 --> D4
C5 --> D5
D1 --> E1
D2 --> E2
D3 --> E3
D4 --> E4
D5 --> E5
E1 --> F1
E2 --> F1
E3 --> F1
E4 --> F1
E5 --> F1
F1 --> F2
F2 --> F3
D1 --> G1
D2 --> G1
D3 --> G1
D4 --> G1
D5 --> G1
G1 --> G2
G2 --> G3
Component StructureΒ€
Workshop Architecture
βββ Core Components
β βββ Protocols & Interfaces β Type-safe abstractions (GenerativeModel, TrainerProtocol)
β βββ Configuration System β Pydantic-based unified configs with validation
β βββ Device Manager β Hardware-aware GPU/CPU/TPU handling
β βββ Loss Functions β Modular composable losses (reconstruction, adversarial, divergence)
β βββ Sampling Methods β Ancestral, diffusion, MCMC, ODE/SDE solvers
β βββ Evaluation Framework β Metrics, benchmarks, protocols
βββ Generative Models
β βββ VAE Family β VAE, Ξ²-VAE, VQ-VAE, Conditional VAE
β βββ GAN Family β DCGAN, WGAN, StyleGAN, CycleGAN, PatchGAN
β βββ Diffusion Models β DDPM, DDIM, Score-based, DiT, Latent Diffusion
β βββ Normalizing Flows β RealNVP, Glow, MAF, IAF, Neural Spline
β βββ Energy-Based Models β Contrastive divergence, Langevin, MCMC
β βββ Autoregressive Models β PixelCNN, WaveNet, Transformer
β βββ Geometric Models β Protein, 3D shapes, SE(3) flows
βββ Training System
β βββ Model-Specific Trainers β Optimized for each model type
β βββ Distributed Training β Data/model parallel, FSDP, sharding
β βββ Optimization β Optimizers, schedules, gradient accumulation
β βββ Callbacks β Logging, checkpointing, early stopping, profiling
β βββ Fine-Tuning β LoRA, prefix tuning, RLHF
βββ Data Pipeline
β βββ Multi-Modal Loaders β Image, text, audio, protein, tabular, timeseries
β βββ Preprocessing β Augmentation, normalization, tokenization
β βββ Streaming Support β Large-scale dataset handling
β βββ Custom Datasets β Extensible dataset framework
βββ Inference & Serving
β βββ Sampling Strategies β Multiple sampling methods per model type
β βββ Generation Pipeline β End-to-end inference with batching
β βββ Optimization β Quantization, pruning, compilation
β βββ Serving Infrastructure β REST/gRPC APIs, monitoring
βββ Evaluation & Benchmarking
βββ Image Metrics β FID, IS, LPIPS, Precision/Recall
βββ Text Metrics β Perplexity, BLEU, ROUGE, diversity
βββ Audio Metrics β MCD, spectral analysis, perceptual quality
βββ Disentanglement β MIG, SAP, DCI, FactorVAE
βββ Benchmark Suites β Standardized protocols across modalities
π Model Types OverviewΒ€
| Model Type | Use Case | Pros | Cons | Best For |
|---|---|---|---|---|
| VAE | Latent representation learning | Stable training, interpretable latents, fast sampling | Lower sample quality | Data compression, latent space exploration, representation learning |
| GAN | High-quality image generation | Best sample quality, mode coverage | Training instability, mode collapse | Image synthesis, style transfer, super-resolution |
| Diffusion | State-of-the-art generation | Excellent quality, stable training, controllable | Slow sampling, memory intensive | Image/audio generation, inpainting, super-resolution |
| Flow | Exact likelihood modeling | Tractable likelihood, invertible | Architecture constraints, less flexible | Density estimation, exact inference, anomaly detection |
| EBM | Energy-based modeling | Flexible, composable | Expensive sampling, training complexity | Compositional generation, constraint satisfaction |
| AR | Sequential generation | Explicit likelihood, interpretable | Sequential generation (slow) | Text, ordered sequences, explicit probability |
| Geometric | 3D structure generation | Physical constraints, equivariance | Domain-specific, complex | Protein design, molecular generation, 3D modeling |
π Performance & BenchmarksΒ€
TODO: Comprehensive benchmark comparisons are planned and will be added once systematic evaluation experiments are completed. The benchmark framework is implemented and ready for evaluation.
Workshop includes a comprehensive benchmarking framework for evaluating generative models. See Benchmarks for:
- Evaluation protocols and metrics
- Benchmark suite implementations
- Guidelines for running benchmarks
- Framework for comparing models
Planned Benchmark CoverageΒ€
- Image Generation: FID, IS, LPIPS, Precision/Recall on CIFAR-10, CelebA, ImageNet
- Text Generation: Perplexity, BLEU, ROUGE on WikiText, BookCorpus
- Audio Generation: MCD, spectral metrics on VCTK, LibriSpeech
- Geometric Models: Structure quality metrics for proteins and molecules
π DocumentationΒ€
Getting StartedΒ€
- Installation Guide - Complete setup instructions
- Quickstart Tutorial - 5-minute introduction
- Core Concepts - Understanding the architecture
- Your First Model - Build a VAE from scratch
User GuidesΒ€
- VAE Guide - Complete VAE tutorial
- GAN Guide - GAN training and tips
- Diffusion Guide - Diffusion models deep dive
- Flow Guide - Normalizing flows
- EBM Guide - Energy-Based Models practical guide
- Autoregressive Guide - Autoregressive models practical guide
- Data Pipeline - Loading and preprocessing
- Training Guide - Training workflows
- Evaluation Guide - Metrics and benchmarks
API ReferenceΒ€
- Core API - Base classes and protocols
- Models API - Model implementations
- Training API - Training systems
- Data API - Data loading
- Inference API - Generation and sampling
Advanced TopicsΒ€
- Distributed Training - Multi-GPU/TPU
- Fine-Tuning - LoRA, RLHF, transfer learning
- Protein Modeling - Specialized extensions
- Deployment Strategies - Serving approaches
π§ͺ Testing & QualityΒ€
Workshop maintains high code quality standards:
- Test Coverage: Growing coverage across all modules
- Test Suite: 2150+ tests passing (unit, integration, end-to-end)
- Type Safety: Full type annotations with Pyright
- Code Quality: Ruff for linting and formatting
- CI/CD: Automated testing on CPU and GPU
# Run complete test suite
pytest tests/ -v
# Run with coverage
pytest --cov=src/workshop --cov-report=html
# Type checking
pyright src/
# Code quality
ruff check src/
ruff format src/
See Testing Guide for details.
π€ ContributingΒ€
We welcome contributions! Workshop is an open-source project that benefits from community involvement.
How to ContributeΒ€
- Fork the repository on GitHub
- Create a feature branch:
git checkout -b feature/amazing-feature - Make your changes with clear, documented code
- Add tests for new functionality
- Run quality checks:
pre-commit run --all-files - Commit changes:
git commit -m 'Add amazing feature' - Push to branch:
git push origin feature/amazing-feature - Open a Pull Request with description
Contribution AreasΒ€
- π Bug fixes and issue reports
- β¨ New features and model implementations
- π Documentation improvements
- π§ͺ Tests and benchmarks
- π¨ Examples and tutorials
- π§ Performance optimizations
See Contributing Guide for detailed guidelines.
π CitationΒ€
If you use Workshop in your research, please cite:
@software{workshop_2025,
title = {Workshop: Generative Modeling Research Library},
author = {Shafiei, Mahdi and contributors},
year = {2025},
url = {https://github.com/mahdi-shafiei/workshop},
version = {0.1.0}
}
π AcknowledgmentsΒ€
Workshop is built on top of excellent open-source projects:
- JAX - High-performance numerical computing
- Flax - Neural network library with NNX API
- Optax - Gradient processing and optimization
- Orbax - Checkpointing and serialization
- BlackJAX - MCMC sampling algorithms
Inspired by research from:
- VAE: Kingma & Welling (2013) - Auto-Encoding Variational Bayes
- GAN: Goodfellow et al. (2014) - Generative Adversarial Networks
- Diffusion: Ho et al. (2020) - Denoising Diffusion Probabilistic Models
- DiT: Peebles & Xie (2023) - Scalable Diffusion Models with Transformers
- Flows: Dinh et al. (2016), Kingma & Dhariwal (2018) - Normalizing Flows
- SE(3) Flows: KΓΆhler et al. (2020) - Equivariant Flows for Molecular Graphs
π LicenseΒ€
This project is licensed under the MIT License - see the LICENSE file for details.
Documentation | GitHub | Issues | Discussions
Made with β€οΈ by the Workshop community