Skip to content

Model Gallery¤

Workshop provides implementations of state-of-the-art generative models with 2025 research compliance, including Diffusion Transformers (DiT), SE(3) equivariant flows for molecular generation, and advanced MCMC sampling with BlackJAX integration.

  • 7 Model Families


    VAE, GAN, Diffusion, Flow, EBM, Autoregressive, and Geometric models with 67+ implementations

  • 2025 Research Compliance


    Latest architectures including DiT, StyleGAN3, SE(3) molecular flows, and score-based diffusion

  • Production Ready


    Hardware-optimized, fully tested, type-safe implementations built on JAX/Flax NNX

  • Multi-Modal


    Native support for images, text, audio, proteins, molecules, and 3D geometric data

Overview¤

All models in Workshop follow a unified interface and are built on JAX/Flax NNX for:

  • Hardware acceleration: Automatic GPU/TPU support with XLA optimization
  • Type safety: Full type annotations and protocol-based interfaces
  • Composability: Mix and match components across model types
  • Reproducibility: Deterministic RNG handling and comprehensive testing
  • Scalability: Distributed training with data, model, and pipeline parallelism

Model Families¤

VAE Variational Autoencoders¤

Latent variable models with probabilistic encoding for representation learning and generation.

Available Models:

Model Description Key Features Use Cases
VAE Standard Variational Autoencoder Gaussian latents, KL regularization Representation learning, compression
β-VAE Disentangled VAE Controllable β parameter, beta annealing Disentangled representations
β-VAE with Capacity β-VAE with capacity control Gradual capacity increase, controlled disentanglement Balance reconstruction and disentanglement
Conditional VAE Class-conditional VAE Label conditioning, controlled generation Supervised generation
VQ-VAE Vector Quantized VAE Discrete latent codes, codebook learning Discrete representations, compression

Quick Start:

from workshop.generative_models.factories import create_vae
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx

config = ModelConfiguration(
    model_type="vae",
    latent_dim=128,
    input_shape=(32, 32, 3),
    encoder_features=[64, 128, 256],
    decoder_features=[256, 128, 64],
    parameters={"beta": 1.0, "kl_weight": 1.0}
)

model = create_vae(config, rngs=nnx.Rngs(0))

Documentation:


GAN Generative Adversarial Networks¤

Adversarial training for high-quality image generation and image-to-image translation.

Available Models:

Model Description Key Features Use Cases
DCGAN Deep Convolutional GAN Convolutional architecture, stable training Image generation baseline
WGAN Wasserstein GAN Wasserstein distance, critic training Stable training, mode coverage
LSGAN Least Squares GAN Least squares loss, improved stability Image generation with stable training
StyleGAN Style-based GAN Style mixing, AdaIN layers High-quality face generation
StyleGAN3 Alias-free StyleGAN Translation/rotation equivariance Alias-free high-quality generation
CycleGAN Cycle-consistent GAN Unpaired translation, cycle loss Image-to-image translation
PatchGAN Patch-based discriminator Local image patches, texture detail Image-to-image tasks, super-resolution
Conditional GAN Class-conditional GAN Label conditioning Controlled generation

Quick Start:

from workshop.generative_models.factories import create_gan

config = ModelConfiguration(
    model_type="dcgan",
    input_shape=(64, 64, 3),
    latent_dim=100,
    generator_features=[512, 256, 128, 64],
    discriminator_features=[64, 128, 256, 512],
    parameters={"learning_rate_g": 2e-4, "learning_rate_d": 2e-4}
)

model = create_gan(config, rngs=nnx.Rngs(0))

Documentation:


Diffusion Diffusion Models¤

State-of-the-art denoising diffusion models for high-quality generation.

Available Models:

Model Description Key Features Use Cases
DDPM Denoising Diffusion Probabilistic Models Gaussian diffusion, noise prediction Image generation, baseline
DDIM Denoising Diffusion Implicit Models Deterministic sampling, faster inference Fast high-quality generation
Score-based Score-based generative models Score matching, SDE/ODE solvers Flexible sampling strategies
Latent Diffusion Latent space diffusion VAE encoder/decoder, efficient training High-resolution generation
DiT Diffusion Transformer Transformer backbone, scalable Large-scale image generation
Stable Diffusion Text-to-image diffusion CLIP conditioning, latent diffusion Text-to-image generation

Quick Start:

from workshop.generative_models.factories import create_diffusion

