Training Guide¤
This guide provides practical examples and patterns for training generative models with Workshop. From basic training to advanced techniques, you'll learn how to effectively train models for your specific use case.
Quick Start¤
The simplest way to train a model:
from workshop.generative_models.core.configuration import (
ModelConfiguration,
TrainingConfiguration,
OptimizerConfiguration,
)
from workshop.generative_models.factory import create_model
from workshop.generative_models.training import Trainer
from flax import nnx
import jax.numpy as jnp
# Create model
model_config = ModelConfiguration(
name="simple_vae",
model_class="workshop.generative_models.models.vae.base.VAE",
input_dim=(28, 28, 1),
hidden_dims=[256, 128],
output_dim=32,
)
rngs = nnx.Rngs(42)
model = create_model(config=model_config, rngs=rngs)
# Configure training
optimizer_config = OptimizerConfiguration(
name="adam",
optimizer_type="adam",
learning_rate=1e-3,
)
training_config = TrainingConfiguration(
name="quick_train",
batch_size=128,
num_epochs=10,
optimizer=optimizer_config,
)
# Create trainer
trainer = Trainer(
model=model,
training_config=training_config,
train_data_loader=train_loader,
)
# Train
for epoch in range(training_config.num_epochs):
metrics = trainer.train_epoch()
print(f"Epoch {epoch + 1}: Loss = {metrics['loss']:.4f}")
Setting Up Training¤
Data Loading¤
Create efficient data loaders for your models:
import numpy as np
import jax
import jax.numpy as jnp
def create_data_loader(data, batch_size, shuffle=True):
"""Create a data loader that yields batches."""
def data_loader(batch_size):
num_samples = len(data)
num_batches = num_samples // batch_size
# Shuffle if requested
if shuffle:
indices = np.random.permutation(num_samples)
data_shuffled = jax.tree_map(lambda x: x[indices], data)
else:
data_shuffled = data
# Yield batches
for i in range(num_batches):
batch_start = i * batch_size
batch_end = min(batch_start + batch_size, num_samples)
batch = jax.tree_map(
lambda x: x[batch_start:batch_end],
data_shuffled
)
yield batch
return data_loader
# Example usage with MNIST
from tensorflow.datasets import load
# Load MNIST
ds_train = load('mnist', split='train', as_supervised=True)
ds_val = load('mnist', split='test', as_supervised=True)
# Convert to numpy arrays
train_images = np.array([img for img, _ in ds_train])
train_labels = np.array([label for _, label in ds_train])
val_images = np.array([img for img, _ in ds_val])
val_labels = np.array([label for _, label in ds_val])
# Normalize to [0, 1]
train_images = train_images.astype(np.float32) / 255.0
val_images = val_images.astype(np.float32) / 255.0
# Create data dictionaries
train_data = {"images": train_images, "labels": train_labels}
val_data = {"images": val_images, "labels": val_labels}
# Create data loaders
train_loader = create_data_loader(train_data, batch_size=128, shuffle=True)
val_loader = create_data_loader(val_data, batch_size=128, shuffle=False)
Preprocessing¤
Apply preprocessing to your data:
def preprocess_images(images):
"""Preprocess images for training."""
# Normalize to [-1, 1]
images = (images - 0.5) * 2.0
# Add channel dimension if needed
if images.ndim == 3:
images = images[..., None]
return images
def dequantize(images, rng):
"""Add uniform noise to discrete images."""
noise = jax.random.uniform(rng, images.shape, minval=0.0, maxval=1/256.0)
return images + noise
# Apply preprocessing
train_images = preprocess_images(train_images)
val_images = preprocess_images(val_images)
# Apply dequantization during training
def train_step_with_dequantization(state, batch, rng):
"""Training step with dequantization."""
rng, dequant_rng = jax.random.split(rng)
# Dequantize images
images = dequantize(batch["images"], dequant_rng)
batch = {**batch, "images": images}
# Regular training step
return train_step(state, batch, rng)
Model Initialization¤
Properly initialize your models:
from flax import nnx
from workshop.generative_models.factory import create_model
def initialize_model(model_config, seed=0):
"""Initialize a model with proper RNG handling."""
rngs = nnx.Rngs(seed)
# Create model
model = create_model(config=model_config, rngs=rngs)
# Verify model is initialized
dummy_input = jnp.ones((1, *model_config.input_dim))
try:
output = model(dummy_input, rngs=rngs, training=False)
print(f"Model initialized successfully. Output shape: {output.shape}")
except Exception as e:
print(f"Model initialization failed: {e}")
raise
return model
# Initialize model
model = initialize_model(model_config, seed=42)
Custom Training Loops¤
Basic Custom Loop¤
Create a custom training loop for full control:
import jax
import jax.numpy as jnp
import optax
from flax import nnx
def custom_training_loop(
model,
train_loader,
val_loader,
num_epochs,
learning_rate=1e-3,
):
"""Custom training loop with full control."""
# Create optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(nnx.state(model))
# Training state
rng = jax.random.PRNGKey(0)
step = 0
# Define training step
@nnx.jit
def train_step(model, opt_state, batch, rng):
def loss_fn(model):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
return outputs["loss"], outputs
# Compute gradients
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, outputs), grads = grad_fn(model)
# Update parameters
updates, opt_state = optimizer.update(grads, opt_state)
model = nnx.apply_updates(model, updates)
return model, opt_state, loss, outputs
# Training loop
for epoch in range(num_epochs):
epoch_losses = []
# Train epoch
for batch in train_loader(batch_size=128):
rng, step_rng = jax.random.split(rng)
model, opt_state, loss, outputs = train_step(
model, opt_state, batch, step_rng
)
epoch_losses.append(float(loss))
step += 1
if step % 100 == 0:
print(f"Step {step}: Loss = {loss:.4f}")
# Validation
val_losses = []
for batch in val_loader(batch_size=128):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=False)
val_losses.append(float(outputs["loss"]))
print(f"Epoch {epoch + 1}:")
print(f" Train Loss: {np.mean(epoch_losses):.4f}")
print(f" Val Loss: {np.mean(val_losses):.4f}")
return model
# Train with custom loop
model = custom_training_loop(
model=model,
train_loader=train_loader,
val_loader=val_loader,
num_epochs=10,
learning_rate=1e-3,
)
Advanced Custom Loop with Metrics¤
Track detailed metrics during training:
from collections import defaultdict
def advanced_training_loop(
model,
train_loader,
val_loader,
num_epochs,
optimizer_config,
scheduler_config=None,
):
"""Advanced training loop with metrics tracking."""
# Create optimizer
base_lr = optimizer_config.learning_rate
if scheduler_config:
schedule = create_schedule(scheduler_config, base_lr)
optimizer = optax.adam(learning_rate=schedule)
else:
optimizer = optax.adam(learning_rate=base_lr)
# Apply gradient clipping if configured
if optimizer_config.gradient_clip_norm:
optimizer = optax.chain(
optax.clip_by_global_norm(optimizer_config.gradient_clip_norm),
optimizer,
)
opt_state = optimizer.init(nnx.state(model))
# Metrics tracking
history = defaultdict(list)
rng = jax.random.PRNGKey(0)
step = 0
@nnx.jit
def train_step(model, opt_state, batch, rng):
def loss_fn(model):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
return outputs["loss"], outputs
(loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
# Compute gradient norm
grad_norm = optax.global_norm(grads)
# Update
updates, opt_state = optimizer.update(grads, opt_state)
model = nnx.apply_updates(model, updates)
# Add gradient norm to metrics
metrics = {**outputs, "grad_norm": grad_norm}
return model, opt_state, loss, metrics
# Training loop
for epoch in range(num_epochs):
# Train epoch
for batch in train_loader(batch_size=128):
rng, step_rng = jax.random.split(rng)
model, opt_state, loss, metrics = train_step(
model, opt_state, batch, step_rng
)
# Track metrics
for key, value in metrics.items():
history[f"train_{key}"].append(float(value))
step += 1
# Periodic logging
if step % 100 == 0:
recent_loss = np.mean(history["train_loss"][-100:])
recent_grad_norm = np.mean(history["train_grad_norm"][-100:])
print(f"Step {step}:")
print(f" Loss: {recent_loss:.4f}")
print(f" Grad Norm: {recent_grad_norm:.4f}")
# Validation
val_metrics = defaultdict(list)
for batch in val_loader(batch_size=128):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=False)
for key, value in outputs.items():
val_metrics[key].append(float(value))
# Log validation metrics
print(f"\nEpoch {epoch + 1}:")
for key, values in val_metrics.items():
mean_value = np.mean(values)
history[f"val_{key}"].append(mean_value)
print(f" Val {key}: {mean_value:.4f}")
return model, history
# Train with advanced loop
model, history = advanced_training_loop(
model=model,
train_loader=train_loader,
val_loader=val_loader,
num_epochs=10,
optimizer_config=optimizer_config,
scheduler_config=scheduler_config,
)
# Plot training curves
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history["train_loss"], label="Train")
plt.plot(np.arange(len(history["val_loss"])) * 100, history["val_loss"], label="Val")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss")
plt.subplot(1, 2, 2)
plt.plot(history["train_grad_norm"])
plt.xlabel("Step")
plt.ylabel("Gradient Norm")
plt.title("Gradient Norm")
plt.tight_layout()
plt.show()
Learning Rate Schedules¤
Warmup Schedule¤
Gradually increase learning rate at the start:
from workshop.generative_models.core.configuration import SchedulerConfiguration
# Cosine schedule with warmup (recommended)
warmup_cosine = SchedulerConfiguration(
name="warmup_cosine",
scheduler_type="cosine",
warmup_steps=1000, # 1000 steps of warmup
cycle_length=50000, # Cosine cycle length
min_lr_ratio=0.1, # End at 10% of peak LR
)
training_config = TrainingConfiguration(
name="warmup_training",
batch_size=128,
num_epochs=100,
optimizer=optimizer_config,
scheduler=warmup_cosine,
)
Custom Schedules¤
Create custom learning rate schedules:
import optax
def create_custom_schedule(
base_lr,
warmup_steps,
hold_steps,
decay_steps,
end_lr_ratio=0.1,
):
"""Create a custom learning rate schedule.
Schedule: warmup → hold → decay
"""
schedules = [
# Warmup
optax.linear_schedule(
init_value=0.0,
end_value=base_lr,
transition_steps=warmup_steps,
),
# Hold
optax.constant_schedule(base_lr),
# Decay
optax.cosine_decay_schedule(
init_value=base_lr,
decay_steps=decay_steps,
alpha=end_lr_ratio,
),
]
boundaries = [warmup_steps, warmup_steps + hold_steps]
return optax.join_schedules(schedules, boundaries)
# Use custom schedule
custom_schedule = create_custom_schedule(
base_lr=1e-3,
warmup_steps=1000,
hold_steps=5000,
decay_steps=44000,
end_lr_ratio=0.1,
)
optimizer = optax.adam(learning_rate=custom_schedule)
One-Cycle Schedule¤
Implement one-cycle learning rate policy:
def create_one_cycle_schedule(
max_lr,
total_steps,
pct_start=0.3,
div_factor=25.0,
final_div_factor=1e4,
):
"""Create a one-cycle learning rate schedule.
Args:
max_lr: Maximum learning rate
total_steps: Total training steps
pct_start: Percentage of cycle spent increasing LR
div_factor: Initial LR = max_lr / div_factor
final_div_factor: Final LR = max_lr / final_div_factor
"""
initial_lr = max_lr / div_factor
final_lr = max_lr / final_div_factor
step_up = int(total_steps * pct_start)
step_down = total_steps - step_up
schedules = [
# Increase phase
optax.linear_schedule(
init_value=initial_lr,
end_value=max_lr,
transition_steps=step_up,
),
# Decrease phase
optax.cosine_decay_schedule(
init_value=max_lr,
decay_steps=step_down,
alpha=final_lr / max_lr,
),
]
return optax.join_schedules(schedules, [step_up])
# Use one-cycle schedule
one_cycle_schedule = create_one_cycle_schedule(
max_lr=1e-3,
total_steps=50000,
pct_start=0.3,
)
optimizer = optax.adam(learning_rate=one_cycle_schedule)
Gradient Accumulation¤
Accumulate gradients to simulate larger batch sizes:
def training_with_gradient_accumulation(
model,
train_loader,
num_epochs,
accumulation_steps=4,
learning_rate=1e-3,
):
"""Training with gradient accumulation."""
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
@nnx.jit
def compute_gradients(model, batch, rng):
"""Compute gradients for a batch."""
def loss_fn(model):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
return outputs["loss"], outputs
(loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
return grads, loss, outputs
@nnx.jit
def apply_accumulated_gradients(model, opt_state, accumulated_grads):
"""Apply accumulated gradients."""
# Average gradients
averaged_grads = jax.tree_map(
lambda g: g / accumulation_steps,
accumulated_grads
)
# Update model
updates, opt_state = optimizer.update(averaged_grads, opt_state)
model = nnx.apply_updates(model, updates)
return model, opt_state
# Training loop
for epoch in range(num_epochs):
accumulated_grads = None
step = 0
for batch in train_loader(batch_size=32): # Smaller batch size
rng, step_rng = jax.random.split(rng)
# Compute gradients
grads, loss, outputs = compute_gradients(model, batch, step_rng)
# Accumulate gradients
if accumulated_grads is None:
accumulated_grads = grads
else:
accumulated_grads = jax.tree_map(
lambda acc, g: acc + g,
accumulated_grads,
grads
)
step += 1
# Apply accumulated gradients
if step % accumulation_steps == 0:
model, opt_state = apply_accumulated_gradients(
model, opt_state, accumulated_grads
)
accumulated_grads = None
if step % 100 == 0:
print(f"Step {step // accumulation_steps}: Loss = {loss:.4f}")
return model
# Train with gradient accumulation
model = training_with_gradient_accumulation(
model=model,
train_loader=train_loader,
num_epochs=10,
accumulation_steps=4, # Effective batch size = 32 * 4 = 128
)
Early Stopping¤
Implement early stopping to prevent overfitting:
class EarlyStopping:
"""Early stopping handler."""
def __init__(self, patience=10, min_delta=0.0, mode="min"):
"""Initialize early stopping.
Args:
patience: Number of epochs to wait before stopping
min_delta: Minimum change to qualify as improvement
mode: 'min' or 'max' for loss or accuracy
"""
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_value = float('inf') if mode == 'min' else float('-inf')
self.should_stop = False
def __call__(self, current_value):
"""Check if training should stop."""
if self.mode == 'min':
improved = current_value < (self.best_value - self.min_delta)
else:
improved = current_value > (self.best_value + self.min_delta)
if improved:
self.best_value = current_value
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
def training_with_early_stopping(
model,
train_loader,
val_loader,
max_epochs,
patience=10,
):
"""Training with early stopping."""
early_stopping = EarlyStopping(patience=patience, mode='min')
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
best_model_state = None
best_val_loss = float('inf')
for epoch in range(max_epochs):
# Train epoch
for batch in train_loader(batch_size=128):
rng, step_rng = jax.random.split(rng)
# Training step (simplified)
model, opt_state, loss = train_step(model, opt_state, batch, step_rng)
# Validation
val_losses = []
for batch in val_loader(batch_size=128):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=False)
val_losses.append(float(outputs["loss"]))
val_loss = np.mean(val_losses)
print(f"Epoch {epoch + 1}: Val Loss = {val_loss:.4f}")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = nnx.state(model)
# Check early stopping
if early_stopping(val_loss):
print(f"Early stopping at epoch {epoch + 1}")
break
# Restore best model
if best_model_state is not None:
model = nnx.merge(nnx.GraphDef.from_state(best_model_state), best_model_state)
return model
# Train with early stopping
model = training_with_early_stopping(
model=model,
train_loader=train_loader,
val_loader=val_loader,
max_epochs=100,
patience=10,
)
Mixed Precision Training¤
Use mixed precision for faster training:
def mixed_precision_training(model, train_loader, num_epochs):
"""Training with mixed precision (bfloat16)."""
# Convert model to bfloat16
def convert_to_bfloat16(x):
if isinstance(x, jnp.ndarray) and x.dtype == jnp.float32:
return x.astype(jnp.bfloat16)
return x
model = jax.tree_map(convert_to_bfloat16, model)
# Use mixed precision optimizer
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.scale_by_adam(),
optax.scale(-1e-3), # Learning rate
)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
@nnx.jit
def train_step(model, opt_state, batch, rng):
# Convert batch to bfloat16
batch = jax.tree_map(convert_to_bfloat16, batch)
def loss_fn(model):
outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
# Keep loss in float32 for numerical stability
return outputs["loss"].astype(jnp.float32), outputs
(loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
# Update (gradients automatically in bfloat16)
updates, opt_state = optimizer.update(grads, opt_state)
model = nnx.apply_updates(model, updates)
return model, opt_state, loss
# Training loop
for epoch in range(num_epochs):
for batch in train_loader(batch_size=128):
rng, step_rng = jax.random.split(rng)
model, opt_state, loss = train_step(model, opt_state, batch, step_rng)
print(f"Epoch {epoch + 1}: Loss = {loss:.4f}")
return model
# Train with mixed precision
model = mixed_precision_training(
model=model,
train_loader=train_loader,
num_epochs=10,
)
Model Checkpointing¤
Save and Load Checkpoints¤
import pickle
from pathlib import Path
def save_checkpoint(model, opt_state, step, path):
"""Save training checkpoint."""
checkpoint = {
"model_state": nnx.state(model),
"opt_state": opt_state,
"step": step,
}
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(checkpoint, f)
print(f"Checkpoint saved to {path}")
def load_checkpoint(model, path):
"""Load training checkpoint."""
with open(path, "rb") as f:
checkpoint = pickle.load(f)
# Restore model state
model = nnx.merge(
nnx.GraphDef.from_state(checkpoint["model_state"]),
checkpoint["model_state"]
)
return model, checkpoint["opt_state"], checkpoint["step"]
# Save checkpoint during training
if step % 1000 == 0:
save_checkpoint(
model=model,
opt_state=opt_state,
step=step,
path=f"./checkpoints/step_{step}.pkl"
)
# Load checkpoint
model, opt_state, step = load_checkpoint(
model=model,
path="./checkpoints/step_5000.pkl"
)
print(f"Resumed from step {step}")
Best Model Checkpointing¤
Save the best model based on validation metrics:
class BestModelCheckpoint:
"""Save best model based on validation metric."""
def __init__(self, save_path, mode='min'):
"""Initialize best model checkpoint.
Args:
save_path: Path to save best model
mode: 'min' for loss, 'max' for accuracy
"""
self.save_path = Path(save_path)
self.mode = mode
self.best_value = float('inf') if mode == 'min' else float('-inf')
def __call__(self, model, opt_state, step, current_value):
"""Check and save if best model."""
improved = False
if self.mode == 'min':
improved = current_value < self.best_value
else:
improved = current_value > self.best_value
if improved:
self.best_value = current_value
save_checkpoint(model, opt_state, step, self.save_path)
print(f"New best model! Value: {current_value:.4f}")
return improved
# Use during training
best_checkpoint = BestModelCheckpoint(
save_path="./checkpoints/best_model.pkl",
mode='min'
)
for epoch in range(num_epochs):
# Train epoch
train_metrics = train_epoch(model)
# Validate
val_metrics = validate(model, val_loader)
# Save if best
best_checkpoint(
model=model,
opt_state=opt_state,
step=epoch,
current_value=val_metrics['loss']
)
Logging and Monitoring¤
Weights & Biases Integration¤
import wandb
def train_with_wandb(
model,
train_loader,
val_loader,
training_config,
project_name="generative-models",
):
"""Training with W&B logging."""
# Initialize wandb
wandb.init(
project=project_name,
config={
"learning_rate": training_config.optimizer.learning_rate,
"batch_size": training_config.batch_size,
"num_epochs": training_config.num_epochs,
"optimizer": training_config.optimizer.optimizer_type,
}
)
optimizer = optax.adam(training_config.optimizer.learning_rate)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
step = 0
for epoch in range(training_config.num_epochs):
# Train epoch
for batch in train_loader(batch_size=training_config.batch_size):
rng, step_rng = jax.random.split(rng)
model, opt_state, loss, metrics = train_step(
model, opt_state, batch, step_rng
)
# Log to wandb
wandb.log({
"train/loss": float(loss),
"train/step": step,
**{f"train/{k}": float(v) for k, v in metrics.items()}
})
step += 1
# Validation
val_metrics = validate(model, val_loader)
wandb.log({
"val/loss": val_metrics['loss'],
"epoch": epoch,
})
print(f"Epoch {epoch + 1}: Val Loss = {val_metrics['loss']:.4f}")
wandb.finish()
return model
# Train with wandb
model = train_with_wandb(
model=model,
train_loader=train_loader,
val_loader=val_loader,
training_config=training_config,
project_name="vae-experiments",
)
TensorBoard Integration¤
from torch.utils.tensorboard import SummaryWriter
def train_with_tensorboard(
model,
train_loader,
val_loader,
num_epochs,
log_dir="./logs",
):
"""Training with TensorBoard logging."""
writer = SummaryWriter(log_dir=log_dir)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
step = 0
for epoch in range(num_epochs):
# Train epoch
for batch in train_loader(batch_size=128):
rng, step_rng = jax.random.split(rng)
model, opt_state, loss, metrics = train_step(
model, opt_state, batch, step_rng
)
# Log to tensorboard
writer.add_scalar("Loss/train", float(loss), step)
for key, value in metrics.items():
writer.add_scalar(f"Metrics/{key}", float(value), step)
step += 1
# Validation
val_metrics = validate(model, val_loader)
writer.add_scalar("Loss/val", val_metrics['loss'], epoch)
# Log images
if epoch % 5 == 0:
samples = model.generate(num_samples=16, rngs=nnx.Rngs(0))
writer.add_images("Samples", samples, epoch)
writer.close()
return model
# Train with tensorboard
model = train_with_tensorboard(
model=model,
train_loader=train_loader,
val_loader=val_loader,
num_epochs=10,
log_dir="./tensorboard_logs",
)
Common Training Patterns¤
Progressive Training¤
Train with progressively increasing complexity:
def progressive_training(model, train_loader, stages):
"""Train with progressive stages.
Args:
model: Model to train
train_loader: Data loader
stages: List of (num_epochs, learning_rate, batch_size) tuples
"""
optimizer_state = None
for stage_idx, (num_epochs, learning_rate, batch_size) in enumerate(stages):
print(f"\nStage {stage_idx + 1}: LR={learning_rate}, BS={batch_size}")
# Create optimizer for this stage
optimizer = optax.adam(learning_rate)
# Initialize or reuse optimizer state
if optimizer_state is None:
optimizer_state = optimizer.init(nnx.state(model))
# Train for this stage
for epoch in range(num_epochs):
for batch in train_loader(batch_size=batch_size):
model, optimizer_state, loss = train_step(
model, optimizer_state, batch, rng
)
print(f" Epoch {epoch + 1}: Loss = {loss:.4f}")
return model
# Define progressive stages
stages = [
(10, 1e-3, 32), # Stage 1: High LR, small batch
(20, 5e-4, 64), # Stage 2: Medium LR, medium batch
(30, 1e-4, 128), # Stage 3: Low LR, large batch
]
model = progressive_training(model, train_loader, stages)
Curriculum Learning¤
Train with increasing data difficulty:
def curriculum_learning(model, data_loader_fn, difficulty_schedule):
"""Train with curriculum learning.
Args:
model: Model to train
data_loader_fn: Function that returns data loader for difficulty level
difficulty_schedule: List of (difficulty_level, num_epochs) tuples
"""
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
for difficulty, num_epochs in difficulty_schedule:
print(f"\nTraining on difficulty level: {difficulty}")
# Get data loader for this difficulty
train_loader = data_loader_fn(difficulty)
# Train
for epoch in range(num_epochs):
for batch in train_loader(batch_size=128):
rng, step_rng = jax.random.split(rng)
model, opt_state, loss = train_step(
model, opt_state, batch, step_rng
)
print(f" Epoch {epoch + 1}: Loss = {loss:.4f}")
return model
# Define curriculum
difficulty_schedule = [
("easy", 10), # Train on easy examples first
("medium", 20), # Then medium difficulty
("hard", 30), # Finally hard examples
("all", 40), # Train on all data
]
model = curriculum_learning(model, data_loader_fn, difficulty_schedule)
Multi-Task Training¤
Train on multiple tasks simultaneously:
def multi_task_training(
model,
task_loaders,
task_weights,
num_epochs,
):
"""Train on multiple tasks.
Args:
model: Model to train
task_loaders: Dict of task_name -> data_loader
task_weights: Dict of task_name -> weight
num_epochs: Number of epochs
"""
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(model))
rng = jax.random.PRNGKey(0)
@nnx.jit
def multi_task_step(model, opt_state, batches, rng):
"""Training step with multiple tasks."""
def loss_fn(model):
total_loss = 0.0
metrics = {}
for task_name, batch in batches.items():
# Task-specific forward pass
outputs = model(
batch,
task=task_name,
rngs=nnx.Rngs(rng),
training=True
)
# Weighted loss
task_loss = outputs["loss"] * task_weights[task_name]
total_loss += task_loss
# Track metrics
metrics[f"{task_name}_loss"] = outputs["loss"]
return total_loss, metrics
(loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
updates, opt_state = optimizer.update(grads, opt_state)
model = nnx.apply_updates(model, updates)
return model, opt_state, loss, metrics
# Training loop
for epoch in range(num_epochs):
# Get batches from all tasks
task_iters = {
name: loader(batch_size=32)
for name, loader in task_loaders.items()
}
for step in range(1000): # Fixed steps per epoch
# Get batch from each task
batches = {
name: next(task_iter)
for name, task_iter in task_iters.items()
}
rng, step_rng = jax.random.split(rng)
model, opt_state, loss, metrics = multi_task_step(
model, opt_state, batches, step_rng
)
if step % 100 == 0:
print(f"Step {step}: Total Loss = {loss:.4f}")
for task_name, task_loss in metrics.items():
print(f" {task_name}: {task_loss:.4f}")
return model
# Train on multiple tasks
task_loaders = {
"reconstruction": reconstruction_loader,
"generation": generation_loader,
"classification": classification_loader,
}
task_weights = {
"reconstruction": 1.0,
"generation": 0.5,
"classification": 0.3,
}
model = multi_task_training(
model=model,
task_loaders=task_loaders,
task_weights=task_weights,
num_epochs=50,
)
Troubleshooting¤
NaN Loss¤
If you encounter NaN loss:
# 1. Add gradient clipping
optimizer_config = OptimizerConfiguration(
name="clipped_adam",
optimizer_type="adam",
learning_rate=1e-3,
gradient_clip_norm=1.0, # Clip gradients
)
# 2. Lower learning rate
optimizer_config = OptimizerConfiguration(
name="lower_lr",
optimizer_type="adam",
learning_rate=1e-4, # Lower LR
)
# 3. Check for numerical instability
def check_for_nans(metrics, step):
"""Check for NaNs in metrics."""
for key, value in metrics.items():
if np.isnan(value):
print(f"NaN detected at step {step} in {key}")
# Save checkpoint before crash
save_checkpoint(model, opt_state, step, "./emergency_checkpoint.pkl")
raise ValueError(f"NaN in {key}")
# 4. Use mixed precision with care
# Avoid bfloat16 for loss computation
loss = loss.astype(jnp.float32) # Keep loss in float32
Slow Training¤
If training is slow:
# 1. Use JIT compilation
@nnx.jit
def train_step(model, opt_state, batch, rng):
# Training step logic
pass
# 2. Profile your code
with jax.profiler.trace("./tensorboard_logs"):
for _ in range(100):
model, opt_state, loss = train_step(model, opt_state, batch, rng)
# 3. Increase batch size (if memory allows)
training_config = TrainingConfiguration(
name="large_batch",
batch_size=256, # Larger batch size
num_epochs=50, # Fewer epochs needed
optimizer=optimizer_config,
)
# 4. Use data prefetching
from concurrent.futures import ThreadPoolExecutor
def prefetch_data_loader(data_loader, prefetch_size=2):
"""Prefetch data in background."""
with ThreadPoolExecutor(max_workers=1) as executor:
iterator = iter(data_loader(batch_size=128))
futures = [executor.submit(lambda: next(iterator))
for _ in range(prefetch_size)]
while True:
# Get next batch from future
batch = futures.pop(0).result()
# Submit new prefetch
futures.append(executor.submit(lambda: next(iterator)))
yield batch
Memory Issues¤
If you run out of memory:
# 1. Reduce batch size
training_config = TrainingConfiguration(
name="small_batch",
batch_size=32, # Smaller batch
num_epochs=200, # More epochs
optimizer=optimizer_config,
)
# 2. Use gradient accumulation
# See "Gradient Accumulation" section above
# 3. Clear cache periodically
import jax
# Clear compilation cache
jax.clear_caches()
# 4. Use checkpointing for large models
from jax.checkpoint import checkpoint
@checkpoint
def expensive_forward_pass(model, x):
"""Forward pass with gradient checkpointing."""
return model(x)
Best Practices¤
DO¤
- ✅ Use type-safe configuration with validation
- ✅ JIT-compile training steps for performance
- ✅ Save checkpoints regularly
- ✅ Monitor training metrics (loss, gradients)
- ✅ Use gradient clipping for stability
- ✅ Start with small learning rate and increase
- ✅ Validate periodically during training
- ✅ Save best model based on validation metrics
- ✅ Use warmup for learning rate schedules
- ✅ Profile code to find bottlenecks
DON'T¤
- ❌ Skip validation - always validate your model
- ❌ Use too high learning rate initially
- ❌ Forget to shuffle training data
- ❌ Ignore NaN or infinite losses
- ❌ Train without gradient clipping
- ❌ Overwrite checkpoints without backup
- ❌ Use mixed precision for all operations
- ❌ Forget to split RNG keys properly
- ❌ Mutate training state in-place
- ❌ Skip warmup for large learning rates
Summary¤
This guide covered:
- Basic Training: Quick start and setup
- Custom Loops: Full control over training
- Learning Rate Schedules: Warmup, cosine, one-cycle
- Advanced Techniques: Gradient accumulation, early stopping, mixed precision
- Checkpointing: Save and load model state
- Logging: W&B, TensorBoard integration
- Common Patterns: Progressive training, curriculum learning, multi-task
- Troubleshooting: NaN loss, slow training, memory issues
Next Steps¤
-
Deep dive into configuration system and best practices
-
Architecture and core concepts of training system
-
Complete API reference for Trainer class
See the Configuration Guide for detailed configuration options and patterns.