Skip to content

Workshop: Generative Modeling Research LibraryΒ€

License: MIT Python 3.10+ JAX Flax Code style: ruff

A research-focused modular generative modeling library built on JAX/Flax NNX, providing implementations of various state-of-the-art generative models with multi-modal support and scientific computing focus. Workshop aims to reach production-ready status through rigorous testing and continuous development.

🎯 Overview€

Workshop is a research-focused library for generative modeling, working towards production deployment capabilities. Built on JAX and Flax NNX, it provides modular, composable implementations of various state-of-the-art generative model architectures including VAEs, GANs, Diffusion Models, Flow Models, Energy-Based Models, and Autoregressive Models.

Why Workshop?Β€

  • State-of-the-Art Models


    Implementations of the latest generative models with 2025 research compliance including DiT, SE(3) flows, and advanced MCMC sampling

    Model Gallery

  • Research-Focused with Production Goals


    Hardware-aware optimization, distributed training, mixed precision, and deployment pipelines validated through 2150+ comprehensive tests

    Getting Started

  • Multi-Modal Support


    Native support for images, text, audio, proteins, and multi-modal data with specialized evaluation metrics

    Modalities Guide

  • Scalable Architecture


    From single GPU to multi-node distributed training with FSDP, tensor parallelism, and pipeline parallelism

    Scaling Guide

✨ Key Features€

Generative ModelsΒ€

VAE Variational Autoencoders

  • VAE - Base variational autoencoder with latent variable modeling
  • Ξ²-VAE - Disentanglement learning with controlled KL weighting
  • Ξ²-VAE with Capacity - Capacity control mechanism for improved training
  • VQ-VAE - Vector quantized VAE with discrete latent representations
  • Conditional VAE - Class-conditional generation

GAN Generative Adversarial Networks

  • GAN - Base generative adversarial network
  • DCGAN - Deep convolutional architecture for stable training
  • WGAN - Wasserstein distance with gradient penalty
  • LSGAN - Least squares loss for stable training
  • Conditional GAN - Class-conditional generation
  • CycleGAN - Unpaired image-to-image translation
  • PatchGAN - Patch-based discrimination

Diffusion Diffusion Models

  • DDPM - Denoising diffusion probabilistic models
  • DDIM - Fast sampling with deterministic generation
  • Score-based - Score matching and SDE-based generation
  • Latent Diffusion - Efficient diffusion in compressed latent space
  • DiT - Diffusion Transformer for scalable generation
  • Classifier-Free Guidance - Strong conditioning without separate classifier

Flow Normalizing Flows

  • RealNVP - Real-valued Non-Volume Preserving flow with coupling layers
  • Glow - Multi-scale flow with ActNorm and invertible convolutions
  • MAF - Masked Autoregressive Flow for fast density estimation
  • IAF - Inverse Autoregressive Flow for fast sampling
  • Neural Spline - Rational quadratic spline transformations

EBM Energy-Based Models

Autoregressive Autoregressive Models

  • PixelCNN - Sequential image generation with masked convolutions
  • WaveNet - Audio generation with dilated causal convolutions
  • Transformer - Transformer-based autoregressive generation

Geometric Geometric Models

Multi-Modal SupportΒ€

Image Image Generation

  • Datasets: CIFAR-10/100, ImageNet, FFHQ, CelebA
  • Metrics: FID, Inception Score, LPIPS, Precision/Recall
  • Preprocessing: Normalization, augmentation, multi-crop

Text Text Generation

  • Datasets: BookCorpus, Wikipedia, custom corpora
  • Metrics: Perplexity, BLEU, ROUGE, diversity scores
  • Tokenization: BPE, SentencePiece, character-level

Audio Audio Generation

  • Datasets: LibriSpeech, VCTK, music datasets
  • Metrics: MCD, F0 RMSE, spectral analysis
  • Processing: STFT, mel-spectrograms, raw waveform

Protein Protein Modeling

  • Structure generation with physical constraints
  • Bond length and angle constraints
  • Amino acid features and properties
  • SE(3) equivariant architectures

Multi-Modal Multi-Modal

  • Image-text paired data (COCO)
  • Cross-modal generation and retrieval
  • Unified representations
  • Contrastive learning

