Skip to content

Core Concepts¤

This guide introduces the fundamental concepts behind Workshop and generative modeling. Understanding these concepts will help you make the most of the library.

What is Generative Modeling?¤

Generative modeling is about learning probability distributions from data and generating new samples from those distributions.

The Core Problem¤

Given a dataset \(\mathcal{D} = \{x_1, x_2, ..., x_N\}\), we want to:

  1. Learn the underlying data distribution \(p(x)\)
  2. Generate new samples \(\tilde{x} \sim p(x)\) that look like the training data
  3. Evaluate the quality of generated samples

Why Generative Models?¤

  • Image Generation


    Create realistic images, art, faces, or any visual content

  • Data Augmentation


    Generate synthetic training data to improve discriminative models

  • Representation Learning


    Learn meaningful latent representations of data

  • Anomaly Detection


    Identify outliers by measuring likelihood under the learned distribution

  • Content Creation


    Generate text, audio, video, 3D shapes, and more

  • Scientific Discovery


    Generate molecules, proteins, materials for drug design and research

Key Concepts¤

1. Probability Distribution¤

A probability distribution \(p(x)\) assigns probabilities to different outcomes:

  • Discrete: \(p(x) \in [0, 1]\) for each \(x\), \(\sum_x p(x) = 1\)
  • Continuous: \(p(x) \geq 0\), \(\int p(x) dx = 1\)

Goal: Learn \(p(x)\) from data so we can sample new \(x \sim p(x)\).

2. Likelihood¤

The likelihood \(p(x|\theta)\) measures how probable data \(x\) is under model parameters \(\theta\).

Maximum Likelihood Estimation (MLE):

\[ \theta^* = \arg\max_\theta \sum_{i=1}^N \log p(x_i | \theta) \]

Models with tractable likelihoods (e.g., Flows, Autoregressive) directly optimize this.

3. Latent Variables¤

Latent variables \(z\) are unobserved variables that capture underlying structure:

\[ p(x) = \int p(x|z) p(z) dz \]
  • \(p(z)\): Prior distribution (usually standard normal)
  • \(p(x|z)\): Likelihood (decoder/generator)
  • \(p(z|x)\): Posterior (encoder/inference network)

Examples:

  • VAE: Continuous latent space
  • VQ-VAE: Discrete latent codebook
  • Diffusion: Noisy latent trajectory

4. Sampling¤

Generating new data from the learned distribution:

Ancestral Sampling: Sample from the prior \(z \sim p(z)\), then generate \(x \sim p(x|z)\)

MCMC Sampling: Use Markov chains to sample from complex distributions

Diffusion Sampling: Iteratively denoise from pure noise to data

5. Encoder-Decoder Architecture¤

graph LR
    X[Data x] -->|Encode| Z[Latent z]
    Z -->|Decode| XR[Reconstructed x']

    style X fill:#e1f5ff
    style Z fill:#fff4e1
    style XR fill:#e8f5e9
  • Encoder: Maps data to latent space \(q(z|x)\)
  • Decoder: Maps latent to data space \(p(x|z)\)
  • Latent Space: Compressed, structured representation

Generative Model Types¤

Workshop supports six main types of generative models, each with different trade-offs:

1. Variational Autoencoders (VAE)¤

Idea: Learn latent representations with probabilistic encoding/decoding

Key Equation: Evidence Lower Bound (ELBO)

\[ \mathcal{L}_{ELBO} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}(q(z|x) || p(z)) \]

Architecture:

graph TD
    X[Input x] --> E[Encoder q]
    E --> M[Mean μ]
    E --> S[Std σ]
    M --> R[Reparameterize]
    S --> R
    R --> Z[Latent z]
    Z --> D[Decoder p]
    D --> XR[Output x']

Pros:

  • Stable training
  • Interpretable latent space
  • Fast sampling

Cons:

  • Lower sample quality compared to GANs/Diffusion
  • Posterior approximation may be limited

Use Cases: Data compression, latent space exploration, representation learning

Workshop Example:

from workshop.generative_models.factories import create_vae

config = ModelConfiguration(
    model_type="vae",
    latent_dim=128,
    parameters={"beta": 1.0}  # β-VAE for disentanglement
)
model = create_vae(config, rngs=rngs)

2. Generative Adversarial Networks (GAN)¤

Idea: Train generator and discriminator in adversarial game

Key Equation: Minimax objective

\[ \min_G \max_D \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))] \]

Architecture:

graph LR
    Z[Noise z] --> G[Generator G]
    G --> XF[Fake x']
    XR[Real x] --> D[Discriminator D]
    XF --> D
    D --> R[Real/Fake]

    style G fill:#fff4e1
    style D fill:#ffe1f5

Pros:

  • Highest sample quality
  • No explicit likelihood needed
  • Mode coverage (with proper training)

Cons:

  • Training instability (mode collapse, vanishing gradients)
  • Hyperparameter sensitive
  • No direct likelihood evaluation

Use Cases: High-quality image generation, style transfer, image-to-image translation

Workshop Example:

from workshop.generative_models.factories import create_gan

config = ModelConfiguration(
    model_type="wgan-gp",  # WGAN with gradient penalty
    latent_dim=100,
    parameters={"gp_weight": 10.0}
)
model = create_gan(config, rngs=rngs)

3. Diffusion Models¤

Idea: Learn to denoise data through iterative refinement

Forward Process: Add noise gradually

\[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) \]