config = ModelConfiguration(
    model_type="ddpm",
    input_shape=(32, 32, 3),
    num_timesteps=1000,
    backbone_type="unet",
    backbone_features=[128, 256, 512],
    parameters={
        "beta_start": 1e-4,
        "beta_end": 2e-2,
        "beta_schedule": "linear"
    }
)

model = create_diffusion(config, rngs=nnx.Rngs(0))

Documentation:


Flow Normalizing Flows¤

Invertible transformations with tractable likelihoods for exact density estimation.

Available Models:

Model Description Key Features Use Cases
RealNVP Real-valued Non-Volume Preserving Affine coupling layers, multi-scale Density estimation baseline
Glow Generative Flow Invertible 1x1 convolutions, ActNorm High-quality image generation
MAF Masked Autoregressive Flow Autoregressive coupling, parallel training Flexible density estimation
IAF Inverse Autoregressive Flow Fast sampling, parallel generation Variational inference
Neural Spline Flow Spline-based coupling Smooth transformations, expressive High-quality density estimation
SE(3) Molecular Flow Equivariant molecular flows SE(3) symmetry, molecular generation Drug design, molecular modeling
Conditional Flow Class-conditional flows Label conditioning Controlled generation

Quick Start:

from workshop.generative_models.factories import create_flow

config = ModelConfiguration(
    model_type="realnvp",
    input_shape=(32, 32, 3),
    num_coupling_layers=8,
    coupling_features=[512, 512],
    parameters={"num_scales": 3, "scale_factor": 2}
)

model = create_flow(config, rngs=nnx.Rngs(0))

Documentation:


EBM Energy-Based Models¤

Energy function learning with MCMC sampling for compositional generation.

Available Models:

Model Description Key Features Use Cases
EBM Energy-based model Energy function learning, flexible Compositional generation
EBM with MCMC EBM with MCMC sampling Langevin dynamics, HMC, NUTS High-quality sampling
Persistent CD Persistent Contrastive Divergence Persistent chains, efficient training Stable EBM training

MCMC Samplers (via BlackJAX integration):

  • HMC: Hamiltonian Monte Carlo
  • NUTS: No-U-Turn Sampler (adaptive HMC)
  • MALA: Metropolis-Adjusted Langevin Algorithm
  • Langevin Dynamics: First-order gradient-based sampling

Quick Start:

from workshop.generative_models.factories import create_ebm

config = ModelConfiguration(
    model_type="ebm",
    input_shape=(28, 28, 1),
    energy_features=[512, 256, 128],
    parameters={
        "mcmc_steps": 100,
        "step_size": 0.01,
        "sampler": "langevin"
    }
)

model = create_ebm(config, rngs=nnx.Rngs(0))

Documentation:


AR Autoregressive Models¤

Sequential generation with explicit likelihood for ordered data.

Available Models:

Model Description Key Features Use Cases
PixelCNN Autoregressive image model Masked convolutions, pixel-by-pixel Image generation with likelihood
WaveNet Autoregressive audio model Dilated convolutions, long context Audio generation, TTS
Transformer Transformer-based AR Self-attention, parallel training Text, structured sequences

Quick Start:

from workshop.generative_models.factories import create_autoregressive

config = ModelConfiguration(
    model_type="pixelcnn",
    input_shape=(32, 32, 3),
    num_layers=8,
    hidden_channels=128,
    parameters={"kernel_size": 3, "num_classes": 256}
)

model = create_autoregressive(config, rngs=nnx.Rngs(0))

Documentation:


Geometric Geometric Models¤

3D structure generation with physical constraints and equivariance.

Available Models:

Model Description Key Features Use Cases
Point Cloud Generator 3D point cloud generation Permutation invariance 3D object generation
Mesh Generator 3D mesh generation Vertex/face generation, deformation 3D modeling
Protein Graph Protein structure generation Residue graphs, amino acid features Protein design
Protein Point Cloud Protein backbone generation Cα coordinates, backbone geometry Protein structure prediction
Voxel Generator Voxel-based 3D generation Regular 3D grid 3D shape generation
Graph Generator Graph-based generation Node/edge features, message passing Molecular graphs

Quick Start:

from workshop.generative_models.factories import create_geometric

config = ModelConfiguration(
    model_type="protein_point_cloud",
    input_shape=(128, 3),  # 128 residues, 3D coordinates
    parameters={
        "num_residues": 128,
        "hidden_dim": 256,
        "num_layers": 6
    }
)

model = create_geometric(config, rngs=nnx.Rngs(0))

Documentation:


Model Comparison¤

Choose the right model for your task:

