Skip to content

Quickstart Guide¤

Get started with Workshop in 5 minutes! This guide walks you through installing Workshop and training your first generative model.

Prerequisites¤

  • Python 3.10 or higher
  • 8GB RAM (16GB recommended)
  • Optional: NVIDIA GPU with CUDA 12.0+ for faster training

Step 1: Install Workshop¤

Choose your preferred installation method:

# Clone repository
git clone https://github.com/mahdi-shafiei/workshop.git
cd workshop

# Install with uv (fastest)
uv venv && source .venv/bin/activate
uv sync --all-extras

# Or with pip
python -m venv .venv && source .venv/bin/activate
pip install -e '.[dev]'
pip install workshop-generative

Verify installation:

python -c "import jax; print(f'JAX backend: {jax.default_backend()}')"
# Should print: JAX backend: gpu (or cpu)

Step 2: Train Your First VAE¤

Create a new Python file train_vae.py:

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.factories import create_vae
from workshop.generative_models.core.configuration import ModelConfiguration

# 1. Configure the model
config = ModelConfiguration(
    model_type="vae",
    latent_dim=32,
    input_shape=(28, 28, 1),  # MNIST-like images
    encoder_features=[64, 128],
    decoder_features=[128, 64],
    parameters={
        "beta": 1.0,
        "kl_weight": 1.0,
        "reconstruction_loss": "mse"
    }
)

# 2. Create model and optimizer
rngs = nnx.Rngs(0)
model = create_vae(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(1e-3))

print("✓ Model created successfully!")
print(f"  Latent dimension: {model.latent_dim}")
print(f"  Parameters: ~{sum(p.size for p in jax.tree.leaves(nnx.state(model)))/1e6:.2f}M")

# 3. Create synthetic training data (for demo)
def generate_batch(key, batch_size=32):
    """Generate random synthetic data for quick demo."""
    return jax.random.normal(key, (batch_size, 28, 28, 1))

# 4. Training step
@nnx.jit
def train_step(model, optimizer, batch, rng):
    """Single training step."""
    def loss_fn(model):
        outputs = model(batch, rngs=nnx.Rngs(dropout=rng))
        loss_dict = model.loss_fn(x=batch, outputs=outputs)
        return loss_dict['total_loss'], loss_dict

    loss, loss_dict = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model)
    return loss, loss_dict

# 5. Train for a few steps
print("\nTraining for 100 steps...")
key = jax.random.PRNGKey(42)

for step in range(100):
    # Generate batch
    key, batch_key, train_key = jax.random.split(key, 3)
    batch = generate_batch(batch_key)

    # Train step
    loss, loss_dict = train_step(model, optimizer, batch, train_key)

    if step % 20 == 0:
        print(f"Step {step:3d} | Loss: {loss:.4f} | "
              f"Recon: {loss_dict['reconstruction_loss']:.4f} | "
              f"KL: {loss_dict['kl_loss']:.4f}")

print("\n✓ Training complete!")

# 6. Generate samples
print("\nGenerating samples...")
samples = model.sample(n_samples=8, rngs=rngs)
print(f"✓ Generated {samples.shape[0]} samples with shape {samples.shape[1:]}")

# 7. Reconstruct an image
test_image = generate_batch(jax.random.PRNGKey(99), batch_size=1)
reconstructed = model.reconstruct(test_image, deterministic=True, rngs=rngs)
print(f"✓ Reconstructed image with shape {reconstructed.shape}")

print("\n🎉 Success! You've trained your first VAE with Workshop!")

Run the script:

python train_vae.py

Expected output:

✓ Model created successfully!
  Latent dimension: 32
  Parameters: ~0.15M

Training for 100 steps...
Step   0 | Loss: 0.5234 | Recon: 0.4891 | KL: 0.0343
Step  20 | Loss: 0.3156 | Recon: 0.2945 | KL: 0.0211
Step  40 | Loss: 0.2498 | Recon: 0.2341 | KL: 0.0157
Step  60 | Loss: 0.2134 | Recon: 0.2002 | KL: 0.0132
Step  80 | Loss: 0.1923 | Recon: 0.1812 | KL: 0.0111

✓ Training complete!

Generating samples...
✓ Generated 8 samples with shape (28, 28, 1)
✓ Reconstructed image with shape (1, 28, 28, 1)

🎉 Success! You've trained your first VAE with Workshop!

Step 3: Visualize Results (Optional)¤

Add visualization to your script:

import matplotlib.pyplot as plt

# Visualize generated samples
fig, axes = plt.subplots(2, 4, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i].squeeze(), cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Samples from VAE')
plt.tight_layout()
plt.savefig('vae_samples.png')
print("✓ Saved samples to vae_samples.png")