Reverse Process: Learn to denoise

\[ p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) \]

Architecture:

graph LR
    X0[Clean x₀] -->|Add noise| X1[Noisy x₁]
    X1 -->|Add noise| X2[Noisy x₂]
    X2 -->|...| XT[Pure noise xₜ]
    XT -->|Denoise| X2R[x₂]
    X2R -->|Denoise| X1R[x₁]
    X1R -->|Denoise| X0R[Clean x₀]

    style X0 fill:#e8f5e9
    style XT fill:#ffebee
    style X0R fill:#e8f5e9

Pros:

  • State-of-the-art sample quality
  • Stable training
  • Flexible conditioning and guidance

Cons:

  • Slow sampling (many steps)
  • Memory intensive
  • Training time

Use Cases: Image/audio generation, inpainting, super-resolution, conditional generation

Workshop Example:

from workshop.generative_models.factories import create_diffusion

config = ModelConfiguration(
    model_type="ddpm",
    num_timesteps=1000,
    backbone_type="dit",  # Diffusion Transformer
    parameters={"beta_schedule": "cosine"}
)
model = create_diffusion(config, rngs=rngs)

4. Normalizing Flows¤

Idea: Use invertible transformations with tractable Jacobians

Key Property: Change of variables

\[ p_X(x) = p_Z(f^{-1}(x)) \left| \det \frac{\partial f^{-1}}{\partial x} \right| \]

Architecture:

graph LR
    Z[Simple z] -->|f| X[Complex x]
    X -->|f⁻¹| Z2[Simple z]

    style Z fill:#e8f5e9
    style X fill:#e1f5ff

Pros:

  • Exact likelihood computation
  • Invertible (can go both ways)
  • Stable training

Cons:

  • Architecture constraints (must be invertible)
  • May require many layers for expressiveness

Use Cases: Density estimation, exact inference, anomaly detection

Workshop Example:

from workshop.generative_models.factories import create_flow

config = ModelConfiguration(
    model_type="glow",
    num_flows=32,
    parameters={"coupling_type": "affine"}
)
model = create_flow(config, rngs=rngs)

5. Energy-Based Models (EBM)¤

Idea: Learn energy function, sample with MCMC

Key Equation: Gibbs distribution

\[ p(x) = \frac{1}{Z} \exp(-E(x)) \]

where

\[Z = \int \exp(-E(x)) dx\]

is the partition function.

Pros:

  • Flexible, can model any distribution
  • Composable (combine multiple EBMs)

Cons:

  • Expensive sampling (MCMC)
  • Training complexity (contrastive divergence)

Use Cases: Compositional generation, constraint satisfaction, hybrid models

Workshop Example:

from workshop.generative_models.factories import create_ebm

config = ModelConfiguration(
    model_type="ebm",
    parameters={
        "mcmc_steps": 60,
        "step_size": 0.01
    }
)
model = create_ebm(config, rngs=rngs)

6. Autoregressive Models¤

Idea: Model distribution as product of conditionals

\[ p(x) = \prod_{i=1}^n p(x_i | x_{<i}) \]

Architecture:

graph LR
    X1[x₁] -->|p| X2[x₂]
    X2 -->|p| X3[x₃]
    X3 -->|...| XN[xₙ]

Pros:

  • Explicit likelihood
  • Flexible architectures (Transformers, CNNs)
  • Strong theoretical foundation

Cons:

  • Sequential generation (slow)
  • Fixed ordering required

Use Cases: Text generation, ordered sequences, explicit probability modeling

Workshop Example:

from workshop.generative_models.factories import create_autoregressive

config = ModelConfiguration(
    model_type="pixelcnn",
    parameters={"num_layers": 12}
)
model = create_autoregressive(config, rngs=rngs)

Model Comparison Matrix¤

Feature VAE GAN Diffusion Flow EBM Autoregressive
Sample Quality ⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐
Training Stability ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐⭐
Sampling Speed ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐ ⭐⭐ ⭐⭐
Exact Likelihood ❌*
Latent Space
Mode Coverage ⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐

*EBM has exact likelihood but intractable partition function

Workshop Architecture¤

High-Level Design¤

Workshop follows a modular, protocol-based design:

graph TB
    subgraph User["User Interface"]
        Config[Configuration]
        Factory[Model Factories]
    end

    subgraph Core["Core Components"]
        Protocols[Protocols & Interfaces]
        Device[Device Manager]
        Loss[Loss Functions]
    end

    subgraph Models["Generative Models"]
        VAE[VAE]
        GAN[GAN]
        Diff[Diffusion]
        Flow[Flow]
        EBM[EBM]
        AR[Autoregressive]
    end

    subgraph Training["Training System"]
        Trainer[Trainers]
        Opt[Optimizers]
        Callbacks[Callbacks]
    end

    Config --> Factory
    Factory --> Models
    Models --> Training
    Core --> Models
    Core --> Training

