Skip to content

Training Guide¤

This guide provides practical examples and patterns for training generative models with Workshop. From basic training to advanced techniques, you'll learn how to effectively train models for your specific use case.

Quick Start¤

The simplest way to train a model:

from workshop.generative_models.core.configuration import (
    ModelConfiguration,
    TrainingConfiguration,
    OptimizerConfiguration,
)
from workshop.generative_models.factory import create_model
from workshop.generative_models.training import Trainer
from flax import nnx
import jax.numpy as jnp

# Create model
model_config = ModelConfiguration(
    name="simple_vae",
    model_class="workshop.generative_models.models.vae.base.VAE",
    input_dim=(28, 28, 1),
    hidden_dims=[256, 128],
    output_dim=32,
)

rngs = nnx.Rngs(42)
model = create_model(config=model_config, rngs=rngs)

# Configure training
optimizer_config = OptimizerConfiguration(
    name="adam",
    optimizer_type="adam",
    learning_rate=1e-3,
)

training_config = TrainingConfiguration(
    name="quick_train",
    batch_size=128,
    num_epochs=10,
    optimizer=optimizer_config,
)

# Create trainer
trainer = Trainer(
    model=model,
    training_config=training_config,
    train_data_loader=train_loader,
)

# Train
for epoch in range(training_config.num_epochs):
    metrics = trainer.train_epoch()
    print(f"Epoch {epoch + 1}: Loss = {metrics['loss']:.4f}")

Setting Up Training¤

Data Loading¤

Create efficient data loaders for your models:

import numpy as np
import jax
import jax.numpy as jnp

def create_data_loader(data, batch_size, shuffle=True):
    """Create a data loader that yields batches."""
    def data_loader(batch_size):
        num_samples = len(data)
        num_batches = num_samples // batch_size

        # Shuffle if requested
        if shuffle:
            indices = np.random.permutation(num_samples)
            data_shuffled = jax.tree_map(lambda x: x[indices], data)
        else:
            data_shuffled = data

        # Yield batches
        for i in range(num_batches):
            batch_start = i * batch_size
            batch_end = min(batch_start + batch_size, num_samples)

            batch = jax.tree_map(
                lambda x: x[batch_start:batch_end],
                data_shuffled
            )
            yield batch

    return data_loader

# Example usage with MNIST
from tensorflow.datasets import load

# Load MNIST
ds_train = load('mnist', split='train', as_supervised=True)
ds_val = load('mnist', split='test', as_supervised=True)

# Convert to numpy arrays
train_images = np.array([img for img, _ in ds_train])
train_labels = np.array([label for _, label in ds_train])

val_images = np.array([img for img, _ in ds_val])
val_labels = np.array([label for _, label in ds_val])

# Normalize to [0, 1]
train_images = train_images.astype(np.float32) / 255.0
val_images = val_images.astype(np.float32) / 255.0

# Create data dictionaries
train_data = {"images": train_images, "labels": train_labels}
val_data = {"images": val_images, "labels": val_labels}

# Create data loaders
train_loader = create_data_loader(train_data, batch_size=128, shuffle=True)
val_loader = create_data_loader(val_data, batch_size=128, shuffle=False)

Preprocessing¤

Apply preprocessing to your data:

def preprocess_images(images):
    """Preprocess images for training."""
    # Normalize to [-1, 1]
    images = (images - 0.5) * 2.0

    # Add channel dimension if needed
    if images.ndim == 3:
        images = images[..., None]

    return images

def dequantize(images, rng):
    """Add uniform noise to discrete images."""
    noise = jax.random.uniform(rng, images.shape, minval=0.0, maxval=1/256.0)
    return images + noise

# Apply preprocessing
train_images = preprocess_images(train_images)
val_images = preprocess_images(val_images)

# Apply dequantization during training
def train_step_with_dequantization(state, batch, rng):
    """Training step with dequantization."""
    rng, dequant_rng = jax.random.split(rng)

    # Dequantize images
    images = dequantize(batch["images"], dequant_rng)
    batch = {**batch, "images": images}

    # Regular training step
    return train_step(state, batch, rng)

Model Initialization¤