Training & OptimizationΒ€

JAX Flax NNX Optax Orbax
  • Model-Specific Trainers: Optimized training loops for VAE, GAN, Diffusion, Flow, EBM
  • Distributed Training: Data parallel, model parallel, FSDP for large-scale training
  • Advanced Optimizers: Adam, AdamW, Lion, Adafactor with learning rate schedules
  • Mixed Precision: BF16/FP16 training for 2x speedup on modern hardware
  • Gradient Accumulation: Train large models with limited memory
  • Checkpointing: Save/resume training with Orbax integration
  • Callbacks: Logging, early stopping, profiling, visualization

Evaluation & BenchmarkingΒ€

Image Metrics:

  • FID (FrΓ©chet Inception Distance)
  • Inception Score
  • LPIPS (Perceptual distance)
  • Precision and Recall
  • SSIM (Structural similarity)

Text Metrics:

  • Perplexity
  • BLEU (Translation quality)
  • ROUGE (Summarization)
  • Diversity metrics
  • Semantic coherence

Audio Metrics:

  • MCD (Mel-Cepstral Distortion)
  • F0 RMSE (Pitch error)
  • Spectral convergence
  • STFT distance
  • Perceptual quality

Disentanglement:

  • MIG (Mutual Information Gap)
  • SAP (Separated Attribute Predictability)
  • DCI (Disentanglement, Completeness, Informativeness)
  • FactorVAE score
  • Benchmark Suites: Standardized evaluation protocols for reproducible research
  • Automatic Metric Selection: Choose appropriate metrics based on modality
  • Visualization: Sample grids, latent space plots, protein structure rendering

Inference & SamplingΒ€

Ancestral Diffusion MCMC ODE/SDE
  • Sampling Methods: Ancestral, temperature scaling, top-k, nucleus sampling
  • Diffusion Sampling: DDPM, DDIM, ODE solvers, SDE solvers
  • MCMC Sampling: HMC, NUTS, MALA, Langevin dynamics with BlackJAX
  • Guidance: Classifier-free guidance, classifier guidance
  • Fast Inference: Caching, JIT compilation, quantization
  • Production Serving: REST API, gRPC, batching, streaming

πŸ“‹ RequirementsΒ€

  • Python 3.10+
  • JAX 0.4.35+ with GPU/TPU support
  • Flax NNX (latest version)
  • CUDA 12.0+ (for GPU acceleration)
  • 8GB+ RAM (16GB+ recommended)
  • NVIDIA GPU with compute capability 7.0+ (optional but recommended)

πŸš€ Quick StartΒ€

InstallationΒ€

# CPU-only version
pip install workshop-generative

# With GPU support (CUDA 12.0+)
pip install workshop-generative[cuda]
# Clone repository
git clone https://github.com/mahdi-shafiei/workshop.git
cd workshop

# Install with uv (recommended)
uv sync --all-extras

# Or with pip
pip install -e '.[dev]'
# Pull latest image
docker pull ghcr.io/mahdi-shafiei/workshop:latest

# Run with GPU support
docker run --gpus all -it ghcr.io/mahdi-shafiei/workshop:latest

See the Installation Guide for detailed setup instructions.

Your First Model: Train a VAEΒ€

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.factories import create_vae
from workshop.generative_models.core.configuration import ModelConfiguration

# Create configuration
config = ModelConfiguration(
    model_type="vae",
    latent_dim=128,
    input_shape=(32, 32, 3),
    encoder_features=[64, 128, 256],
    decoder_features=[256, 128, 64],
    parameters={
        "beta": 1.0,  # Ξ²-VAE parameter
        "kl_weight": 1.0,
        "reconstruction_loss": "mse"
    }
)

# Initialize model
rngs = nnx.Rngs(0)
model = create_vae(config, rngs=rngs)

# Prepare data (example with random data)
batch_size = 16
x = jax.random.normal(jax.random.PRNGKey(0), (batch_size, 32, 32, 3))

# Forward pass: encode, reparameterize, decode
outputs = model(x, rngs=rngs)
reconstructed = outputs["reconstructed"]
mean, logvar = outputs["mean"], outputs["logvar"]
z = outputs["z"]