Model Type Sample Quality Training Stability Speed Exact Likelihood Best For
VAE ⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ❌ (Lower bound) Representation learning, fast sampling
GAN ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐ High-quality images, style transfer
Diffusion ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐ State-of-the-art generation, controllability
Flow ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐ Density estimation, exact inference
EBM ⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐ ❌ (unnormalized) Compositional generation, flexibility
AR ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐ Sequences, explicit probabilities
Geometric ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐ Varies 3D structures, physical constraints

Common Backbones¤

Workshop provides reusable backbone architectures used across model types:

U-Net¤

Widely used in diffusion models and image-to-image tasks:

from workshop.generative_models.models.common.unet import UNet

unet = UNet(
    in_channels=3,
    out_channels=3,
    channels=[128, 256, 512, 1024],
    num_res_blocks=2,
    attention_resolutions=[16, 8],
    rngs=nnx.Rngs(0)
)

Documentation: U-Net API

Diffusion Transformer (DiT)¤

Transformer-based backbone for diffusion models:

from workshop.generative_models.models.diffusion.dit import DiT

dit = DiT(
    input_size=32,
    patch_size=2,
    in_channels=3,
    hidden_size=768,
    depth=12,
    num_heads=12,
    rngs=nnx.Rngs(0)
)

Documentation: DiT API

Encoders & Decoders¤

Modular encoder/decoder architectures:

  • MLP Encoder/Decoder: Fully-connected networks
  • CNN Encoder/Decoder: Convolutional networks for images
  • Conditional Encoder/Decoder: Class-conditional variants
  • ResNet Encoder/Decoder: Residual connections

Documentation: Encoders | Decoders

Conditioning Methods¤

Workshop supports multiple conditioning strategies across models:

Method Description Supported Models Use Cases
Class Conditioning One-hot encoded labels VAE, GAN, Diffusion, Flow Supervised generation
Text Conditioning CLIP embeddings Diffusion, GAN Text-to-image
Image Conditioning Concatenation, cross-attention GAN, Diffusion Image-to-image, inpainting
Embedding Conditioning Learned embeddings All models Flexible conditioning

Documentation: Conditioning Guide

Model Registry¤

All models are registered in a global registry for easy instantiation:

from workshop.generative_models.models.registry import (
    list_models,
    get_model_class,
    register_model
)

# List all available models
available = list_models()
print(f"Available models: {len(available)}")

# Get model class by name
vae_class = get_model_class("vae")

# Register custom model
from my_models import CustomVAE
register_model("custom_vae", CustomVAE)

Documentation: Model Registry

Factory Functions¤

Create models easily with factory functions:

from workshop.generative_models.factories import (
    create_vae,
    create_gan,
    create_diffusion,
    create_flow,
    create_ebm,
    create_autoregressive,
    create_geometric,
    create_model  # Generic factory
)

# Generic factory (auto-detects model type)
config = ModelConfiguration(model_type="ddpm", ...)
model = create_model(config, rngs=nnx.Rngs(0))

Documentation: Factory API

Training¤

All models follow a unified training interface:

from workshop.generative_models.training import Trainer
from workshop.generative_models.core.configuration import TrainingConfiguration

training_config = TrainingConfiguration(
    batch_size=128,
    num_epochs=100,
    optimizer={"type": "adam", "learning_rate": 1e-3},
    scheduler={"type": "cosine", "warmup_steps": 1000}
)

trainer = Trainer(
    model=model,
    training_config=training_config,
    train_dataset=train_data,
    val_dataset=val_data
)

trainer.train()

Documentation:

Evaluation¤

Evaluate models with modality-specific metrics:

from workshop.benchmarks import EvaluationFramework

framework = EvaluationFramework(
    model=model,
    modality="image",
    metrics=["fid", "inception_score", "lpips"]
)

results = framework.evaluate(test_dataset)
print(results)

Documentation: Benchmarks

Examples¤

Hands-on examples for each model family:

Contributing¤

Add new models to Workshop:

  1. Implement model following protocols in core/protocols.py
  2. Add to appropriate directory (vae/, gan/, diffusion/, etc.)
  3. Register in model registry
  4. Add comprehensive tests
  5. Document API and usage

Documentation: Contributing Guide

API Statistics¤

Current model coverage:

  • Total modules: 67
  • Total classes: 135
  • Total functions: 482
  • Model families: 7
  • Conditioning methods: 4
  • Sampling methods: 15+

  • User Guides


    Deep dive into each model family with examples and best practices

    Browse guides

  • API Reference


    Complete API documentation for all models and components

    API docs

  • Tutorials


    Step-by-step tutorials for common tasks and workflows

    Start learning

  • Examples


    Working code examples for all model types

    See examples