Properly initialize your models:

from flax import nnx
from workshop.generative_models.factory import create_model

def initialize_model(model_config, seed=0):
    """Initialize a model with proper RNG handling."""
    rngs = nnx.Rngs(seed)

    # Create model
    model = create_model(config=model_config, rngs=rngs)

    # Verify model is initialized
    dummy_input = jnp.ones((1, *model_config.input_dim))

    try:
        output = model(dummy_input, rngs=rngs, training=False)
        print(f"Model initialized successfully. Output shape: {output.shape}")
    except Exception as e:
        print(f"Model initialization failed: {e}")
        raise

    return model

# Initialize model
model = initialize_model(model_config, seed=42)

Custom Training Loops¤

Basic Custom Loop¤

Create a custom training loop for full control:

import jax
import jax.numpy as jnp
import optax
from flax import nnx

def custom_training_loop(
    model,
    train_loader,
    val_loader,
    num_epochs,
    learning_rate=1e-3,
):
    """Custom training loop with full control."""
    # Create optimizer
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(nnx.state(model))

    # Training state
    rng = jax.random.PRNGKey(0)
    step = 0

    # Define training step
    @nnx.jit
    def train_step(model, opt_state, batch, rng):
        def loss_fn(model):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
            return outputs["loss"], outputs

        # Compute gradients
        grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
        (loss, outputs), grads = grad_fn(model)

        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state)
        model = nnx.apply_updates(model, updates)

        return model, opt_state, loss, outputs

    # Training loop
    for epoch in range(num_epochs):
        epoch_losses = []

        # Train epoch
        for batch in train_loader(batch_size=128):
            rng, step_rng = jax.random.split(rng)

            model, opt_state, loss, outputs = train_step(
                model, opt_state, batch, step_rng
            )

            epoch_losses.append(float(loss))
            step += 1

            if step % 100 == 0:
                print(f"Step {step}: Loss = {loss:.4f}")

        # Validation
        val_losses = []
        for batch in val_loader(batch_size=128):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=False)
            val_losses.append(float(outputs["loss"]))

        print(f"Epoch {epoch + 1}:")
        print(f"  Train Loss: {np.mean(epoch_losses):.4f}")
        print(f"  Val Loss: {np.mean(val_losses):.4f}")

    return model

# Train with custom loop
model = custom_training_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=10,
    learning_rate=1e-3,
)

Advanced Custom Loop with Metrics¤

Track detailed metrics during training:

from collections import defaultdict

def advanced_training_loop(
    model,
    train_loader,
    val_loader,
    num_epochs,
    optimizer_config,
    scheduler_config=None,
):
    """Advanced training loop with metrics tracking."""
    # Create optimizer
    base_lr = optimizer_config.learning_rate

    if scheduler_config:
        schedule = create_schedule(scheduler_config, base_lr)
        optimizer = optax.adam(learning_rate=schedule)
    else:
        optimizer = optax.adam(learning_rate=base_lr)

    # Apply gradient clipping if configured
    if optimizer_config.gradient_clip_norm:
        optimizer = optax.chain(
            optax.clip_by_global_norm(optimizer_config.gradient_clip_norm),
            optimizer,
        )

    opt_state = optimizer.init(nnx.state(model))

    # Metrics tracking
    history = defaultdict(list)
    rng = jax.random.PRNGKey(0)
    step = 0

    @nnx.jit
    def train_step(model, opt_state, batch, rng):
        def loss_fn(model):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
            return outputs["loss"], outputs

        (loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

        # Compute gradient norm
        grad_norm = optax.global_norm(grads)

        # Update
        updates, opt_state = optimizer.update(grads, opt_state)
        model = nnx.apply_updates(model, updates)

        # Add gradient norm to metrics
        metrics = {**outputs, "grad_norm": grad_norm}

        return model, opt_state, loss, metrics

    # Training loop
    for epoch in range(num_epochs):
        # Train epoch
        for batch in train_loader(batch_size=128):
            rng, step_rng = jax.random.split(rng)

            model, opt_state, loss, metrics = train_step(
                model, opt_state, batch, step_rng
            )

            # Track metrics
            for key, value in metrics.items():
                history[f"train_{key}"].append(float(value))

            step += 1

            # Periodic logging
            if step % 100 == 0:
                recent_loss = np.mean(history["train_loss"][-100:])
                recent_grad_norm = np.mean(history["train_grad_norm"][-100:])
                print(f"Step {step}:")
                print(f"  Loss: {recent_loss:.4f}")
                print(f"  Grad Norm: {recent_grad_norm:.4f}")

        # Validation
        val_metrics = defaultdict(list)
        for batch in val_loader(batch_size=128):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=False)
            for key, value in outputs.items():
                val_metrics[key].append(float(value))

        # Log validation metrics
        print(f"\nEpoch {epoch + 1}:")
        for key, values in val_metrics.items():
            mean_value = np.mean(values)
            history[f"val_{key}"].append(mean_value)
            print(f"  Val {key}: {mean_value:.4f}")

    return model, history