# Compute loss
loss_dict = model.loss_fn(x=x, outputs=outputs)
print(f"Total loss: {loss_dict['total_loss']:.4f}")
print(f"Reconstruction loss: {loss_dict['reconstruction_loss']:.4f}")
print(f"KL loss: {loss_dict['kl_loss']:.4f}")

# Generate new samples from prior
samples = model.sample(num_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}")

Train a Diffusion ModelΒ€

from workshop.generative_models.factories import create_diffusion

# Create DDPM configuration
config = ModelConfiguration(
    model_type="ddpm",
    input_shape=(32, 32, 3),
    num_timesteps=1000,
    backbone_type="unet",
    backbone_features=[128, 256, 512],
    parameters={
        "beta_start": 1e-4,
        "beta_end": 2e-2,
        "beta_schedule": "linear"
    }
)

# Create model
model = create_diffusion(config, rngs=rngs)

# Training: add noise and predict it
timesteps = jax.random.randint(jax.random.PRNGKey(1), (batch_size,), 0, 1000)
noise = jax.random.normal(jax.random.PRNGKey(2), x.shape)
noisy_x = model.forward_diffusion(x, timesteps, rngs=rngs)

# Predict noise
predicted_noise_dict = model(noisy_x, timesteps, rngs=rngs)
predicted_noise = predicted_noise_dict["predicted_noise"]

# Compute loss
loss = jnp.mean((noise - predicted_noise) ** 2)
print(f"Diffusion loss: {loss:.4f}")

# Sampling: generate new images
samples = model.sample(num_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}")

Next StepsΒ€

  • Learn the Basics


    Understand generative modeling concepts and Workshop architecture

    Core Concepts

  • Complete Tutorial


    Build your first generative model from scratch with step-by-step guidance

    First Model Tutorial

  • Model Guides


    Deep dive into VAEs, GANs, Diffusion, Flows, and more

    Model Guides

  • Advanced Topics


    Distributed training, fine-tuning, protein modeling, and deployment strategies

    Advanced Guides

πŸ—οΈ ArchitectureΒ€

System OverviewΒ€

graph TB
    subgraph Input["Data Pipeline"]
        A1[Image Data]
        A2[Text Data]
        A3[Audio Data]
        A4[Protein Data]
    end

    subgraph Config["Configuration System"]
        B1[ModelConfiguration]
        B2[TrainingConfig]
        B3[DataConfig]
    end

    subgraph Factory["Model Factory"]
        C1[create_vae]
        C2[create_gan]
        C3[create_diffusion]
        C4[create_flow]
        C5[create_ebm]
    end

    subgraph Models["Generative Models"]
        D1[VAE Models]
        D2[GAN Models]
        D3[Diffusion Models]
        D4[Flow Models]
        D5[EBM Models]
    end

    subgraph Training["Training System"]
        E1[VAE Trainer]
        E2[GAN Trainer]
        E3[Diffusion Trainer]
        E4[Flow Trainer]
        E5[EBM Trainer]
    end

    subgraph Eval["Evaluation"]
        F1[Metrics]
        F2[Benchmarks]
        F3[Visualization]
    end

    subgraph Deploy["Inference & Serving"]
        G1[Sampling]
        G2[Generation]
        G3[API Serving]
    end

    A1 --> B1
    A2 --> B1
    A3 --> B1
    A4 --> B1

    B1 --> C1
    B1 --> C2
    B1 --> C3
    B1 --> C4
    B1 --> C5

    C1 --> D1
    C2 --> D2
    C3 --> D3
    C4 --> D4
    C5 --> D5

    D1 --> E1
    D2 --> E2
    D3 --> E3
    D4 --> E4
    D5 --> E5

    E1 --> F1
    E2 --> F1
    E3 --> F1
    E4 --> F1
    E5 --> F1

    F1 --> F2
    F2 --> F3

    D1 --> G1
    D2 --> G1
    D3 --> G1
    D4 --> G1
    D5 --> G1

    G1 --> G2
    G2 --> G3

Component StructureΒ€

