Quickstart Guide¤
Get started with Workshop in 5 minutes! This guide walks you through installing Workshop and training your first generative model.
Prerequisites¤
- Python 3.10 or higher
- 8GB RAM (16GB recommended)
- Optional: NVIDIA GPU with CUDA 12.0+ for faster training
Step 1: Install Workshop¤
Choose your preferred installation method:
Verify installation:
python -c "import jax; print(f'JAX backend: {jax.default_backend()}')"
# Should print: JAX backend: gpu (or cpu)
Step 2: Train Your First VAE¤
Create a new Python file train_vae.py:
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.factories import create_vae
from workshop.generative_models.core.configuration import ModelConfiguration
# 1. Configure the model
config = ModelConfiguration(
model_type="vae",
latent_dim=32,
input_shape=(28, 28, 1), # MNIST-like images
encoder_features=[64, 128],
decoder_features=[128, 64],
parameters={
"beta": 1.0,
"kl_weight": 1.0,
"reconstruction_loss": "mse"
}
)
# 2. Create model and optimizer
rngs = nnx.Rngs(0)
model = create_vae(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
print("✓ Model created successfully!")
print(f" Latent dimension: {model.latent_dim}")
print(f" Parameters: ~{sum(p.size for p in jax.tree.leaves(nnx.state(model)))/1e6:.2f}M")
# 3. Create synthetic training data (for demo)
def generate_batch(key, batch_size=32):
"""Generate random synthetic data for quick demo."""
return jax.random.normal(key, (batch_size, 28, 28, 1))
# 4. Training step
@nnx.jit
def train_step(model, optimizer, batch, rng):
"""Single training step."""
def loss_fn(model):
outputs = model(batch, rngs=nnx.Rngs(dropout=rng))
loss_dict = model.loss_fn(x=batch, outputs=outputs)
return loss_dict['total_loss'], loss_dict
loss, loss_dict = nnx.value_and_grad(loss_fn, has_aux=True)(model)
optimizer.update(model)
return loss, loss_dict
# 5. Train for a few steps
print("\nTraining for 100 steps...")
key = jax.random.PRNGKey(42)
for step in range(100):
# Generate batch
key, batch_key, train_key = jax.random.split(key, 3)
batch = generate_batch(batch_key)
# Train step
loss, loss_dict = train_step(model, optimizer, batch, train_key)
if step % 20 == 0:
print(f"Step {step:3d} | Loss: {loss:.4f} | "
f"Recon: {loss_dict['reconstruction_loss']:.4f} | "
f"KL: {loss_dict['kl_loss']:.4f}")
print("\n✓ Training complete!")
# 6. Generate samples
print("\nGenerating samples...")
samples = model.sample(n_samples=8, rngs=rngs)
print(f"✓ Generated {samples.shape[0]} samples with shape {samples.shape[1:]}")
# 7. Reconstruct an image
test_image = generate_batch(jax.random.PRNGKey(99), batch_size=1)
reconstructed = model.reconstruct(test_image, deterministic=True, rngs=rngs)
print(f"✓ Reconstructed image with shape {reconstructed.shape}")
print("\n🎉 Success! You've trained your first VAE with Workshop!")
Run the script:
Expected output:
✓ Model created successfully!
Latent dimension: 32
Parameters: ~0.15M
Training for 100 steps...
Step 0 | Loss: 0.5234 | Recon: 0.4891 | KL: 0.0343
Step 20 | Loss: 0.3156 | Recon: 0.2945 | KL: 0.0211
Step 40 | Loss: 0.2498 | Recon: 0.2341 | KL: 0.0157
Step 60 | Loss: 0.2134 | Recon: 0.2002 | KL: 0.0132
Step 80 | Loss: 0.1923 | Recon: 0.1812 | KL: 0.0111
✓ Training complete!
Generating samples...
✓ Generated 8 samples with shape (28, 28, 1)
✓ Reconstructed image with shape (1, 28, 28, 1)
🎉 Success! You've trained your first VAE with Workshop!
Step 3: Visualize Results (Optional)¤
Add visualization to your script:
import matplotlib.pyplot as plt
# Visualize generated samples
fig, axes = plt.subplots(2, 4, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i].squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('Generated Samples from VAE')
plt.tight_layout()
plt.savefig('vae_samples.png')
print("✓ Saved samples to vae_samples.png")
# Visualize reconstruction
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(test_image[0].squeeze(), cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(reconstructed[0].squeeze(), cmap='gray')
axes[1].set_title('Reconstructed')
axes[1].axis('off')
plt.tight_layout()
plt.savefig('vae_reconstruction.png')
print("✓ Saved reconstruction to vae_reconstruction.png")
What You Just Did¤
In just a few minutes, you:
- ✅ Installed Workshop - Set up the complete environment
- ✅ Created a VAE - Built a variational autoencoder with configuration
- ✅ Trained the model - Ran 100 training steps with loss monitoring
- ✅ Generated samples - Created new images from the learned distribution
- ✅ Reconstructed images - Tested the encoder-decoder pipeline
Key Concepts¤
Configuration System¤
config = ModelConfiguration(
model_type="vae", # Model type
latent_dim=32, # Latent space dimension
input_shape=(28, 28, 1), # Input image shape
encoder_features=[64, 128], # Encoder layer sizes
decoder_features=[128, 64], # Decoder layer sizes
parameters={...} # Model-specific parameters
)
Workshop uses a unified configuration system based on Pydantic for type-safe, validated configurations.
Model Factories¤
Factory functions create models from configurations, handling all initialization details.
RNG Management¤
JAX requires explicit random number generators for reproducibility and functional purity.
Training Loop¤
The training loop follows standard JAX/Flax patterns:
- Forward pass:
model(batch)→ outputs - Compute loss:
model.loss_fn()→ loss values - Backward pass:
nnx.value_and_grad()→ gradients - Update weights:
optimizer.update()→ new parameters
Try Different Models¤
Train a Diffusion Model¤
from workshop.generative_models.factories import create_diffusion
# Create DDPM configuration
config = ModelConfiguration(
model_type="ddpm",
input_shape=(28, 28, 1),
num_timesteps=1000,
backbone_type="unet",
backbone_features=[64, 128, 256],
parameters={
"beta_start": 1e-4,
"beta_end": 2e-2,
"beta_schedule": "linear"
}
)
# Create and use
model = create_diffusion(config, rngs=rngs)
samples = model.sample(n_samples=8, rngs=rngs)
Train a GAN¤
from workshop.generative_models.factories import create_gan
config = ModelConfiguration(
model_type="dcgan",
latent_dim=100,
input_shape=(28, 28, 1),
generator_features=[256, 128, 64],
discriminator_features=[64, 128, 256],
parameters={
"generator_lr": 2e-4,
"discriminator_lr": 2e-4,
"label_smoothing": 0.1
}
)
model = create_gan(config, rngs=rngs)
Next Steps¤
Now that you have a working setup, explore more:
-
Learn Core Concepts
Understand generative modeling fundamentals and Workshop architecture
-
Build Your First Model
Step-by-step tutorial to build a VAE from scratch with real data
-
Explore Model Guides
Deep dive into VAEs, GANs, Diffusion, Flows, and more
-
Check Examples
Ready-to-run examples for various models and use cases
Common Next Questions¤
How do I use real data?¤
See the Data Pipeline Guide for loading CIFAR-10, ImageNet, and custom datasets.
How do I save and load models?¤
# Save
from flax.training import checkpoints
checkpoints.save_checkpoint('checkpoints/', model, step=100)
# Load
model = checkpoints.restore_checkpoint('checkpoints/', model)
See Training Guide for details on checkpointing.
How do I train on multiple GPUs?¤
Workshop supports distributed training out of the box. See Distributed Training Guide.
What if I get errors?¤
If you encounter issues, open an issue on GitHub.
Quick Reference¤
Model Types¤
| Type | Factory Function | Use Case |
|---|---|---|
| VAE | create_vae() |
Latent representations, data compression |
| GAN | create_gan() |
High-quality image generation |
| Diffusion | create_diffusion() |
State-of-the-art generation, controllable |
| Flow | create_flow() |
Exact likelihood, invertible transformations |
| EBM | create_ebm() |
Energy-based modeling, composable |
Key Commands¤
# Install
uv sync --all-extras
# Run tests
pytest tests/ -v
# Format code
ruff format src/
# Type check
pyright src/
# Build docs
mkdocs serve
Getting Help¤
- Documentation: Comprehensive guides and API reference
- Examples: Ready-to-run code in
examples/ - Issues: GitHub Issues
- Discussions: GitHub Discussions
Congratulations! 🎉 You've completed the quickstart guide. You're now ready to build more sophisticated generative models with Workshop!
Next recommended step: Core Concepts to understand the architecture better, or First Model Tutorial to build a complete VAE with real data.