Training Systems¤
Comprehensive training infrastructure for generative models, including model-specific trainers, distributed training, callbacks, and optimization utilities.
Overview¤
-
Model Trainers
Specialized trainers for VAE, GAN, Diffusion, Flow, EBM, and Autoregressive models
-
Distributed Training
Data parallel and model parallel training across multiple GPUs/TPUs
-
Callbacks
Checkpointing, early stopping, logging, and custom callbacks
-
Optimizers
AdamW, Lion, Adafactor with learning rate schedulers
Quick Start¤
Basic Training¤
from artifex.generative_models.training import VAETrainer
from artifex.generative_models.core.configuration import TrainingConfig
# Create training configuration
training_config = TrainingConfig(
batch_size=128,
num_epochs=100,
optimizer={"type": "adam", "learning_rate": 1e-3},
scheduler={"type": "cosine", "warmup_steps": 1000},
)
# Create trainer
trainer = VAETrainer(
model=model,
config=training_config,
train_dataset=train_data,
val_dataset=val_data,
)
# Train
trainer.train()
Model-Specific Trainers¤
Each model family has a specialized trainer that handles its unique training requirements.
VAE Trainer¤
Handles ELBO loss, KL annealing, and reconstruction metrics.
from artifex.generative_models.training import VAETrainer
trainer = VAETrainer(
model=vae_model,
config=training_config,
train_dataset=train_data,
kl_annealing=True,
kl_warmup_epochs=10,
)
GAN Trainer¤
Manages generator/discriminator alternating updates.
from artifex.generative_models.training import GANTrainer
trainer = GANTrainer(
model=gan_model,
config=training_config,
train_dataset=train_data,
d_steps=5, # Discriminator steps per generator step
gp_weight=10.0, # Gradient penalty weight
)
Diffusion Trainer¤
Handles noise scheduling and denoising score matching.
from artifex.generative_models.training import DiffusionTrainer
trainer = DiffusionTrainer(
model=diffusion_model,
config=training_config,
train_dataset=train_data,
ema_decay=0.9999, # Exponential moving average
)
Flow Trainer¤
Manages exact likelihood training for normalizing flows.
from artifex.generative_models.training import FlowTrainer
trainer = FlowTrainer(
model=flow_model,
config=training_config,
train_dataset=train_data,
)
Energy Trainer¤
Handles contrastive divergence and MCMC sampling.
from artifex.generative_models.training import EnergyTrainer
trainer = EnergyTrainer(
model=ebm_model,
config=training_config,
train_dataset=train_data,
mcmc_steps=10,
)
Autoregressive Trainer¤
Manages sequential likelihood training.
from artifex.generative_models.training import AutoregressiveTrainer
trainer = AutoregressiveTrainer(
model=ar_model,
config=training_config,
train_dataset=train_data,
)
Autoregressive Trainer Reference
Callbacks¤
Callbacks allow customization of the training loop.
Built-in Callbacks¤
| Callback | Description |
|---|---|
| CheckpointCallback | Save model checkpoints |
| EarlyStoppingCallback | Stop training when validation plateaus |
| LoggingCallback | Log metrics to console/file |
| ProfilingCallback | Profile training performance |
| VisualizationCallback | Generate sample visualizations |
Using Callbacks¤
from artifex.generative_models.training.callbacks import (
CheckpointCallback,
EarlyStoppingCallback,
LoggingCallback,
)
callbacks = [
CheckpointCallback(
save_dir="checkpoints/",
save_every_n_epochs=10,
save_best=True,
metric="val_loss",
),
EarlyStoppingCallback(
patience=20,
metric="val_loss",
mode="min",
),
LoggingCallback(
log_every_n_steps=100,
use_wandb=True,
),
]
trainer = VAETrainer(
model=model,
config=config,
train_dataset=train_data,
callbacks=callbacks,
)
Custom Callbacks¤
from artifex.generative_models.training.callbacks import BaseCallback
class CustomCallback(BaseCallback):
def on_epoch_start(self, trainer, epoch):
print(f"Starting epoch {epoch}")
def on_epoch_end(self, trainer, epoch, metrics):
print(f"Epoch {epoch} completed: {metrics}")
def on_train_batch_end(self, trainer, batch_idx, loss):
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}: loss={loss:.4f}")
Distributed Training¤
Data Parallel¤
from artifex.generative_models.training.distributed import DataParallelTrainer
trainer = DataParallelTrainer(
model=model,
config=config,
train_dataset=train_data,
num_devices=4, # Use 4 GPUs
)
Model Parallel¤
from artifex.generative_models.training.distributed import ModelParallelTrainer
trainer = ModelParallelTrainer(
model=large_model,
config=config,
train_dataset=train_data,
mesh_shape=(2, 4), # 2x4 device mesh
)
Device Mesh¤
from artifex.generative_models.training.distributed import DeviceMesh
mesh = DeviceMesh(
shape=(2, 2), # 2x2 mesh
axis_names=("data", "model"),
)
Optimizers¤
Available Optimizers¤
| Optimizer | Description | Best For |
|---|---|---|
| AdamW | Adam with weight decay | General use |
| Lion | Memory-efficient optimizer | Large models |
| Adafactor | Low memory optimizer | Very large models |
Using Optimizers¤
from artifex.generative_models.training.optimizers import create_optimizer
optimizer = create_optimizer(
optimizer_type="adamw",
learning_rate=1e-3,
weight_decay=0.01,
beta1=0.9,
beta2=0.999,
)
Learning Rate Schedulers¤
Available Schedulers¤
| Scheduler | Description |
|---|---|
| Cosine | Cosine annealing with warmup |
| Linear | Linear warmup and decay |
| Exponential | Exponential decay |
Using Schedulers¤
from artifex.generative_models.training.schedulers import create_scheduler
scheduler = create_scheduler(
scheduler_type="cosine",
warmup_steps=1000,
total_steps=100000,
min_lr=1e-6,
)
RL Training¤
Reinforcement learning trainers for fine-tuning generative models.
| Trainer | Description |
|---|---|
| REINFORCE | Policy gradient training |
| PPO | Proximal Policy Optimization |
| DPO | Direct Preference Optimization |
| GRPO | Group Relative Policy Optimization |
Advanced Features¤
Gradient Accumulation¤
trainer = VAETrainer(
model=model,
config=config,
train_dataset=train_data,
gradient_accumulation_steps=4,
)
Mixed Precision Training¤
trainer = VAETrainer(
model=model,
config=config,
train_dataset=train_data,
mixed_precision=True, # Use bfloat16
)
Experiment Tracking¤
trainer = VAETrainer(
model=model,
config=config,
train_dataset=train_data,
tracking={
"wandb": {"project": "my-project"},
"mlflow": {"experiment": "vae-experiments"},
},
)
Module Reference¤
| Category | Modules |
|---|---|
| Trainers | vae_trainer, gan_trainer, diffusion_trainer, flow_trainer, energy_trainer, autoregressive_trainer |
| Callbacks | base, checkpoint, early_stopping, logging, profiling, visualization |
| Distributed | data_parallel, model_parallel, mesh, device_placement, distributed_metrics |
| Optimizers | adamw, lion, adafactor, optax_wrappers |
| Schedulers | cosine, linear, exponential, factory, scheduler |
| RL | reinforce, ppo, dpo, grpo |
| Utilities | base, gradient_accumulation, mixed_precision, tracking, trainer, utils |
Related Documentation¤
- Training Guide - Complete training guide
- Configuration System - Training configuration
- Distributed Training - Multi-GPU/TPU guide