# Visualize reconstruction
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(test_image[0].squeeze(), cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(reconstructed[0].squeeze(), cmap='gray')
axes[1].set_title('Reconstructed')
axes[1].axis('off')
plt.tight_layout()
plt.savefig('vae_reconstruction.png')
print("✓ Saved reconstruction to vae_reconstruction.png")

What You Just Did¤

In just a few minutes, you:

  1. Installed Workshop - Set up the complete environment
  2. Created a VAE - Built a variational autoencoder with configuration
  3. Trained the model - Ran 100 training steps with loss monitoring
  4. Generated samples - Created new images from the learned distribution
  5. Reconstructed images - Tested the encoder-decoder pipeline

Key Concepts¤

Configuration System¤

config = ModelConfiguration(
    model_type="vae",           # Model type
    latent_dim=32,              # Latent space dimension
    input_shape=(28, 28, 1),    # Input image shape
    encoder_features=[64, 128], # Encoder layer sizes
    decoder_features=[128, 64], # Decoder layer sizes
    parameters={...}            # Model-specific parameters
)

Workshop uses a unified configuration system based on Pydantic for type-safe, validated configurations.

Model Factories¤

model = create_vae(config, rngs=rngs)

Factory functions create models from configurations, handling all initialization details.

RNG Management¤

rngs = nnx.Rngs(0)  # Random number generator

JAX requires explicit random number generators for reproducibility and functional purity.

Training Loop¤

The training loop follows standard JAX/Flax patterns:

  1. Forward pass: model(batch) → outputs
  2. Compute loss: model.loss_fn() → loss values
  3. Backward pass: nnx.value_and_grad() → gradients
  4. Update weights: optimizer.update() → new parameters

Try Different Models¤

Train a Diffusion Model¤

from workshop.generative_models.factories import create_diffusion

# Create DDPM configuration
config = ModelConfiguration(
    model_type="ddpm",
    input_shape=(28, 28, 1),
    num_timesteps=1000,
    backbone_type="unet",
    backbone_features=[64, 128, 256],
    parameters={
        "beta_start": 1e-4,
        "beta_end": 2e-2,
        "beta_schedule": "linear"
    }
)

# Create and use
model = create_diffusion(config, rngs=rngs)
samples = model.sample(n_samples=8, rngs=rngs)

Train a GAN¤

from workshop.generative_models.factories import create_gan

config = ModelConfiguration(
    model_type="dcgan",
    latent_dim=100,
    input_shape=(28, 28, 1),
    generator_features=[256, 128, 64],
    discriminator_features=[64, 128, 256],
    parameters={
        "generator_lr": 2e-4,
        "discriminator_lr": 2e-4,
        "label_smoothing": 0.1
    }
)

model = create_gan(config, rngs=rngs)

Next Steps¤

Now that you have a working setup, explore more:

  • Learn Core Concepts


    Understand generative modeling fundamentals and Workshop architecture

    Core Concepts

  • Build Your First Model


    Step-by-step tutorial to build a VAE from scratch with real data

    First Model Tutorial

  • Explore Model Guides


    Deep dive into VAEs, GANs, Diffusion, Flows, and more

    VAE Guide Model Implementations

  • Check Examples


    Ready-to-run examples for various models and use cases

    Examples

Common Next Questions¤

How do I use real data?¤

See the Data Pipeline Guide for loading CIFAR-10, ImageNet, and custom datasets.

How do I save and load models?¤

# Save
from flax.training import checkpoints
checkpoints.save_checkpoint('checkpoints/', model, step=100)

# Load
model = checkpoints.restore_checkpoint('checkpoints/', model)

See Training Guide for details on checkpointing.

How do I train on multiple GPUs?¤

Workshop supports distributed training out of the box. See Distributed Training Guide.

What if I get errors?¤

If you encounter issues, open an issue on GitHub.

Quick Reference¤

Model Types¤

Type Factory Function Use Case
VAE create_vae() Latent representations, data compression
GAN create_gan() High-quality image generation
Diffusion create_diffusion() State-of-the-art generation, controllable
Flow create_flow() Exact likelihood, invertible transformations
EBM create_ebm() Energy-based modeling, composable

Key Commands¤

# Install
uv sync --all-extras

# Run tests
pytest tests/ -v

# Format code
ruff format src/

# Type check
pyright src/

# Build docs
mkdocs serve

Getting Help¤


Congratulations! 🎉 You've completed the quickstart guide. You're now ready to build more sophisticated generative models with Workshop!

Next recommended step: Core Concepts to understand the architecture better, or First Model Tutorial to build a complete VAE with real data.