Workshop Framework Features Demonstration¤
Level: Intermediate | Runtime: ~1-2 minutes (CPU) | Format: Python + Jupyter
Prerequisites: Basic understanding of generative models and JAX | Target Audience: Users learning the framework's architecture
Overview¤
This example provides a comprehensive tour of the Workshop framework's core features and design patterns. Learn how to leverage the unified configuration system, factory pattern, composable losses, sampling methods, and modality adapters for building production-ready generative models.
What You'll Learn¤
-
Unified Configuration
Type-safe model, training, and data configurations with Pydantic validation
-
Factory Pattern
Consistent model creation interface across all model types (VAE, GAN, diffusion)
-
Composable Losses
Flexible loss composition with weighted components and tracking
-
Sampling Methods
MCMC and SDE sampling for energy-based and diffusion models
-
Modality System
Domain-specific adapters for images, text, audio, proteins, and 3D data
Files¤
This example is available in two formats:
- Python Script:
framework_features_demo.py - Jupyter Notebook:
framework_features_demo.ipynb
Quick Start¤
Run the Python Script¤
# Activate environment
source activate.sh
# Run the example
python examples/generative_models/framework_features_demo.py
Run the Jupyter Notebook¤
# Activate environment
source activate.sh
# Launch Jupyter
jupyter lab examples/generative_models/framework_features_demo.ipynb
Key Concepts¤
1. Unified Configuration System¤
Workshop uses Pydantic-based configuration classes for type-safe, validated model definitions:
from workshop.generative_models.core.configuration import (
ModelConfiguration,
TrainingConfiguration,
OptimizerConfiguration,
DataConfiguration,
)
# Create a model configuration
config = ModelConfiguration(
name="my_vae",
model_class="workshop.generative_models.models.vae.VAE",
input_dim=(28, 28, 1),
hidden_dims=[256, 128],
output_dim=32,
parameters={
"latent_dim": 32,
"beta": 1.0,
}
)
Benefits:
- Automatic validation of types and ranges
- Serialization to JSON/YAML for reproducibility
- Self-documenting through type hints
- Easy parameter sweeps for hyperparameter tuning
2. Factory Pattern¤
The factory pattern provides unified model creation:
from workshop.generative_models.factory import create_model
from flax import nnx
# Setup RNGs
rngs = nnx.Rngs(params=42, dropout=42)
# Create any model from configuration
model = create_model(config, rngs=rngs)
# Test forward pass
outputs = model(test_data, rngs=rngs)
Why Use Factories?
- Consistency across all model types
- Validation before instantiation
- Easy to swap models for experimentation
- Proper RNG management
3. Composable Loss System¤
Combine multiple loss functions with different weights:
from workshop.generative_models.core.losses import (
CompositeLoss,
WeightedLoss,
mse_loss,
mae_loss,
)
# Create composite loss
composite = CompositeLoss([
WeightedLoss(mse_loss, weight=1.0, name="reconstruction"),
WeightedLoss(mae_loss, weight=0.5, name="l1_penalty"),
], return_components=True)
# Compute loss with component tracking
total_loss, components = composite(predictions, targets)
# components = {"reconstruction": 0.15, "l1_penalty": 0.08}
4. Sampling Methods¤
MCMC Sampling (Energy-Based Models)¤
from workshop.generative_models.core.sampling import mcmc_sampling
def log_prob_fn(x):
return -0.5 * jnp.sum(x**2) # Log probability
samples = mcmc_sampling(
log_prob_fn=log_prob_fn,
init_state=jnp.zeros(10),
key=jax.random.key(42),
n_samples=1000,
n_burnin=200,
step_size=0.1,
)
SDE Sampling (Diffusion Models)¤
from workshop.generative_models.core.sampling import sde_sampling
def drift_fn(x, t):
return -x # Drift function
def diffusion_fn(x, t):
return jnp.ones_like(x) * 0.1 # Diffusion coefficient
sample = sde_sampling(
drift_fn=drift_fn,
diffusion_fn=diffusion_fn,
init_state=x0,
t_span=(0.0, 1.0),
key=key,
n_steps=100,
)
5. Modality System¤
Domain-specific features for different data types:
from workshop.generative_models.modalities import get_modality
# Get image modality
image_modality = get_modality('image', rngs=rngs)
# Create dataset
dataset = image_modality.create_dataset(data_config)
# Compute metrics
metrics = image_modality.evaluate(model, test_data)
# metrics = {"fid": 12.5, "is_score": 8.3, ...}
# Get modality adapter
adapter = image_modality.get_adapter('vae')
adapted_model = adapter.adapt(base_model, config)
Available Modalities:
image: Convolutional layers, FID, IS metricstext: Tokenization, perplexity, BLEUaudio: Spectrograms, MFCCsprotein: Structure prediction, sequence modelinggeometric: Point clouds, mesh processing
Code Structure¤
The example demonstrates framework features in five sections:
- Configuration System - Create type-safe configs for models, training, data
- Factory Pattern - Instantiate models from configurations
- Composable Losses - Combine weighted loss functions
- Sampling Methods - MCMC and SDE sampling for generation
- Modality System - Domain-specific adapters and evaluation
Each section is self-contained and can be run independently.
Features Demonstrated¤
- ✅ Type-safe configuration with automatic validation
- ✅ Unified model creation across all types (VAE, GAN, diffusion, flow, EBM)
- ✅ Flexible loss composition with component tracking
- ✅ MCMC sampling for energy-based models
- ✅ SDE sampling for diffusion models
- ✅ Modality-specific dataset loading and evaluation
- ✅ Proper RNG management with
nnx.Rngs - ✅ JIT compilation for performance
Experiments to Try¤
- Create Different Model Types
# Try creating different models
gan_config = ModelConfiguration(
name="my_gan",
model_class="workshop.generative_models.models.gan.GAN",
# ...
)
gan = create_model(gan_config, rngs=rngs)
- Custom Loss Combinations
# Add perceptual loss to composite
from workshop.generative_models.core.losses import PerceptualLoss
composite = CompositeLoss([
WeightedLoss(mse_loss, weight=1.0, name="recon"),
WeightedLoss(PerceptualLoss(), weight=0.1, name="perceptual"),
])
- Adjust Sampling Parameters
# Try different MCMC settings
samples = mcmc_sampling(
log_prob_fn=log_prob_fn,
init_state=x0,
key=key,
n_samples=5000, # More samples
step_size=0.01, # Smaller steps
)
- Experiment with Modalities
# Try different modalities
audio_modality = get_modality('audio', rngs=rngs)
audio_dataset = audio_modality.create_dataset(audio_config)
Next Steps¤
-
VAE Examples
Learn VAE implementation patterns
-
GAN Examples
Explore GAN training
-
Diffusion Examples
Understand diffusion models
-
Loss Examples
Deep dive into loss functions
Troubleshooting¤
Missing Configuration Fields¤
Symptom: ValidationError when creating configuration
Solution: Check required fields in the configuration class
# View required fields
from workshop.generative_models.core.configuration import ModelConfiguration
print(ModelConfiguration.model_fields)
Factory Creation Fails¤
Symptom: TypeError or AttributeError during model creation
Solution: Verify model class path and required parameters
# Ensure model class is valid
config = ModelConfiguration(
model_class="workshop.generative_models.models.vae.VAE", # Full path
parameters={"latent_dim": 32}, # Required VAE parameters
)
RNG Key Errors¤
Symptom: KeyError for missing RNG streams
Solution: Initialize all required RNG streams
# VAE needs params, dropout, and sample streams
rngs = nnx.Rngs(
params=42,
dropout=43,
sample=44,
)
Additional Resources¤
Documentation¤
- Configuration System Guide - Deep dive into configurations
- Factory Pattern Guide - Advanced factory usage
- Loss Functions API - Complete loss function reference
- Sampling Methods API - Sampling algorithm details
Related Examples¤
- Loss Examples - Complete loss function showcase
- VAE MNIST Tutorial - Step-by-step VAE implementation
- Simple Diffusion - Diffusion model basics
Papers¤
- Pydantic: Pydantic Documentation - Configuration validation
- JAX: JAX Documentation - Array programming and JIT compilation
- Flax NNX: Flax NNX Guide - Neural network library