Workshop Architecture
β”œβ”€β”€ Core Components
β”‚   β”œβ”€β”€ Protocols & Interfaces    β†’ Type-safe abstractions (GenerativeModel, TrainerProtocol)
β”‚   β”œβ”€β”€ Configuration System      β†’ Pydantic-based unified configs with validation
β”‚   β”œβ”€β”€ Device Manager            β†’ Hardware-aware GPU/CPU/TPU handling
β”‚   β”œβ”€β”€ Loss Functions            β†’ Modular composable losses (reconstruction, adversarial, divergence)
β”‚   β”œβ”€β”€ Sampling Methods          β†’ Ancestral, diffusion, MCMC, ODE/SDE solvers
β”‚   └── Evaluation Framework      β†’ Metrics, benchmarks, protocols
β”œβ”€β”€ Generative Models
β”‚   β”œβ”€β”€ VAE Family               β†’ VAE, Ξ²-VAE, VQ-VAE, Conditional VAE
β”‚   β”œβ”€β”€ GAN Family               β†’ DCGAN, WGAN, StyleGAN, CycleGAN, PatchGAN
β”‚   β”œβ”€β”€ Diffusion Models         β†’ DDPM, DDIM, Score-based, DiT, Latent Diffusion
β”‚   β”œβ”€β”€ Normalizing Flows        β†’ RealNVP, Glow, MAF, IAF, Neural Spline
β”‚   β”œβ”€β”€ Energy-Based Models      β†’ Contrastive divergence, Langevin, MCMC
β”‚   β”œβ”€β”€ Autoregressive Models    β†’ PixelCNN, WaveNet, Transformer
β”‚   └── Geometric Models         β†’ Protein, 3D shapes, SE(3) flows
β”œβ”€β”€ Training System
β”‚   β”œβ”€β”€ Model-Specific Trainers  β†’ Optimized for each model type
β”‚   β”œβ”€β”€ Distributed Training     β†’ Data/model parallel, FSDP, sharding
β”‚   β”œβ”€β”€ Optimization             β†’ Optimizers, schedules, gradient accumulation
β”‚   β”œβ”€β”€ Callbacks                β†’ Logging, checkpointing, early stopping, profiling
β”‚   └── Fine-Tuning              β†’ LoRA, prefix tuning, RLHF
β”œβ”€β”€ Data Pipeline
β”‚   β”œβ”€β”€ Multi-Modal Loaders      β†’ Image, text, audio, protein, tabular, timeseries
β”‚   β”œβ”€β”€ Preprocessing            β†’ Augmentation, normalization, tokenization
β”‚   β”œβ”€β”€ Streaming Support        β†’ Large-scale dataset handling
β”‚   └── Custom Datasets          β†’ Extensible dataset framework
β”œβ”€β”€ Inference & Serving
β”‚   β”œβ”€β”€ Sampling Strategies      β†’ Multiple sampling methods per model type
β”‚   β”œβ”€β”€ Generation Pipeline      β†’ End-to-end inference with batching
β”‚   β”œβ”€β”€ Optimization             β†’ Quantization, pruning, compilation
β”‚   └── Serving Infrastructure   β†’ REST/gRPC APIs, monitoring
└── Evaluation & Benchmarking
    β”œβ”€β”€ Image Metrics            β†’ FID, IS, LPIPS, Precision/Recall
    β”œβ”€β”€ Text Metrics             β†’ Perplexity, BLEU, ROUGE, diversity
    β”œβ”€β”€ Audio Metrics            β†’ MCD, spectral analysis, perceptual quality
    β”œβ”€β”€ Disentanglement          β†’ MIG, SAP, DCI, FactorVAE
    └── Benchmark Suites         β†’ Standardized protocols across modalities

πŸŽ“ Model Types OverviewΒ€

Model Type Use Case Pros Cons Best For
VAE Latent representation learning Stable training, interpretable latents, fast sampling Lower sample quality Data compression, latent space exploration, representation learning
GAN High-quality image generation Best sample quality, mode coverage Training instability, mode collapse Image synthesis, style transfer, super-resolution
Diffusion State-of-the-art generation Excellent quality, stable training, controllable Slow sampling, memory intensive Image/audio generation, inpainting, super-resolution
Flow Exact likelihood modeling Tractable likelihood, invertible Architecture constraints, less flexible Density estimation, exact inference, anomaly detection
EBM Energy-based modeling Flexible, composable Expensive sampling, training complexity Compositional generation, constraint satisfaction
AR Sequential generation Explicit likelihood, interpretable Sequential generation (slow) Text, ordered sequences, explicit probability
Geometric 3D structure generation Physical constraints, equivariance Domain-specific, complex Protein design, molecular generation, 3D modeling

