Core Concepts¤
This guide introduces the fundamental concepts behind Workshop and generative modeling. Understanding these concepts will help you make the most of the library.
What is Generative Modeling?¤
Generative modeling is about learning probability distributions from data and generating new samples from those distributions.
The Core Problem¤
Given a dataset \(\mathcal{D} = \{x_1, x_2, ..., x_N\}\), we want to:
- Learn the underlying data distribution \(p(x)\)
- Generate new samples \(\tilde{x} \sim p(x)\) that look like the training data
- Evaluate the quality of generated samples
Why Generative Models?¤
-
Image Generation
Create realistic images, art, faces, or any visual content
-
Data Augmentation
Generate synthetic training data to improve discriminative models
-
Representation Learning
Learn meaningful latent representations of data
-
Anomaly Detection
Identify outliers by measuring likelihood under the learned distribution
-
Content Creation
Generate text, audio, video, 3D shapes, and more
-
Scientific Discovery
Generate molecules, proteins, materials for drug design and research
Key Concepts¤
1. Probability Distribution¤
A probability distribution \(p(x)\) assigns probabilities to different outcomes:
- Discrete: \(p(x) \in [0, 1]\) for each \(x\), \(\sum_x p(x) = 1\)
- Continuous: \(p(x) \geq 0\), \(\int p(x) dx = 1\)
Goal: Learn \(p(x)\) from data so we can sample new \(x \sim p(x)\).
2. Likelihood¤
The likelihood \(p(x|\theta)\) measures how probable data \(x\) is under model parameters \(\theta\).
Maximum Likelihood Estimation (MLE):
Models with tractable likelihoods (e.g., Flows, Autoregressive) directly optimize this.
3. Latent Variables¤
Latent variables \(z\) are unobserved variables that capture underlying structure:
- \(p(z)\): Prior distribution (usually standard normal)
- \(p(x|z)\): Likelihood (decoder/generator)
- \(p(z|x)\): Posterior (encoder/inference network)
Examples:
- VAE: Continuous latent space
- VQ-VAE: Discrete latent codebook
- Diffusion: Noisy latent trajectory
4. Sampling¤
Generating new data from the learned distribution:
Ancestral Sampling: Sample from the prior \(z \sim p(z)\), then generate \(x \sim p(x|z)\)
MCMC Sampling: Use Markov chains to sample from complex distributions
Diffusion Sampling: Iteratively denoise from pure noise to data
5. Encoder-Decoder Architecture¤
graph LR
X[Data x] -->|Encode| Z[Latent z]
Z -->|Decode| XR[Reconstructed x']
style X fill:#e1f5ff
style Z fill:#fff4e1
style XR fill:#e8f5e9
- Encoder: Maps data to latent space \(q(z|x)\)
- Decoder: Maps latent to data space \(p(x|z)\)
- Latent Space: Compressed, structured representation
Generative Model Types¤
Workshop supports six main types of generative models, each with different trade-offs:
1. Variational Autoencoders (VAE)¤
Idea: Learn latent representations with probabilistic encoding/decoding
Key Equation: Evidence Lower Bound (ELBO)
Architecture:
graph TD
X[Input x] --> E[Encoder q]
E --> M[Mean μ]
E --> S[Std σ]
M --> R[Reparameterize]
S --> R
R --> Z[Latent z]
Z --> D[Decoder p]
D --> XR[Output x']
Pros:
- Stable training
- Interpretable latent space
- Fast sampling
Cons:
- Lower sample quality compared to GANs/Diffusion
- Posterior approximation may be limited
Use Cases: Data compression, latent space exploration, representation learning
Workshop Example:
from workshop.generative_models.factories import create_vae
config = ModelConfiguration(
model_type="vae",
latent_dim=128,
parameters={"beta": 1.0} # β-VAE for disentanglement
)
model = create_vae(config, rngs=rngs)
2. Generative Adversarial Networks (GAN)¤
Idea: Train generator and discriminator in adversarial game
Key Equation: Minimax objective
Architecture:
graph LR
Z[Noise z] --> G[Generator G]
G --> XF[Fake x']
XR[Real x] --> D[Discriminator D]
XF --> D
D --> R[Real/Fake]
style G fill:#fff4e1
style D fill:#ffe1f5
Pros:
- Highest sample quality
- No explicit likelihood needed
- Mode coverage (with proper training)
Cons:
- Training instability (mode collapse, vanishing gradients)
- Hyperparameter sensitive
- No direct likelihood evaluation
Use Cases: High-quality image generation, style transfer, image-to-image translation
Workshop Example:
from workshop.generative_models.factories import create_gan
config = ModelConfiguration(
model_type="wgan-gp", # WGAN with gradient penalty
latent_dim=100,
parameters={"gp_weight": 10.0}
)
model = create_gan(config, rngs=rngs)
3. Diffusion Models¤
Idea: Learn to denoise data through iterative refinement
Forward Process: Add noise gradually
Reverse Process: Learn to denoise
Architecture:
graph LR
X0[Clean x₀] -->|Add noise| X1[Noisy x₁]
X1 -->|Add noise| X2[Noisy x₂]
X2 -->|...| XT[Pure noise xₜ]
XT -->|Denoise| X2R[x₂]
X2R -->|Denoise| X1R[x₁]
X1R -->|Denoise| X0R[Clean x₀]
style X0 fill:#e8f5e9
style XT fill:#ffebee
style X0R fill:#e8f5e9
Pros:
- State-of-the-art sample quality
- Stable training
- Flexible conditioning and guidance
Cons:
- Slow sampling (many steps)
- Memory intensive
- Training time
Use Cases: Image/audio generation, inpainting, super-resolution, conditional generation
Workshop Example:
from workshop.generative_models.factories import create_diffusion
config = ModelConfiguration(
model_type="ddpm",
num_timesteps=1000,
backbone_type="dit", # Diffusion Transformer
parameters={"beta_schedule": "cosine"}
)
model = create_diffusion(config, rngs=rngs)
4. Normalizing Flows¤
Idea: Use invertible transformations with tractable Jacobians
Key Property: Change of variables
Architecture:
graph LR
Z[Simple z] -->|f| X[Complex x]
X -->|f⁻¹| Z2[Simple z]
style Z fill:#e8f5e9
style X fill:#e1f5ff
Pros:
- Exact likelihood computation
- Invertible (can go both ways)
- Stable training
Cons:
- Architecture constraints (must be invertible)
- May require many layers for expressiveness
Use Cases: Density estimation, exact inference, anomaly detection
Workshop Example:
from workshop.generative_models.factories import create_flow
config = ModelConfiguration(
model_type="glow",
num_flows=32,
parameters={"coupling_type": "affine"}
)
model = create_flow(config, rngs=rngs)
5. Energy-Based Models (EBM)¤
Idea: Learn energy function, sample with MCMC
Key Equation: Gibbs distribution
where
is the partition function.
Pros:
- Flexible, can model any distribution
- Composable (combine multiple EBMs)
Cons:
- Expensive sampling (MCMC)
- Training complexity (contrastive divergence)
Use Cases: Compositional generation, constraint satisfaction, hybrid models
Workshop Example:
from workshop.generative_models.factories import create_ebm
config = ModelConfiguration(
model_type="ebm",
parameters={
"mcmc_steps": 60,
"step_size": 0.01
}
)
model = create_ebm(config, rngs=rngs)
6. Autoregressive Models¤
Idea: Model distribution as product of conditionals
Architecture:
graph LR
X1[x₁] -->|p| X2[x₂]
X2 -->|p| X3[x₃]
X3 -->|...| XN[xₙ]
Pros:
- Explicit likelihood
- Flexible architectures (Transformers, CNNs)
- Strong theoretical foundation
Cons:
- Sequential generation (slow)
- Fixed ordering required
Use Cases: Text generation, ordered sequences, explicit probability modeling
Workshop Example:
from workshop.generative_models.factories import create_autoregressive
config = ModelConfiguration(
model_type="pixelcnn",
parameters={"num_layers": 12}
)
model = create_autoregressive(config, rngs=rngs)
Model Comparison Matrix¤
| Feature | VAE | GAN | Diffusion | Flow | EBM | Autoregressive |
|---|---|---|---|---|---|---|
| Sample Quality | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| Training Stability | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
| Sampling Speed | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐ |
| Exact Likelihood | ❌ | ❌ | ❌ | ✅ | ❌* | ✅ |
| Latent Space | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
| Mode Coverage | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
*EBM has exact likelihood but intractable partition function
Workshop Architecture¤
High-Level Design¤
Workshop follows a modular, protocol-based design:
graph TB
subgraph User["User Interface"]
Config[Configuration]
Factory[Model Factories]
end
subgraph Core["Core Components"]
Protocols[Protocols & Interfaces]
Device[Device Manager]
Loss[Loss Functions]
end
subgraph Models["Generative Models"]
VAE[VAE]
GAN[GAN]
Diff[Diffusion]
Flow[Flow]
EBM[EBM]
AR[Autoregressive]
end
subgraph Training["Training System"]
Trainer[Trainers]
Opt[Optimizers]
Callbacks[Callbacks]
end
Config --> Factory
Factory --> Models
Models --> Training
Core --> Models
Core --> Training
Key Design Principles¤
- Protocol-Based: Type-safe interfaces using Python Protocols
- Configuration-Driven: Pydantic-based unified configuration system
- Factory Pattern: Centralized model creation
- Hardware-Aware: Automatic GPU/CPU/TPU detection and optimization
- Modular: Composable components for flexibility
Configuration System¤
All models use a unified configuration class:
from workshop.generative_models.core.configuration import ModelConfiguration
config = ModelConfiguration(
# Required fields
model_type="vae", # Type of model
input_shape=(28, 28, 1), # Input dimensions
# Architecture
latent_dim=128, # Latent space size
encoder_features=[64, 128], # Encoder layer sizes
decoder_features=[128, 64], # Decoder layer sizes
# Model-specific parameters
parameters={
"beta": 1.0, # VAE-specific: β-VAE weight
"kl_weight": 1.0, # KL divergence weight
"reconstruction_loss": "mse" # Reconstruction loss type
},
# Optional metadata
metadata={
"experiment_id": "vae_001",
"dataset": "mnist"
}
)
Benefits:
- Type-safe with Pydantic validation
- Serializable (save/load configurations)
- Versioned for reproducibility
- Extensible for custom models
Device Management¤
Workshop automatically handles GPU/CPU/TPU:
from workshop.generative_models.core.device_manager import get_device_manager
# Automatic device detection
manager = get_device_manager()
print(f"Using: {manager.get_device()}") # gpu, cpu, or tpu
print(f"Device count: {manager.device_count}")
# Explicit configuration
from workshop.generative_models.core.device_manager import configure_for_generative_models, MemoryStrategy
manager = configure_for_generative_models(
memory_strategy=MemoryStrategy.BALANCED, # 75% GPU memory
enable_mixed_precision=True # BF16/FP16
)
Protocol System¤
Models implement the GenerativeModel protocol:
from typing import Protocol
from flax import nnx
import jax
class GenerativeModel(Protocol):
"""Base protocol for all generative models."""
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs = None) -> dict:
"""Forward pass."""
...
def generate(self, *, n_samples: int, rngs: nnx.Rngs = None) -> jax.Array:
"""Generate samples from the model."""
...
def loss_fn(self, batch: dict, outputs: dict) -> dict:
"""Compute loss."""
...
This enables:
- Type checking at development time
- Generic training loops that work with any model
- Consistent interfaces across model types
JAX and Flax NNX Basics¤
Workshop is built on JAX and Flax NNX. Here are the key concepts:
JAX: Functional Programming¤
JAX provides functional transformations:
import jax
import jax.numpy as jnp
# JIT compilation for speed
@jax.jit
def fast_function(x):
return jnp.sum(x ** 2)
# Automatic differentiation
def loss_fn(params, x):
return jnp.sum((params['w'] * x) ** 2)
grad_fn = jax.grad(loss_fn) # Get gradient function
gradients = grad_fn(params, x) # Compute gradients
# Vectorization
batch_fn = jax.vmap(fast_function) # Apply to batches
Flax NNX: Object-Oriented Neural Networks¤
Flax NNX provides a Pythonic API for neural networks:
from flax import nnx
class MyModel(nnx.Module):
def __init__(self, features: int, *, rngs: nnx.Rngs):
super().__init__()
self.dense1 = nnx.Linear(784, features, rngs=rngs)
self.dense2 = nnx.Linear(features, 10, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
x = self.dense1(x)
x = nnx.relu(x)
x = self.dense2(x)
return x
# Create model
rngs = nnx.Rngs(0)
model = MyModel(features=128, rngs=rngs)
# Use model
x = jax.random.normal(jax.random.PRNGKey(0), (32, 784))
y = model(x)
Random Number Generation¤
JAX requires explicit RNG management:
from flax import nnx
# Create RNG
rngs = nnx.Rngs(seed=42)
# Use for model initialization
model = MyModel(rngs=rngs)
# Use for sampling
samples = jax.random.normal(rngs.sample(), (10, 784))
Multi-Modal Support¤
Workshop natively supports multiple data modalities:
Image¤
from workshop.data.datasets.image import CIFAR10Dataset
dataset = CIFAR10Dataset(root='./data', train=True)
Text¤
from workshop.data.datasets.text import WikipediaDataset
dataset = WikipediaDataset(
tokenizer='bpe',
max_length=512
)
Audio¤
from workshop.data.datasets.audio import LibriSpeechDataset
dataset = LibriSpeechDataset(
root='./data',
sample_rate=16000
)
Protein¤
from workshop.data.protein_dataset import ProteinDataset
dataset = ProteinDataset(
pdb_dir='./data/pdb',
with_constraints=True
)
Next Steps¤
Now that you understand the core concepts:
-
Build Your First Model
Hands-on tutorial to build a complete VAE from scratch
-
Explore Model Guides
Deep dives into each model type with examples
-
Check API Reference
Complete API documentation for all components
-
Learn Training
Training workflows, optimization, and distributed training
Further Reading¤
Generative Models¤
- VAE: Kingma & Welling (2013) - Auto-Encoding Variational Bayes
- GAN: Goodfellow et al. (2014) - Generative Adversarial Networks
- Diffusion: Ho et al. (2020) - Denoising Diffusion Probabilistic Models
- Flow: Dinh et al. (2016) - Density Estimation using Real NVP
JAX and Flax¤
Last Updated: 2025-10-13