Training System Overview¤
Workshop provides a robust training system built on JAX/Flax NNX, working towards production-ready status. The training infrastructure handles the complete training lifecycle, from model initialization to checkpointing and evaluation.
-
Easy to Use
Simple, intuitive API for training any generative model with sensible defaults
-
Highly Configurable
Type-safe configuration system with Pydantic validation and YAML support
-
Flexible
Support for custom training loops, optimizers, and learning rate schedules
-
Research-Focused
Checkpointing, resuming, logging, and metrics tracking built-in for experimentation
Training Architecture¤
The Workshop training system follows a modular, composable architecture:
graph TB
subgraph "Training Configuration"
TC[TrainingConfiguration]
OC[OptimizerConfiguration]
SC[SchedulerConfiguration]
TC -->|contains| OC
TC -->|optional| SC
end
subgraph "Trainer"
T[Trainer]
TS[TrainingState]
OPT[Optimizer]
T -->|manages| TS
T -->|uses| OPT
end
subgraph "Training Loop"
STEP[train_step]
VAL[validate_step]
EPOCH[train_epoch]
EVAL[evaluate]
EPOCH -->|calls| STEP
EPOCH -->|periodic| VAL
EVAL -->|uses| VAL
end
subgraph "Persistence"
CKPT[Checkpointing]
LOG[Logging]
METRICS[MetricsLogger]
T -->|saves| CKPT
T -->|writes| LOG
T -->|tracks| METRICS
end
TC -->|configures| T
T -->|executes| EPOCH
STEP -->|updates| TS
CKPT -->|saves| TS
Core Components¤
1. Trainer¤
The Trainer class is the central component for training generative models:
from workshop.generative_models.training import Trainer
from workshop.generative_models.core.configuration import (
TrainingConfiguration,
OptimizerConfiguration,
)
# Create optimizer configuration
optimizer_config = OptimizerConfiguration(
name="adam_optimizer",
optimizer_type="adam",
learning_rate=1e-3,
gradient_clip_norm=1.0,
)
# Create training configuration
training_config = TrainingConfiguration(
name="vae_training",
batch_size=128,
num_epochs=100,
optimizer=optimizer_config,
save_frequency=1000,
log_frequency=100,
)
# Initialize trainer
trainer = Trainer(
model=model,
training_config=training_config,
train_data_loader=train_loader,
val_data_loader=val_loader,
workdir="./experiments/vae",
)
Key Features:
- Type-Safe Configuration: Uses Pydantic-based configuration with validation
- Automatic Setup: Handles optimizer initialization, state management, and checkpointing
- Flexible Training: Supports custom loss functions and training loops
- Built-in Logging: Integrates with Workshop's logging system and external trackers
- Checkpointing: Automatic saving and loading of training state
2. TrainingState¤
The TrainingState is a PyTree that holds all training state:
from workshop.generative_models.training.trainer import TrainingState
state = TrainingState.create(
params=model_params,
opt_state=optimizer.init(model_params),
rng=jax.random.PRNGKey(42),
step=0,
best_loss=float("inf"),
)
Components:
params: Model parametersopt_state: Optimizer staterng: JAX random number generatorstep: Current training stepbest_loss: Best validation loss (for early stopping)
Benefits:
- JAX-compatible PyTree structure
- Easy to save/load with checkpointing
- Immutable updates for functional programming
- Integrates with JAX transformations (jit, grad, etc.)
3. Configuration System¤
Workshop uses a unified, type-safe configuration system based on Pydantic:
from workshop.generative_models.core.configuration import (
TrainingConfiguration,
OptimizerConfiguration,
SchedulerConfiguration,
)
# Optimizer configuration
optimizer = OptimizerConfiguration(
name="adamw_optimizer",
optimizer_type="adamw",
learning_rate=3e-4,
weight_decay=0.01,
beta1=0.9,
beta2=0.999,
)
# Learning rate scheduler
scheduler = SchedulerConfiguration(
name="cosine_scheduler",
scheduler_type="cosine",
warmup_steps=1000,
cycle_length=10000,
min_lr_ratio=0.1,
)
# Complete training configuration
training_config = TrainingConfiguration(
name="diffusion_training",
batch_size=64,
num_epochs=200,
optimizer=optimizer,
scheduler=scheduler,
gradient_clip_norm=1.0,
save_frequency=5000,
)
Advantages:
- Type Safety: Pydantic validates all fields at creation
- IDE Support: Full autocompletion and type checking
- Serialization: Easy YAML/JSON save/load
- Validation: Built-in constraints and custom validators
- Documentation: Self-documenting with field descriptions
See Configuration Guide for complete details.
Training Loop Mechanics¤
Basic Training Flow¤
sequenceDiagram
participant User
participant Trainer
participant Model
participant Optimizer
participant Storage
User->>Trainer: trainer.train_epoch()
loop For each batch
Trainer->>Model: compute loss
Model-->>Trainer: loss + metrics
Trainer->>Trainer: compute gradients
Trainer->>Optimizer: update parameters
Optimizer-->>Trainer: new parameters
alt Step % save_frequency == 0
Trainer->>Storage: save checkpoint
end
alt Step % log_frequency == 0
Trainer->>Storage: log metrics
end
end
Trainer-->>User: epoch metrics
Training Step¤
The core training step is JIT-compiled for performance:
def _train_step(state, batch):
"""Single training step (JIT-compiled)."""
rng, step_rng = jax.random.split(state["rng"])
# Define loss function
def loss_fn(params):
loss, metrics = model.loss_fn(params, batch, step_rng)
return loss, metrics
# Compute gradients
(loss, metrics), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state["params"])
# Update parameters
updates, opt_state = optimizer.update(
grads, state["opt_state"], state["params"]
)
params = optax.apply_updates(state["params"], updates)
# Create new state
new_state = {
"step": state["step"] + 1,
"params": params,
"opt_state": opt_state,
"rng": rng,
}
return new_state, metrics
Key Points:
- Functional Style: Pure function with immutable updates
- JIT Compilation: Compiled once, runs fast
- RNG Splitting: Proper random number handling
- Gradient Computation: Uses
jax.value_and_gradfor efficiency - State Updates: Returns new state (immutable)
Validation Step¤
Validation uses the same loss function without updates:
def _validate_step(state, batch):
"""Single validation step."""
_, val_rng = jax.random.split(state["rng"])
# Compute validation loss (no gradients)
loss, metrics = model.loss_fn(state["params"], batch, val_rng)
metrics["loss"] = loss
return metrics
Epoch Training¤
An epoch iterates over the entire dataset:
def train_epoch(trainer):
"""Train for one epoch."""
data_iter = trainer.train_data_loader(trainer.training_config.batch_size)
epoch_metrics = []
for _ in range(trainer.steps_per_epoch):
batch = next(data_iter)
# Training step
trainer.state, metrics = trainer.train_step_fn(trainer.state, batch)
epoch_metrics.append(metrics)
# Periodic checkpointing
if trainer.state["step"] % trainer.training_config.save_frequency == 0:
trainer.save_checkpoint()
# Average metrics
avg_metrics = {
key: sum(m[key] for m in epoch_metrics) / len(epoch_metrics)
for key in epoch_metrics[0].keys()
if key != "step"
}
return avg_metrics
Checkpointing¤
Automatic Checkpointing¤
Checkpoints are automatically saved during training:
# Configure checkpointing
training_config = TrainingConfiguration(
name="my_training",
batch_size=32,
num_epochs=100,
optimizer=optimizer_config,
checkpoint_dir="./checkpoints",
save_frequency=1000, # Save every 1000 steps
max_checkpoints=5, # Keep last 5 checkpoints
)
trainer = Trainer(
model=model,
training_config=training_config,
)
# Training automatically saves checkpoints
trainer.train_epoch()
Manual Checkpointing¤
Save and load checkpoints manually:
# Save checkpoint
trainer.save_checkpoint("./checkpoints/my_checkpoint.pkl")
# Load checkpoint
trainer.load_checkpoint("./checkpoints/my_checkpoint.pkl")
# Resume training
trainer.train_epoch() # Continues from loaded state
Checkpoint Contents¤
Each checkpoint contains the complete training state:
{
"step": 5000,
"params": {...}, # Model parameters
"opt_state": {...}, # Optimizer state
"rng": Array(...), # RNG state
}
Best Practices:
- Save checkpoints to fast storage (SSD) for quick I/O
- Use
max_checkpointsto limit disk usage - Save best model separately based on validation metrics
- Include step number in checkpoint filenames
- Test checkpoint loading before long training runs
Logging and Monitoring¤
Built-in Logging¤
Workshop includes structured logging:
from workshop.generative_models.utils.logging import Logger, MetricsLogger
# Create loggers
logger = Logger(log_dir="./logs")
metrics_logger = MetricsLogger(log_dir="./logs/metrics")
# Initialize trainer with loggers
trainer = Trainer(
model=model,
training_config=training_config,
logger=logger,
metrics_logger=metrics_logger,
)
# Logs are written automatically during training
trainer.train_epoch()
Custom Logging Callbacks¤
Implement custom logging with callbacks:
def custom_log_callback(step, metrics, prefix="train"):
"""Custom logging function."""
print(f"[{prefix}] Step {step}: Loss = {metrics['loss']:.4f}")
# Log to external system (e.g., wandb, tensorboard)
if wandb_enabled:
wandb.log({f"{prefix}/{k}": v for k, v in metrics.items()}, step=step)
trainer = Trainer(
model=model,
training_config=training_config,
log_callback=custom_log_callback,
)
Metrics Tracking¤
The trainer automatically tracks:
- Training Loss: Loss value for each batch
- Validation Loss: Periodic validation metrics
- Learning Rate: Current learning rate (with schedulers)
- Gradient Norms: L2 norm of gradients
- Model-Specific Metrics: KL divergence, reconstruction loss, etc.
Example metrics access:
# Train for an epoch
metrics = trainer.train_epoch()
print(f"Epoch loss: {metrics['loss']:.4f}")
# Access training history
for step, metric in enumerate(trainer.train_metrics):
print(f"Step {step}: {metric['loss']:.4f}")
# Validation metrics
val_metrics = trainer.evaluate(val_data, batch_size=64)
print(f"Validation loss: {val_metrics['loss']:.4f}")
Optimizers¤
Workshop supports multiple optimizers through Optax:
Available Optimizers¤
| Optimizer | Best For | Key Parameters |
|---|---|---|
| Adam | General purpose, most models | learning_rate, beta1, beta2 |
| AdamW | Transformers, weight decay needed | learning_rate, weight_decay |
| SGD | Large batch training, momentum | learning_rate, momentum |
| RMSProp | RNNs, non-stationary objectives | learning_rate, decay |
| AdaGrad | Sparse gradients, NLP | learning_rate |
Optimizer Configuration¤
# Adam optimizer
adam_config = OptimizerConfiguration(
name="adam",
optimizer_type="adam",
learning_rate=1e-3,
beta1=0.9,
beta2=0.999,
eps=1e-8,
)
# AdamW with weight decay
adamw_config = OptimizerConfiguration(
name="adamw",
optimizer_type="adamw",
learning_rate=3e-4,
weight_decay=0.01,
)
# SGD with momentum
sgd_config = OptimizerConfiguration(
name="sgd",
optimizer_type="sgd",
learning_rate=0.1,
momentum=0.9,
nesterov=True,
)
Gradient Clipping¤
Prevent gradient explosion with clipping:
# Clip by global norm (recommended)
optimizer_config = OptimizerConfiguration(
name="clipped_adam",
optimizer_type="adam",
learning_rate=1e-3,
gradient_clip_norm=1.0, # Clip to norm of 1.0
)
# Clip by value
optimizer_config = OptimizerConfiguration(
name="value_clipped_adam",
optimizer_type="adam",
learning_rate=1e-3,
gradient_clip_value=0.5, # Clip values to [-0.5, 0.5]
)
Learning Rate Schedules¤
Available Schedules¤
| Schedule | Description | Use Case |
|---|---|---|
| Constant | Fixed learning rate | Simple training, debugging |
| Linear | Linear decay | Short training runs |
| Cosine | Cosine annealing | Most deep learning (recommended) |
| Exponential | Exponential decay | Traditional ML |
| Step | Step-wise decay | Milestone-based training |
| MultiStep | Multiple milestones | Fine-grained control |
Schedule Configuration¤
# Cosine schedule with warmup (recommended)
cosine_schedule = SchedulerConfiguration(
name="cosine_warmup",
scheduler_type="cosine",
warmup_steps=1000,
cycle_length=50000,
min_lr_ratio=0.1, # End at 10% of initial LR
)
# Linear schedule
linear_schedule = SchedulerConfiguration(
name="linear_decay",
scheduler_type="linear",
warmup_steps=500,
total_steps=10000,
min_lr_ratio=0.0, # Decay to 0
)
# Step schedule
step_schedule = SchedulerConfiguration(
name="step_decay",
scheduler_type="step",
step_size=5000, # Decay every 5000 steps
gamma=0.1, # Multiply LR by 0.1
)
# MultiStep schedule
multistep_schedule = SchedulerConfiguration(
name="multistep",
scheduler_type="multistep",
milestones=[10000, 20000, 30000],
gamma=0.1,
)
Schedule Visualization¤
import matplotlib.pyplot as plt
import numpy as np
def visualize_schedule(scheduler_config, base_lr=1e-3, total_steps=10000):
"""Visualize learning rate schedule."""
schedule = create_schedule(scheduler_config, base_lr)
steps = np.arange(total_steps)
lrs = [schedule(step) for step in steps]
plt.figure(figsize=(10, 4))
plt.plot(steps, lrs)
plt.xlabel("Step")
plt.ylabel("Learning Rate")
plt.title(f"{scheduler_config.scheduler_type} Schedule")
plt.grid(True)
plt.show()
# Visualize cosine schedule
visualize_schedule(cosine_schedule)
Complete Training Example¤
Here's a complete training workflow:
from workshop.generative_models.core.configuration import (
ModelConfiguration,
TrainingConfiguration,
OptimizerConfiguration,
SchedulerConfiguration,
)
from workshop.generative_models.factory import create_model
from workshop.generative_models.training import Trainer
from flax import nnx
# 1. Create model configuration
model_config = ModelConfiguration(
name="vae_mnist",
model_class="workshop.generative_models.models.vae.base.VAE",
input_dim=(28, 28, 1),
hidden_dims=[512, 256],
output_dim=64,
parameters={"beta": 1.0},
)
# 2. Initialize model
rngs = nnx.Rngs(42)
model = create_model(config=model_config, rngs=rngs)
# 3. Configure optimizer
optimizer_config = OptimizerConfiguration(
name="adamw",
optimizer_type="adamw",
learning_rate=3e-4,
weight_decay=0.01,
gradient_clip_norm=1.0,
)
# 4. Configure learning rate schedule
scheduler_config = SchedulerConfiguration(
name="cosine_warmup",
scheduler_type="cosine",
warmup_steps=1000,
cycle_length=50000,
min_lr_ratio=0.1,
)
# 5. Create training configuration
training_config = TrainingConfiguration(
name="vae_training",
batch_size=128,
num_epochs=100,
optimizer=optimizer_config,
scheduler=scheduler_config,
save_frequency=5000,
log_frequency=100,
checkpoint_dir="./checkpoints/vae",
)
# 6. Initialize trainer
trainer = Trainer(
model=model,
training_config=training_config,
train_data_loader=train_loader,
val_data_loader=val_loader,
workdir="./experiments/vae",
)
# 7. Train
for epoch in range(training_config.num_epochs):
# Train epoch
train_metrics = trainer.train_epoch()
print(f"Epoch {epoch + 1}: Train Loss = {train_metrics['loss']:.4f}")
# Validate
val_metrics = trainer.evaluate(val_data, batch_size=128)
print(f"Epoch {epoch + 1}: Val Loss = {val_metrics['loss']:.4f}")
# Save best model
if val_metrics['loss'] < trainer.state.get('best_loss', float('inf')):
trainer.save_checkpoint(f"./checkpoints/vae/best_model.pkl")
# 8. Generate samples
samples = trainer.generate_samples(num_samples=16)
Key Design Principles¤
1. Functional Programming¤
The training system uses functional programming principles:
- Immutable Updates: States are never modified in-place
- Pure Functions: Training steps are deterministic
- JAX Transformations: Compatible with jit, grad, vmap, pmap
# Immutable state updates
new_state = {**old_state, "step": old_state["step"] + 1}
# Pure function (same inputs → same outputs)
@jax.jit
def train_step(state, batch):
# No side effects
return new_state, metrics
2. Type Safety¤
All configurations use Pydantic for type safety:
# Type-safe configuration
config = TrainingConfiguration(
name="my_training",
batch_size=32, # int: validated
num_epochs=100, # int: validated
optimizer=opt_config, # OptimizerConfiguration: validated
)
# Invalid configuration raises error at creation
try:
bad_config = TrainingConfiguration(
name="bad",
batch_size="invalid", # TypeError: not an int
)
except ValidationError as e:
print(e)
3. Composability¤
Components are designed to be composable:
# Compose optimizer with gradient clipping
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=1e-3),
)
# Compose learning rate schedules
schedule = optax.join_schedules(
schedules=[warmup_schedule, cosine_schedule],
boundaries=[warmup_steps],
)
# Compose training callbacks
def composed_callback(step, metrics, prefix="train"):
log_to_file(step, metrics, prefix)
log_to_wandb(step, metrics, prefix)
update_progress_bar(step, metrics)
4. JAX-First Design¤
Leverages JAX for performance and scalability:
- JIT Compilation: Training steps are JIT-compiled
- Automatic Differentiation: Gradients computed with
jax.grad - Device Agnostic: Runs on CPU, GPU, or TPU
- Parallelization: Ready for data and model parallelism
Summary¤
The Workshop training system provides:
- ✅ Type-Safe Configuration: Pydantic-based with validation
- ✅ Flexible Training: Custom loops, optimizers, and schedules
- ✅ Research-Ready Features: Checkpointing, logging, monitoring
- ✅ High Performance: JIT compilation and JAX optimizations
- ✅ Easy to Use: Simple API with sensible defaults
- ✅ Well-Tested: Comprehensive test coverage
Next Steps¤
-
Practical guide with examples for common training scenarios
-
Deep dive into the configuration system and best practices
-
Complete API reference for the Trainer class
Continue to the Training Guide for practical examples and advanced patterns.