# Train with advanced loop
model, history = advanced_training_loop(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=10,
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config,
)

# Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history["train_loss"], label="Train")
plt.plot(np.arange(len(history["val_loss"])) * 100, history["val_loss"], label="Val")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss")

plt.subplot(1, 2, 2)
plt.plot(history["train_grad_norm"])
plt.xlabel("Step")
plt.ylabel("Gradient Norm")
plt.title("Gradient Norm")

plt.tight_layout()
plt.show()

Learning Rate Schedules¤

Warmup Schedule¤

Gradually increase learning rate at the start:

from workshop.generative_models.core.configuration import SchedulerConfiguration

# Cosine schedule with warmup (recommended)
warmup_cosine = SchedulerConfiguration(
    name="warmup_cosine",
    scheduler_type="cosine",
    warmup_steps=1000,      # 1000 steps of warmup
    cycle_length=50000,     # Cosine cycle length
    min_lr_ratio=0.1,       # End at 10% of peak LR
)

training_config = TrainingConfiguration(
    name="warmup_training",
    batch_size=128,
    num_epochs=100,
    optimizer=optimizer_config,
    scheduler=warmup_cosine,
)

Custom Schedules¤

Create custom learning rate schedules:

import optax

def create_custom_schedule(
    base_lr,
    warmup_steps,
    hold_steps,
    decay_steps,
    end_lr_ratio=0.1,
):
    """Create a custom learning rate schedule.

    Schedule: warmup → hold → decay
    """
    schedules = [
        # Warmup
        optax.linear_schedule(
            init_value=0.0,
            end_value=base_lr,
            transition_steps=warmup_steps,
        ),
        # Hold
        optax.constant_schedule(base_lr),
        # Decay
        optax.cosine_decay_schedule(
            init_value=base_lr,
            decay_steps=decay_steps,
            alpha=end_lr_ratio,
        ),
    ]

    boundaries = [warmup_steps, warmup_steps + hold_steps]

    return optax.join_schedules(schedules, boundaries)

# Use custom schedule
custom_schedule = create_custom_schedule(
    base_lr=1e-3,
    warmup_steps=1000,
    hold_steps=5000,
    decay_steps=44000,
    end_lr_ratio=0.1,
)

optimizer = optax.adam(learning_rate=custom_schedule)

One-Cycle Schedule¤

Implement one-cycle learning rate policy:

def create_one_cycle_schedule(
    max_lr,
    total_steps,
    pct_start=0.3,
    div_factor=25.0,
    final_div_factor=1e4,
):
    """Create a one-cycle learning rate schedule.

    Args:
        max_lr: Maximum learning rate
        total_steps: Total training steps
        pct_start: Percentage of cycle spent increasing LR
        div_factor: Initial LR = max_lr / div_factor
        final_div_factor: Final LR = max_lr / final_div_factor
    """
    initial_lr = max_lr / div_factor
    final_lr = max_lr / final_div_factor
    step_up = int(total_steps * pct_start)
    step_down = total_steps - step_up

    schedules = [
        # Increase phase
        optax.linear_schedule(
            init_value=initial_lr,
            end_value=max_lr,
            transition_steps=step_up,
        ),
        # Decrease phase
        optax.cosine_decay_schedule(
            init_value=max_lr,
            decay_steps=step_down,
            alpha=final_lr / max_lr,
        ),
    ]

    return optax.join_schedules(schedules, [step_up])

