Skip to content

Workshop Framework Features Demonstration¤

Level: Intermediate | Runtime: ~1-2 minutes (CPU) | Format: Python + Jupyter

Prerequisites: Basic understanding of generative models and JAX | Target Audience: Users learning the framework's architecture

Overview¤

This example provides a comprehensive tour of the Workshop framework's core features and design patterns. Learn how to leverage the unified configuration system, factory pattern, composable losses, sampling methods, and modality adapters for building production-ready generative models.

What You'll Learn¤

  • Unified Configuration


    Type-safe model, training, and data configurations with Pydantic validation

  • Factory Pattern


    Consistent model creation interface across all model types (VAE, GAN, diffusion)

  • Composable Losses


    Flexible loss composition with weighted components and tracking

  • Sampling Methods


    MCMC and SDE sampling for energy-based and diffusion models

  • Modality System


    Domain-specific adapters for images, text, audio, proteins, and 3D data

Files¤

This example is available in two formats:

Quick Start¤

Run the Python Script¤

# Activate environment
source activate.sh

# Run the example
python examples/generative_models/framework_features_demo.py

Run the Jupyter Notebook¤

# Activate environment
source activate.sh

# Launch Jupyter
jupyter lab examples/generative_models/framework_features_demo.ipynb

Key Concepts¤

1. Unified Configuration System¤

Workshop uses Pydantic-based configuration classes for type-safe, validated model definitions:

from workshop.generative_models.core.configuration import (
    ModelConfiguration,
    TrainingConfiguration,
    OptimizerConfiguration,
    DataConfiguration,
)

# Create a model configuration
config = ModelConfiguration(
    name="my_vae",
    model_class="workshop.generative_models.models.vae.VAE",
    input_dim=(28, 28, 1),
    hidden_dims=[256, 128],
    output_dim=32,
    parameters={
        "latent_dim": 32,
        "beta": 1.0,
    }
)

Benefits:

  • Automatic validation of types and ranges
  • Serialization to JSON/YAML for reproducibility
  • Self-documenting through type hints
  • Easy parameter sweeps for hyperparameter tuning

2. Factory Pattern¤

The factory pattern provides unified model creation:

from workshop.generative_models.factory import create_model
from flax import nnx

# Setup RNGs
rngs = nnx.Rngs(params=42, dropout=42)

# Create any model from configuration
model = create_model(config, rngs=rngs)

# Test forward pass
outputs = model(test_data, rngs=rngs)

Why Use Factories?

  • Consistency across all model types
  • Validation before instantiation
  • Easy to swap models for experimentation
  • Proper RNG management

3. Composable Loss System¤

Combine multiple loss functions with different weights:

\[L_{\text{total}} = \sum_{i=1}^{n} w_i \cdot L_i(\text{pred}, \text{target})\]
from workshop.generative_models.core.losses import (
    CompositeLoss,
    WeightedLoss,
    mse_loss,
    mae_loss,
)

# Create composite loss
composite = CompositeLoss([
    WeightedLoss(mse_loss, weight=1.0, name="reconstruction"),
    WeightedLoss(mae_loss, weight=0.5, name="l1_penalty"),
], return_components=True)

# Compute loss with component tracking
total_loss, components = composite(predictions, targets)
# components = {"reconstruction": 0.15, "l1_penalty": 0.08}

4. Sampling Methods¤

MCMC Sampling (Energy-Based Models)¤

from workshop.generative_models.core.sampling import mcmc_sampling

def log_prob_fn(x):
    return -0.5 * jnp.sum(x**2)  # Log probability

samples = mcmc_sampling(
    log_prob_fn=log_prob_fn,
    init_state=jnp.zeros(10),
    key=jax.random.key(42),
    n_samples=1000,
    n_burnin=200,
    step_size=0.1,
)

SDE Sampling (Diffusion Models)¤

from workshop.generative_models.core.sampling import sde_sampling

def drift_fn(x, t):
    return -x  # Drift function

