Advanced Training Pipeline¤
Level: Intermediate Runtime: ~2 minutes Format: Dual (.py script | .ipynb notebook)
Production-ready training patterns including optimizer configuration, learning rate scheduling, metrics tracking, and checkpointing strategies.
Files¤
- Python Script:
examples/generative_models/advanced_training_example.py - Jupyter Notebook:
examples/generative_models/advanced_training_example.ipynb
Quick Start¤
# Run the Python script
python examples/generative_models/advanced_training_example.py
# Or open the Jupyter notebook
jupyter notebook examples/generative_models/advanced_training_example.ipynb
Overview¤
This example demonstrates how to build a complete, production-ready training pipeline using the Workshop framework. You'll learn essential patterns for training deep learning models including configuration management, optimization strategies, metrics tracking, and model checkpointing.
Learning Objectives¤
After completing this example, you will understand:
- How to implement a complete training pipeline with proper validation
- Optimizer and learning rate scheduler configuration
- Metrics tracking and visualization during training
- Checkpoint management and model persistence
- Best practices for training loop organization
Prerequisites¤
- Basic understanding of neural network training
- Familiarity with JAX and Flax NNX
- Understanding of gradient descent and backpropagation
- Knowledge of learning rate scheduling concepts
Theory and Key Concepts¤
Training Loop Components¤
A production training loop requires several key components working together:
- Data Management: Efficient batching and shuffling strategies
- Optimization: Gradient computation and parameter updates
- Metrics Tracking: Monitor training and validation performance
- Checkpointing: Save model state for recovery and deployment
- Validation: Monitor generalization to unseen data
Learning Rate Scheduling¤
Learning rate schedules improve training stability and convergence by adapting the learning rate during training:
Warmup: Gradually increase learning rate from zero to avoid early instability Decay: Reduce learning rate as training progresses to enable fine-grained convergence Cosine Annealing: Smooth decrease following a cosine curve
The formula for cosine decay is:
where \(\eta_t\) is the learning rate at step \(t\), and \(T\) is the total number of steps.
Optimization Algorithms¤
Adam (Adaptive Moment Estimation): Combines momentum and adaptive learning rates
SGD with Momentum: Accelerates convergence by accumulating gradients
Code Walkthrough¤
1. Configuration Setup¤
The example uses Pydantic-based configuration objects for all settings:
# Model configuration
model_config = ModelConfiguration(
name="classifier",
model_class="simple_classifier",
input_dim=784,
hidden_dims=[256, 128],
output_dim=10,
dropout_rate=0.1,
)
# Optimizer configuration
optimizer_config = OptimizerConfiguration(
name="training_optimizer",
optimizer_type="adam",
learning_rate=1e-3,
beta1=0.9,
beta2=0.999,
weight_decay=1e-4,
)
# Scheduler configuration
scheduler_config = SchedulerConfiguration(
name="cosine_scheduler",
scheduler_type="cosine",
total_steps=1000,
warmup_steps=100,
)
# Training configuration
training_config = TrainingConfiguration(
name="training",
batch_size=32,
num_epochs=10,
optimizer=optimizer_config,
scheduler=scheduler_config,
checkpoint_dir="./checkpoints/advanced_example",
save_frequency=5,
)
This approach centralizes all hyperparameters, making experiments reproducible and configuration management straightforward.
2. Data Loading¤
The example implements a simple data loader with shuffling:
def create_data_loader(data, batch_size=32, shuffle=True):
"""Create a simple data loader."""
x, y = data
num_samples = len(x)
indices = jnp.arange(num_samples)
if shuffle:
key = jax.random.key(np.random.randint(0, 10000))
indices = jax.random.permutation(key, indices)
for i in range(0, num_samples, batch_size):
batch_indices = indices[i : i + batch_size]
yield x[batch_indices], y[batch_indices]
In production, you would use more sophisticated data loading strategies like TensorFlow Datasets or PyTorch DataLoader equivalents.
3. Model Definition¤
A simple classifier using Flax NNX demonstrates proper module patterns:
class SimpleClassifier(nnx.Module):
def __init__(self, input_dim, hidden_dims, num_classes, *, rngs: nnx.Rngs):
super().__init__() # Always call this
layers = []
prev_dim = input_dim
# Build hidden layers
for hidden_dim in hidden_dims:
layers.append(nnx.Linear(prev_dim, hidden_dim, rngs=rngs))
layers.append(nnx.relu)
layers.append(nnx.Dropout(rate=0.1, rngs=rngs))
prev_dim = hidden_dim
# Output layer
layers.append(nnx.Linear(prev_dim, num_classes, rngs=rngs))
self.net = nnx.Sequential(*layers)
def __call__(self, x, *, training=False):
return self.net(x)
4. Training Step¤
The core training step computes loss, gradients, and updates parameters:
def train_step(model, optimizer, batch_x, batch_y, loss_fn):
def compute_loss(model):
logits = model(batch_x, training=True)
# Cross-entropy loss
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch_y)
loss = jnp.mean(loss)
# Accuracy
predictions = jnp.argmax(logits, axis=-1)
accuracy = jnp.mean(predictions == batch_y)
return loss, accuracy
(loss, accuracy), grads = nnx.value_and_grad(compute_loss, has_aux=True)(model)
optimizer.update(model, grads)
return loss, accuracy
This pattern uses NNX's value_and_grad for efficient gradient computation with auxiliary outputs (accuracy).
5. Main Training Loop¤
The main loop orchestrates all components:
for epoch in range(training_config.num_epochs):
# Training
train_loss = 0
train_acc = 0
num_train_batches = 0
train_loader = create_data_loader(
train_data, batch_size=training_config.batch_size, shuffle=True
)
for batch_x, batch_y in train_loader:
loss, acc = train_step(model, optimizer, batch_x, batch_y, None)
train_loss += loss
train_acc += acc
num_train_batches += 1
train_loss /= num_train_batches
train_acc /= num_train_batches
# Validation
val_loader = create_data_loader(
val_data, batch_size=training_config.batch_size, shuffle=False
)
val_loss, val_acc = evaluate(model, val_loader)
# Update metrics
metrics.update({
"train_loss": train_loss,
"train_acc": train_acc,
"val_loss": val_loss,
"val_acc": val_acc,
})
# Save checkpoint
if (epoch + 1) % training_config.save_frequency == 0:
save_checkpoint(model, optimizer, epoch + 1, training_config.checkpoint_dir)
6. Metrics Tracking¤
The example includes a custom metrics tracker with visualization:
class TrainingMetrics:
def __init__(self):
self.history = {
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
}
def update(self, metrics: dict[str, float]):
for key, value in metrics.items():
if key in self.history:
self.history[key].append(float(value))
def plot(self, save_path=None):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Plot loss and accuracy curves
# ...
This enables real-time monitoring and post-training analysis.
Expected Output¤
When you run the example, you should see:
============================================================
Advanced Training Example
============================================================
1. Setting up configuration...
Model: classifier
Optimizer: adam
Scheduler: cosine
Epochs: 10
Batch size: 32
2. Creating synthetic dataset...
Train samples: 1000
Validation samples: 200
Test samples: 200
3. Creating model...
Model created with 2 hidden layers
4. Setting up optimizer and scheduler...
5. Starting training...
----------------------------------------
Epoch 1/10
Train - Loss: 2.3965, Acc: 0.1025
Val - Loss: 2.3957, Acc: 0.1071
...
Epoch 10/10
Train - Loss: 0.0137, Acc: 1.0000
Val - Loss: 4.3435, Acc: 0.0670
----------------------------------------
6. Evaluating on test set...
Test Loss: 4.1130
Test Accuracy: 0.1205
7. Plotting training curves...
Training curves saved to examples_output/training_curves.png
✅ Advanced training example completed successfully!
The example will also save a visualization of the training curves to examples_output/training_curves.png.
Experiments to Try¤
- Different Optimizers: Compare Adam, SGD with momentum, and AdamW
- Scheduler Variations: Test exponential decay vs cosine annealing
scheduler_config.scheduler_type = "exponential"
scheduler_config.decay_steps = 100
scheduler_config.decay_rate = 0.96
- Architecture Changes: Experiment with different hidden layer configurations
- Regularization: Adjust dropout and weight decay
- Early Stopping: Implement early stopping based on validation loss
Troubleshooting¤
High Validation Loss¤
If validation loss is much higher than training loss:
- Reduce model complexity or add regularization
- Increase dropout rate
- Add weight decay to optimizer
- Use more training data
Slow Convergence¤
If training is slow to converge:
- Increase learning rate (carefully)
- Use a learning rate warmup
- Try a different optimizer (e.g., Adam instead of SGD)
- Check gradient magnitudes
Numerical Instability¤
If you encounter NaN or Inf values:
- Reduce learning rate
- Add gradient clipping
- Use mixed precision training
- Check for exploding/vanishing gradients
Next Steps¤
-
VAE Training
Learn to train Variational Autoencoders with the ELBO loss
-
GAN Training
Master adversarial training with generator and discriminator
-
Advanced Optimization
Explore gradient clipping, mixed precision, and distributed training
-
Model Deployment
Learn to export and deploy trained models