# Use one-cycle schedule
one_cycle_schedule = create_one_cycle_schedule(
    max_lr=1e-3,
    total_steps=50000,
    pct_start=0.3,
)

optimizer = optax.adam(learning_rate=one_cycle_schedule)

Gradient Accumulation¤

Accumulate gradients to simulate larger batch sizes:

def training_with_gradient_accumulation(
    model,
    train_loader,
    num_epochs,
    accumulation_steps=4,
    learning_rate=1e-3,
):
    """Training with gradient accumulation."""
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)

    @nnx.jit
    def compute_gradients(model, batch, rng):
        """Compute gradients for a batch."""
        def loss_fn(model):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
            return outputs["loss"], outputs

        (loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
        return grads, loss, outputs

    @nnx.jit
    def apply_accumulated_gradients(model, opt_state, accumulated_grads):
        """Apply accumulated gradients."""
        # Average gradients
        averaged_grads = jax.tree_map(
            lambda g: g / accumulation_steps,
            accumulated_grads
        )

        # Update model
        updates, opt_state = optimizer.update(averaged_grads, opt_state)
        model = nnx.apply_updates(model, updates)

        return model, opt_state

    # Training loop
    for epoch in range(num_epochs):
        accumulated_grads = None
        step = 0

        for batch in train_loader(batch_size=32):  # Smaller batch size
            rng, step_rng = jax.random.split(rng)

            # Compute gradients
            grads, loss, outputs = compute_gradients(model, batch, step_rng)

            # Accumulate gradients
            if accumulated_grads is None:
                accumulated_grads = grads
            else:
                accumulated_grads = jax.tree_map(
                    lambda acc, g: acc + g,
                    accumulated_grads,
                    grads
                )

            step += 1

            # Apply accumulated gradients
            if step % accumulation_steps == 0:
                model, opt_state = apply_accumulated_gradients(
                    model, opt_state, accumulated_grads
                )
                accumulated_grads = None

                if step % 100 == 0:
                    print(f"Step {step // accumulation_steps}: Loss = {loss:.4f}")

    return model

# Train with gradient accumulation
model = training_with_gradient_accumulation(
    model=model,
    train_loader=train_loader,
    num_epochs=10,
    accumulation_steps=4,  # Effective batch size = 32 * 4 = 128
)

Early Stopping¤

Implement early stopping to prevent overfitting:

class EarlyStopping:
    """Early stopping handler."""

    def __init__(self, patience=10, min_delta=0.0, mode="min"):
        """Initialize early stopping.

        Args:
            patience: Number of epochs to wait before stopping
            min_delta: Minimum change to qualify as improvement
            mode: 'min' or 'max' for loss or accuracy
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_value = float('inf') if mode == 'min' else float('-inf')
        self.should_stop = False

    def __call__(self, current_value):
        """Check if training should stop."""
        if self.mode == 'min':
            improved = current_value < (self.best_value - self.min_delta)
        else:
            improved = current_value > (self.best_value + self.min_delta)

        if improved:
            self.best_value = current_value
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True

        return self.should_stop

def training_with_early_stopping(
    model,
    train_loader,
    val_loader,
    max_epochs,
    patience=10,
):
    """Training with early stopping."""
    early_stopping = EarlyStopping(patience=patience, mode='min')

    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)

    best_model_state = None
    best_val_loss = float('inf')

    for epoch in range(max_epochs):
        # Train epoch
        for batch in train_loader(batch_size=128):
            rng, step_rng = jax.random.split(rng)
            # Training step (simplified)
            model, opt_state, loss = train_step(model, opt_state, batch, step_rng)

        # Validation
        val_losses = []
        for batch in val_loader(batch_size=128):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=False)
            val_losses.append(float(outputs["loss"]))

        val_loss = np.mean(val_losses)
        print(f"Epoch {epoch + 1}: Val Loss = {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = nnx.state(model)

        # Check early stopping
        if early_stopping(val_loss):
            print(f"Early stopping at epoch {epoch + 1}")
            break

    # Restore best model
    if best_model_state is not None:
        model = nnx.merge(nnx.GraphDef.from_state(best_model_state), best_model_state)

    return model

# Train with early stopping
model = training_with_early_stopping(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    max_epochs=100,
    patience=10,
)

Mixed Precision Training¤

Use mixed precision for faster training:

def mixed_precision_training(model, train_loader, num_epochs):
    """Training with mixed precision (bfloat16)."""
    # Convert model to bfloat16
    def convert_to_bfloat16(x):
        if isinstance(x, jnp.ndarray) and x.dtype == jnp.float32:
            return x.astype(jnp.bfloat16)
        return x

    model = jax.tree_map(convert_to_bfloat16, model)

    # Use mixed precision optimizer
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.scale_by_adam(),
        optax.scale(-1e-3),  # Learning rate
    )

    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)

    @nnx.jit
    def train_step(model, opt_state, batch, rng):
        # Convert batch to bfloat16
        batch = jax.tree_map(convert_to_bfloat16, batch)

        def loss_fn(model):
            outputs = model(batch["images"], rngs=nnx.Rngs(rng), training=True)
            # Keep loss in float32 for numerical stability
            return outputs["loss"].astype(jnp.float32), outputs

        (loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

        # Update (gradients automatically in bfloat16)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = nnx.apply_updates(model, updates)

        return model, opt_state, loss

    # Training loop
    for epoch in range(num_epochs):
        for batch in train_loader(batch_size=128):
            rng, step_rng = jax.random.split(rng)
            model, opt_state, loss = train_step(model, opt_state, batch, step_rng)

        print(f"Epoch {epoch + 1}: Loss = {loss:.4f}")

    return model

# Train with mixed precision
model = mixed_precision_training(
    model=model,
    train_loader=train_loader,
    num_epochs=10,
)

Model Checkpointing¤

Save and Load Checkpoints¤

import pickle
from pathlib import Path

def save_checkpoint(model, opt_state, step, path):
    """Save training checkpoint."""
    checkpoint = {
        "model_state": nnx.state(model),
        "opt_state": opt_state,
        "step": step,
    }

    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    with open(path, "wb") as f:
        pickle.dump(checkpoint, f)

    print(f"Checkpoint saved to {path}")

def load_checkpoint(model, path):
    """Load training checkpoint."""
    with open(path, "rb") as f:
        checkpoint = pickle.load(f)

    # Restore model state
    model = nnx.merge(
        nnx.GraphDef.from_state(checkpoint["model_state"]),
        checkpoint["model_state"]
    )

    return model, checkpoint["opt_state"], checkpoint["step"]

# Save checkpoint during training
if step % 1000 == 0:
    save_checkpoint(
        model=model,
        opt_state=opt_state,
        step=step,
        path=f"./checkpoints/step_{step}.pkl"
    )

# Load checkpoint
model, opt_state, step = load_checkpoint(
    model=model,
    path="./checkpoints/step_5000.pkl"
)
print(f"Resumed from step {step}")

Best Model Checkpointing¤

Save the best model based on validation metrics:

class BestModelCheckpoint:
    """Save best model based on validation metric."""

    def __init__(self, save_path, mode='min'):
        """Initialize best model checkpoint.

        Args:
            save_path: Path to save best model
            mode: 'min' for loss, 'max' for accuracy
        """
        self.save_path = Path(save_path)
        self.mode = mode
        self.best_value = float('inf') if mode == 'min' else float('-inf')

    def __call__(self, model, opt_state, step, current_value):
        """Check and save if best model."""
        improved = False
        if self.mode == 'min':
            improved = current_value < self.best_value
        else:
            improved = current_value > self.best_value

        if improved:
            self.best_value = current_value
            save_checkpoint(model, opt_state, step, self.save_path)
            print(f"New best model! Value: {current_value:.4f}")

        return improved

# Use during training
best_checkpoint = BestModelCheckpoint(
    save_path="./checkpoints/best_model.pkl",
    mode='min'
)

for epoch in range(num_epochs):
    # Train epoch
    train_metrics = train_epoch(model)

    # Validate
    val_metrics = validate(model, val_loader)

    # Save if best
    best_checkpoint(
        model=model,
        opt_state=opt_state,
        step=epoch,
        current_value=val_metrics['loss']
    )

Logging and Monitoring¤

Weights & Biases Integration¤

import wandb

def train_with_wandb(
    model,
    train_loader,
    val_loader,
    training_config,
    project_name="generative-models",
):
    """Training with W&B logging."""
    # Initialize wandb
    wandb.init(
        project=project_name,
        config={
            "learning_rate": training_config.optimizer.learning_rate,
            "batch_size": training_config.batch_size,
            "num_epochs": training_config.num_epochs,
            "optimizer": training_config.optimizer.optimizer_type,
        }
    )

    optimizer = optax.adam(training_config.optimizer.learning_rate)
    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)
    step = 0

    for epoch in range(training_config.num_epochs):
        # Train epoch
        for batch in train_loader(batch_size=training_config.batch_size):
            rng, step_rng = jax.random.split(rng)
            model, opt_state, loss, metrics = train_step(
                model, opt_state, batch, step_rng
            )

            # Log to wandb
            wandb.log({
                "train/loss": float(loss),
                "train/step": step,
                **{f"train/{k}": float(v) for k, v in metrics.items()}
            })

            step += 1

        # Validation
        val_metrics = validate(model, val_loader)
        wandb.log({
            "val/loss": val_metrics['loss'],
            "epoch": epoch,
        })

        print(f"Epoch {epoch + 1}: Val Loss = {val_metrics['loss']:.4f}")

    wandb.finish()
    return model

# Train with wandb
model = train_with_wandb(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    training_config=training_config,
    project_name="vae-experiments",
)

TensorBoard Integration¤

from torch.utils.tensorboard import SummaryWriter

def train_with_tensorboard(
    model,
    train_loader,
    val_loader,
    num_epochs,
    log_dir="./logs",
):
    """Training with TensorBoard logging."""
    writer = SummaryWriter(log_dir=log_dir)

    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)
    step = 0

    for epoch in range(num_epochs):
        # Train epoch
        for batch in train_loader(batch_size=128):
            rng, step_rng = jax.random.split(rng)
            model, opt_state, loss, metrics = train_step(
                model, opt_state, batch, step_rng
            )

            # Log to tensorboard
            writer.add_scalar("Loss/train", float(loss), step)
            for key, value in metrics.items():
                writer.add_scalar(f"Metrics/{key}", float(value), step)

            step += 1

        # Validation
        val_metrics = validate(model, val_loader)
        writer.add_scalar("Loss/val", val_metrics['loss'], epoch)

        # Log images
        if epoch % 5 == 0:
            samples = model.generate(num_samples=16, rngs=nnx.Rngs(0))
            writer.add_images("Samples", samples, epoch)

    writer.close()
    return model

# Train with tensorboard
model = train_with_tensorboard(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=10,
    log_dir="./tensorboard_logs",
)

Common Training Patterns¤

Progressive Training¤

Train with progressively increasing complexity:

def progressive_training(model, train_loader, stages):
    """Train with progressive stages.

    Args:
        model: Model to train
        train_loader: Data loader
        stages: List of (num_epochs, learning_rate, batch_size) tuples
    """
    optimizer_state = None

    for stage_idx, (num_epochs, learning_rate, batch_size) in enumerate(stages):
        print(f"\nStage {stage_idx + 1}: LR={learning_rate}, BS={batch_size}")

        # Create optimizer for this stage
        optimizer = optax.adam(learning_rate)

        # Initialize or reuse optimizer state
        if optimizer_state is None:
            optimizer_state = optimizer.init(nnx.state(model))

        # Train for this stage
        for epoch in range(num_epochs):
            for batch in train_loader(batch_size=batch_size):
                model, optimizer_state, loss = train_step(
                    model, optimizer_state, batch, rng
                )

            print(f"  Epoch {epoch + 1}: Loss = {loss:.4f}")

    return model

# Define progressive stages
stages = [
    (10, 1e-3, 32),   # Stage 1: High LR, small batch
    (20, 5e-4, 64),   # Stage 2: Medium LR, medium batch
    (30, 1e-4, 128),  # Stage 3: Low LR, large batch
]

model = progressive_training(model, train_loader, stages)

Curriculum Learning¤

Train with increasing data difficulty:

def curriculum_learning(model, data_loader_fn, difficulty_schedule):
    """Train with curriculum learning.

    Args:
        model: Model to train
        data_loader_fn: Function that returns data loader for difficulty level
        difficulty_schedule: List of (difficulty_level, num_epochs) tuples
    """
    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)

    for difficulty, num_epochs in difficulty_schedule:
        print(f"\nTraining on difficulty level: {difficulty}")

        # Get data loader for this difficulty
        train_loader = data_loader_fn(difficulty)

        # Train
        for epoch in range(num_epochs):
            for batch in train_loader(batch_size=128):
                rng, step_rng = jax.random.split(rng)
                model, opt_state, loss = train_step(
                    model, opt_state, batch, step_rng
                )

            print(f"  Epoch {epoch + 1}: Loss = {loss:.4f}")

    return model

# Define curriculum
difficulty_schedule = [
    ("easy", 10),      # Train on easy examples first
    ("medium", 20),    # Then medium difficulty
    ("hard", 30),      # Finally hard examples
    ("all", 40),       # Train on all data
]

model = curriculum_learning(model, data_loader_fn, difficulty_schedule)

Multi-Task Training¤

Train on multiple tasks simultaneously:

def multi_task_training(
    model,
    task_loaders,
    task_weights,
    num_epochs,
):
    """Train on multiple tasks.

    Args:
        model: Model to train
        task_loaders: Dict of task_name -> data_loader
        task_weights: Dict of task_name -> weight
        num_epochs: Number of epochs
    """
    optimizer = optax.adam(1e-3)
    opt_state = optimizer.init(nnx.state(model))
    rng = jax.random.PRNGKey(0)

    @nnx.jit
    def multi_task_step(model, opt_state, batches, rng):
        """Training step with multiple tasks."""
        def loss_fn(model):
            total_loss = 0.0
            metrics = {}

            for task_name, batch in batches.items():
                # Task-specific forward pass
                outputs = model(
                    batch,
                    task=task_name,
                    rngs=nnx.Rngs(rng),
                    training=True
                )

                # Weighted loss
                task_loss = outputs["loss"] * task_weights[task_name]
                total_loss += task_loss

                # Track metrics
                metrics[f"{task_name}_loss"] = outputs["loss"]

            return total_loss, metrics

        (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

        updates, opt_state = optimizer.update(grads, opt_state)
        model = nnx.apply_updates(model, updates)

        return model, opt_state, loss, metrics

    # Training loop
    for epoch in range(num_epochs):
        # Get batches from all tasks
        task_iters = {
            name: loader(batch_size=32)
            for name, loader in task_loaders.items()
        }

        for step in range(1000):  # Fixed steps per epoch
            # Get batch from each task
            batches = {
                name: next(task_iter)
                for name, task_iter in task_iters.items()
            }

            rng, step_rng = jax.random.split(rng)
            model, opt_state, loss, metrics = multi_task_step(
                model, opt_state, batches, step_rng
            )

            if step % 100 == 0:
                print(f"Step {step}: Total Loss = {loss:.4f}")
                for task_name, task_loss in metrics.items():
                    print(f"  {task_name}: {task_loss:.4f}")

    return model

# Train on multiple tasks
task_loaders = {
    "reconstruction": reconstruction_loader,
    "generation": generation_loader,
    "classification": classification_loader,
}

task_weights = {
    "reconstruction": 1.0,
    "generation": 0.5,
    "classification": 0.3,
}

model = multi_task_training(
    model=model,
    task_loaders=task_loaders,
    task_weights=task_weights,
    num_epochs=50,
)

Troubleshooting¤

NaN Loss¤

If you encounter NaN loss:

# 1. Add gradient clipping
optimizer_config = OptimizerConfiguration(
    name="clipped_adam",
    optimizer_type="adam",
    learning_rate=1e-3,
    gradient_clip_norm=1.0,  # Clip gradients
)

# 2. Lower learning rate
optimizer_config = OptimizerConfiguration(
    name="lower_lr",
    optimizer_type="adam",
    learning_rate=1e-4,  # Lower LR
)

# 3. Check for numerical instability
def check_for_nans(metrics, step):
    """Check for NaNs in metrics."""
    for key, value in metrics.items():
        if np.isnan(value):
            print(f"NaN detected at step {step} in {key}")
            # Save checkpoint before crash
            save_checkpoint(model, opt_state, step, "./emergency_checkpoint.pkl")
            raise ValueError(f"NaN in {key}")

# 4. Use mixed precision with care
# Avoid bfloat16 for loss computation
loss = loss.astype(jnp.float32)  # Keep loss in float32

Slow Training¤

If training is slow:

# 1. Use JIT compilation
@nnx.jit
def train_step(model, opt_state, batch, rng):
    # Training step logic
    pass

# 2. Profile your code
with jax.profiler.trace("./tensorboard_logs"):
    for _ in range(100):
        model, opt_state, loss = train_step(model, opt_state, batch, rng)

# 3. Increase batch size (if memory allows)
training_config = TrainingConfiguration(
    name="large_batch",
    batch_size=256,  # Larger batch size
    num_epochs=50,   # Fewer epochs needed
    optimizer=optimizer_config,
)

# 4. Use data prefetching
from concurrent.futures import ThreadPoolExecutor

def prefetch_data_loader(data_loader, prefetch_size=2):
    """Prefetch data in background."""
    with ThreadPoolExecutor(max_workers=1) as executor:
        iterator = iter(data_loader(batch_size=128))
        futures = [executor.submit(lambda: next(iterator))
                   for _ in range(prefetch_size)]

        while True:
            # Get next batch from future
            batch = futures.pop(0).result()
            # Submit new prefetch
            futures.append(executor.submit(lambda: next(iterator)))
            yield batch

Memory Issues¤

If you run out of memory:

# 1. Reduce batch size
training_config = TrainingConfiguration(
    name="small_batch",
    batch_size=32,  # Smaller batch
    num_epochs=200,  # More epochs
    optimizer=optimizer_config,
)

# 2. Use gradient accumulation
# See "Gradient Accumulation" section above

# 3. Clear cache periodically
import jax

# Clear compilation cache
jax.clear_caches()

# 4. Use checkpointing for large models
from jax.checkpoint import checkpoint

@checkpoint
def expensive_forward_pass(model, x):
    """Forward pass with gradient checkpointing."""
    return model(x)

Best Practices¤

DO¤

  • ✅ Use type-safe configuration with validation
  • ✅ JIT-compile training steps for performance
  • ✅ Save checkpoints regularly
  • ✅ Monitor training metrics (loss, gradients)
  • ✅ Use gradient clipping for stability
  • ✅ Start with small learning rate and increase
  • ✅ Validate periodically during training
  • ✅ Save best model based on validation metrics
  • ✅ Use warmup for learning rate schedules
  • ✅ Profile code to find bottlenecks

DON'T¤

  • ❌ Skip validation - always validate your model
  • ❌ Use too high learning rate initially
  • ❌ Forget to shuffle training data
  • ❌ Ignore NaN or infinite losses
  • ❌ Train without gradient clipping
  • ❌ Overwrite checkpoints without backup
  • ❌ Use mixed precision for all operations
  • ❌ Forget to split RNG keys properly
  • ❌ Mutate training state in-place
  • ❌ Skip warmup for large learning rates

Summary¤

This guide covered:

  • Basic Training: Quick start and setup
  • Custom Loops: Full control over training
  • Learning Rate Schedules: Warmup, cosine, one-cycle
  • Advanced Techniques: Gradient accumulation, early stopping, mixed precision
  • Checkpointing: Save and load model state
  • Logging: W&B, TensorBoard integration
  • Common Patterns: Progressive training, curriculum learning, multi-task
  • Troubleshooting: NaN loss, slow training, memory issues

Next Steps¤


See the Configuration Guide for detailed configuration options and patterns.