Key Design Principles¤

  1. Protocol-Based: Type-safe interfaces using Python Protocols
  2. Configuration-Driven: Pydantic-based unified configuration system
  3. Factory Pattern: Centralized model creation
  4. Hardware-Aware: Automatic GPU/CPU/TPU detection and optimization
  5. Modular: Composable components for flexibility

Configuration System¤

All models use a unified configuration class:

from workshop.generative_models.core.configuration import ModelConfiguration

config = ModelConfiguration(
    # Required fields
    model_type="vae",              # Type of model
    input_shape=(28, 28, 1),       # Input dimensions

    # Architecture
    latent_dim=128,                # Latent space size
    encoder_features=[64, 128],    # Encoder layer sizes
    decoder_features=[128, 64],    # Decoder layer sizes

    # Model-specific parameters
    parameters={
        "beta": 1.0,                # VAE-specific: β-VAE weight
        "kl_weight": 1.0,           # KL divergence weight
        "reconstruction_loss": "mse" # Reconstruction loss type
    },

    # Optional metadata
    metadata={
        "experiment_id": "vae_001",
        "dataset": "mnist"
    }
)

Benefits:

  • Type-safe with Pydantic validation
  • Serializable (save/load configurations)
  • Versioned for reproducibility
  • Extensible for custom models

Device Management¤

Workshop automatically handles GPU/CPU/TPU:

from workshop.generative_models.core.device_manager import get_device_manager

# Automatic device detection
manager = get_device_manager()
print(f"Using: {manager.get_device()}")  # gpu, cpu, or tpu
print(f"Device count: {manager.device_count}")

# Explicit configuration
from workshop.generative_models.core.device_manager import configure_for_generative_models, MemoryStrategy

manager = configure_for_generative_models(
    memory_strategy=MemoryStrategy.BALANCED,  # 75% GPU memory
    enable_mixed_precision=True                # BF16/FP16
)

Protocol System¤

Models implement the GenerativeModel protocol:

from typing import Protocol
from flax import nnx
import jax

class GenerativeModel(Protocol):
    """Base protocol for all generative models."""

    def __call__(self, x: jax.Array, *, rngs: nnx.Rngs = None) -> dict:
        """Forward pass."""
        ...

    def generate(self, *, n_samples: int, rngs: nnx.Rngs = None) -> jax.Array:
        """Generate samples from the model."""
        ...

    def loss_fn(self, batch: dict, outputs: dict) -> dict:
        """Compute loss."""
        ...

This enables:

  • Type checking at development time
  • Generic training loops that work with any model
  • Consistent interfaces across model types

JAX and Flax NNX Basics¤

Workshop is built on JAX and Flax NNX. Here are the key concepts:

JAX: Functional Programming¤

JAX provides functional transformations:

import jax
import jax.numpy as jnp

# JIT compilation for speed
@jax.jit
def fast_function(x):
    return jnp.sum(x ** 2)

# Automatic differentiation
def loss_fn(params, x):
    return jnp.sum((params['w'] * x) ** 2)

grad_fn = jax.grad(loss_fn)  # Get gradient function
gradients = grad_fn(params, x)  # Compute gradients

# Vectorization
batch_fn = jax.vmap(fast_function)  # Apply to batches

Flax NNX: Object-Oriented Neural Networks¤

Flax NNX provides a Pythonic API for neural networks:

from flax import nnx

class MyModel(nnx.Module):
    def __init__(self, features: int, *, rngs: nnx.Rngs):
        super().__init__()
        self.dense1 = nnx.Linear(784, features, rngs=rngs)
        self.dense2 = nnx.Linear(features, 10, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.dense1(x)
        x = nnx.relu(x)
        x = self.dense2(x)
        return x

# Create model
rngs = nnx.Rngs(0)
model = MyModel(features=128, rngs=rngs)

# Use model
x = jax.random.normal(jax.random.PRNGKey(0), (32, 784))
y = model(x)

Random Number Generation¤

JAX requires explicit RNG management:

from flax import nnx

# Create RNG
rngs = nnx.Rngs(seed=42)

# Use for model initialization
model = MyModel(rngs=rngs)

# Use for sampling
samples = jax.random.normal(rngs.sample(), (10, 784))

Multi-Modal Support¤

Workshop natively supports multiple data modalities:

Image¤

from workshop.data.datasets.image import CIFAR10Dataset

dataset = CIFAR10Dataset(root='./data', train=True)

Text¤

from workshop.data.datasets.text import WikipediaDataset

dataset = WikipediaDataset(
    tokenizer='bpe',
    max_length=512
)

Audio¤

from workshop.data.datasets.audio import LibriSpeechDataset

dataset = LibriSpeechDataset(
    root='./data',
    sample_rate=16000
)

Protein¤

from workshop.data.protein_dataset import ProteinDataset

dataset = ProteinDataset(
    pdb_dir='./data/pdb',
    with_constraints=True
)

Next Steps¤

Now that you understand the core concepts:

Further Reading¤

Generative Models¤

JAX and Flax¤


Last Updated: 2025-10-13