πŸ“Š Performance & BenchmarksΒ€

TODO: Comprehensive benchmark comparisons are planned and will be added once systematic evaluation experiments are completed. The benchmark framework is implemented and ready for evaluation.

Workshop includes a comprehensive benchmarking framework for evaluating generative models. See Benchmarks for:

  • Evaluation protocols and metrics
  • Benchmark suite implementations
  • Guidelines for running benchmarks
  • Framework for comparing models

Planned Benchmark CoverageΒ€

  • Image Generation: FID, IS, LPIPS, Precision/Recall on CIFAR-10, CelebA, ImageNet
  • Text Generation: Perplexity, BLEU, ROUGE on WikiText, BookCorpus
  • Audio Generation: MCD, spectral metrics on VCTK, LibriSpeech
  • Geometric Models: Structure quality metrics for proteins and molecules

πŸ“š DocumentationΒ€

Getting StartedΒ€

User GuidesΒ€

API ReferenceΒ€

Advanced TopicsΒ€

πŸ§ͺ Testing & QualityΒ€

Workshop maintains high code quality standards:

  • Test Coverage: Growing coverage across all modules
  • Test Suite: 2150+ tests passing (unit, integration, end-to-end)
  • Type Safety: Full type annotations with Pyright
  • Code Quality: Ruff for linting and formatting
  • CI/CD: Automated testing on CPU and GPU
# Run complete test suite
pytest tests/ -v

# Run with coverage
pytest --cov=src/workshop --cov-report=html

# Type checking
pyright src/

# Code quality
ruff check src/
ruff format src/

See Testing Guide for details.

🀝 Contributing€

We welcome contributions! Workshop is an open-source project that benefits from community involvement.

How to ContributeΒ€

  1. Fork the repository on GitHub
  2. Create a feature branch: git checkout -b feature/amazing-feature
  3. Make your changes with clear, documented code
  4. Add tests for new functionality
  5. Run quality checks: pre-commit run --all-files
  6. Commit changes: git commit -m 'Add amazing feature'
  7. Push to branch: git push origin feature/amazing-feature
  8. Open a Pull Request with description

Contribution AreasΒ€

  • πŸ› Bug fixes and issue reports
  • ✨ New features and model implementations
  • πŸ“– Documentation improvements
  • πŸ§ͺ Tests and benchmarks
  • 🎨 Examples and tutorials
  • πŸ”§ Performance optimizations

See Contributing Guide for detailed guidelines.

πŸ“– CitationΒ€

If you use Workshop in your research, please cite:

@software{workshop_2025,
  title = {Workshop: Generative Modeling Research Library},
  author = {Shafiei, Mahdi and contributors},
  year = {2025},
  url = {https://github.com/mahdi-shafiei/workshop},
  version = {0.1.0}
}

πŸ™ AcknowledgmentsΒ€

Workshop is built on top of excellent open-source projects:

  • JAX - High-performance numerical computing
  • Flax - Neural network library with NNX API
  • Optax - Gradient processing and optimization
  • Orbax - Checkpointing and serialization
  • BlackJAX - MCMC sampling algorithms

Inspired by research from:

  • VAE: Kingma & Welling (2013) - Auto-Encoding Variational Bayes
  • GAN: Goodfellow et al. (2014) - Generative Adversarial Networks
  • Diffusion: Ho et al. (2020) - Denoising Diffusion Probabilistic Models
  • DiT: Peebles & Xie (2023) - Scalable Diffusion Models with Transformers
  • Flows: Dinh et al. (2016), Kingma & Dhariwal (2018) - Normalizing Flows
  • SE(3) Flows: KΓΆhler et al. (2020) - Equivariant Flows for Molecular Graphs

πŸ“œ LicenseΒ€

This project is licensed under the MIT License - see the LICENSE file for details.


Documentation | GitHub | Issues | Discussions

Made with ❀️ by the Workshop community