def diffusion_fn(x, t):
    return jnp.ones_like(x) * 0.1  # Diffusion coefficient

sample = sde_sampling(
    drift_fn=drift_fn,
    diffusion_fn=diffusion_fn,
    init_state=x0,
    t_span=(0.0, 1.0),
    key=key,
    n_steps=100,
)

5. Modality System¤

Domain-specific features for different data types:

from workshop.generative_models.modalities import get_modality

# Get image modality
image_modality = get_modality('image', rngs=rngs)

# Create dataset
dataset = image_modality.create_dataset(data_config)

# Compute metrics
metrics = image_modality.evaluate(model, test_data)
# metrics = {"fid": 12.5, "is_score": 8.3, ...}

# Get modality adapter
adapter = image_modality.get_adapter('vae')
adapted_model = adapter.adapt(base_model, config)

Available Modalities:

  • image: Convolutional layers, FID, IS metrics
  • text: Tokenization, perplexity, BLEU
  • audio: Spectrograms, MFCCs
  • protein: Structure prediction, sequence modeling
  • geometric: Point clouds, mesh processing

Code Structure¤

The example demonstrates framework features in five sections:

  1. Configuration System - Create type-safe configs for models, training, data
  2. Factory Pattern - Instantiate models from configurations
  3. Composable Losses - Combine weighted loss functions
  4. Sampling Methods - MCMC and SDE sampling for generation
  5. Modality System - Domain-specific adapters and evaluation

Each section is self-contained and can be run independently.

Features Demonstrated¤

  • ✅ Type-safe configuration with automatic validation
  • ✅ Unified model creation across all types (VAE, GAN, diffusion, flow, EBM)
  • ✅ Flexible loss composition with component tracking
  • ✅ MCMC sampling for energy-based models
  • ✅ SDE sampling for diffusion models
  • ✅ Modality-specific dataset loading and evaluation
  • ✅ Proper RNG management with nnx.Rngs
  • ✅ JIT compilation for performance

Experiments to Try¤

  1. Create Different Model Types
# Try creating different models
gan_config = ModelConfiguration(
    name="my_gan",
    model_class="workshop.generative_models.models.gan.GAN",
    # ...
)
gan = create_model(gan_config, rngs=rngs)
  1. Custom Loss Combinations
# Add perceptual loss to composite
from workshop.generative_models.core.losses import PerceptualLoss

composite = CompositeLoss([
    WeightedLoss(mse_loss, weight=1.0, name="recon"),
    WeightedLoss(PerceptualLoss(), weight=0.1, name="perceptual"),
])
  1. Adjust Sampling Parameters
# Try different MCMC settings
samples = mcmc_sampling(
    log_prob_fn=log_prob_fn,
    init_state=x0,
    key=key,
    n_samples=5000,  # More samples
    step_size=0.01,  # Smaller steps
)
  1. Experiment with Modalities
# Try different modalities
audio_modality = get_modality('audio', rngs=rngs)
audio_dataset = audio_modality.create_dataset(audio_config)

Next Steps¤

Troubleshooting¤

Missing Configuration Fields¤

Symptom: ValidationError when creating configuration

Solution: Check required fields in the configuration class

# View required fields
from workshop.generative_models.core.configuration import ModelConfiguration
print(ModelConfiguration.model_fields)

Factory Creation Fails¤

Symptom: TypeError or AttributeError during model creation

Solution: Verify model class path and required parameters

# Ensure model class is valid
config = ModelConfiguration(
    model_class="workshop.generative_models.models.vae.VAE",  # Full path
    parameters={"latent_dim": 32},  # Required VAE parameters
)

RNG Key Errors¤

Symptom: KeyError for missing RNG streams

Solution: Initialize all required RNG streams

# VAE needs params, dropout, and sample streams
rngs = nnx.Rngs(
    params=42,
    dropout=43,
    sample=44,
)

Additional Resources¤

Documentation¤

Papers¤