Skip to content

Training Systems¤

Comprehensive training infrastructure for generative models, including model-specific trainers, distributed training, callbacks, and optimization utilities.

Overview¤

  • Model Trainers


    Specialized trainers for VAE, GAN, Diffusion, Flow, EBM, and Autoregressive models

  • Distributed Training


    Data parallel and model parallel training across multiple GPUs/TPUs

  • Callbacks


    Checkpointing, early stopping, logging, and custom callbacks

  • Optimizers


    AdamW, Lion, Adafactor with learning rate schedulers

Quick Start¤

Basic Training¤

from artifex.generative_models.training import VAETrainer
from artifex.generative_models.core.configuration import TrainingConfig

# Create training configuration
training_config = TrainingConfig(
    batch_size=128,
    num_epochs=100,
    optimizer={"type": "adam", "learning_rate": 1e-3},
    scheduler={"type": "cosine", "warmup_steps": 1000},
)

# Create trainer
trainer = VAETrainer(
    model=model,
    config=training_config,
    train_dataset=train_data,
    val_dataset=val_data,
)

# Train
trainer.train()

Model-Specific Trainers¤

Each model family has a specialized trainer that handles its unique training requirements.

VAE Trainer¤

Handles ELBO loss, KL annealing, and reconstruction metrics.

from artifex.generative_models.training import VAETrainer

trainer = VAETrainer(
    model=vae_model,
    config=training_config,
    train_dataset=train_data,
    kl_annealing=True,
    kl_warmup_epochs=10,
)

VAE Trainer Reference

GAN Trainer¤

Manages generator/discriminator alternating updates.

from artifex.generative_models.training import GANTrainer

trainer = GANTrainer(
    model=gan_model,
    config=training_config,
    train_dataset=train_data,
    d_steps=5,  # Discriminator steps per generator step
    gp_weight=10.0,  # Gradient penalty weight
)

GAN Trainer Reference

Diffusion Trainer¤

Handles noise scheduling and denoising score matching.

from artifex.generative_models.training import DiffusionTrainer

trainer = DiffusionTrainer(
    model=diffusion_model,
    config=training_config,
    train_dataset=train_data,
    ema_decay=0.9999,  # Exponential moving average
)

Diffusion Trainer Reference

Flow Trainer¤

Manages exact likelihood training for normalizing flows.

from artifex.generative_models.training import FlowTrainer

trainer = FlowTrainer(
    model=flow_model,
    config=training_config,
    train_dataset=train_data,
)

Flow Trainer Reference

Energy Trainer¤

Handles contrastive divergence and MCMC sampling.

from artifex.generative_models.training import EnergyTrainer

trainer = EnergyTrainer(
    model=ebm_model,
    config=training_config,
    train_dataset=train_data,
    mcmc_steps=10,
)

Energy Trainer Reference

Autoregressive Trainer¤

Manages sequential likelihood training.

from artifex.generative_models.training import AutoregressiveTrainer

trainer = AutoregressiveTrainer(
    model=ar_model,
    config=training_config,
    train_dataset=train_data,
)

Autoregressive Trainer Reference

Callbacks¤

Callbacks allow customization of the training loop.

Built-in Callbacks¤

Callback Description
CheckpointCallback Save model checkpoints
EarlyStoppingCallback Stop training when validation plateaus
LoggingCallback Log metrics to console/file
ProfilingCallback Profile training performance
VisualizationCallback Generate sample visualizations

Using Callbacks¤

from artifex.generative_models.training.callbacks import (
    CheckpointCallback,
    EarlyStoppingCallback,
    LoggingCallback,
)

callbacks = [
    CheckpointCallback(
        save_dir="checkpoints/",
        save_every_n_epochs=10,
        save_best=True,
        metric="val_loss",
    ),
    EarlyStoppingCallback(
        patience=20,
        metric="val_loss",
        mode="min",
    ),
    LoggingCallback(
        log_every_n_steps=100,
        use_wandb=True,
    ),
]

trainer = VAETrainer(
    model=model,
    config=config,
    train_dataset=train_data,
    callbacks=callbacks,
)

Custom Callbacks¤

from artifex.generative_models.training.callbacks import BaseCallback

class CustomCallback(BaseCallback):
    def on_epoch_start(self, trainer, epoch):
        print(f"Starting epoch {epoch}")

    def on_epoch_end(self, trainer, epoch, metrics):
        print(f"Epoch {epoch} completed: {metrics}")

    def on_train_batch_end(self, trainer, batch_idx, loss):
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx}: loss={loss:.4f}")

Distributed Training¤

Data Parallel¤

from artifex.generative_models.training.distributed import DataParallelTrainer

trainer = DataParallelTrainer(
    model=model,
    config=config,
    train_dataset=train_data,
    num_devices=4,  # Use 4 GPUs
)

Data Parallel Reference

Model Parallel¤

from artifex.generative_models.training.distributed import ModelParallelTrainer

trainer = ModelParallelTrainer(
    model=large_model,
    config=config,
    train_dataset=train_data,
    mesh_shape=(2, 4),  # 2x4 device mesh
)

Model Parallel Reference

Device Mesh¤

from artifex.generative_models.training.distributed import DeviceMesh

mesh = DeviceMesh(
    shape=(2, 2),  # 2x2 mesh
    axis_names=("data", "model"),
)

Device Mesh Reference

Optimizers¤

Available Optimizers¤

Optimizer Description Best For
AdamW Adam with weight decay General use
Lion Memory-efficient optimizer Large models
Adafactor Low memory optimizer Very large models

Using Optimizers¤

from artifex.generative_models.training.optimizers import create_optimizer

optimizer = create_optimizer(
    optimizer_type="adamw",
    learning_rate=1e-3,
    weight_decay=0.01,
    beta1=0.9,
    beta2=0.999,
)

Learning Rate Schedulers¤

Available Schedulers¤

Scheduler Description
Cosine Cosine annealing with warmup
Linear Linear warmup and decay
Exponential Exponential decay

Using Schedulers¤

from artifex.generative_models.training.schedulers import create_scheduler

scheduler = create_scheduler(
    scheduler_type="cosine",
    warmup_steps=1000,
    total_steps=100000,
    min_lr=1e-6,
)

Scheduler Reference

RL Training¤

Reinforcement learning trainers for fine-tuning generative models.

Trainer Description
REINFORCE Policy gradient training
PPO Proximal Policy Optimization
DPO Direct Preference Optimization
GRPO Group Relative Policy Optimization

RL Training Guide

Advanced Features¤

Gradient Accumulation¤

trainer = VAETrainer(
    model=model,
    config=config,
    train_dataset=train_data,
    gradient_accumulation_steps=4,
)

Gradient Accumulation

Mixed Precision Training¤

trainer = VAETrainer(
    model=model,
    config=config,
    train_dataset=train_data,
    mixed_precision=True,  # Use bfloat16
)

Mixed Precision

Experiment Tracking¤

trainer = VAETrainer(
    model=model,
    config=config,
    train_dataset=train_data,
    tracking={
        "wandb": {"project": "my-project"},
        "mlflow": {"experiment": "vae-experiments"},
    },
)

Experiment Tracking

Module Reference¤

Category Modules
Trainers vae_trainer, gan_trainer, diffusion_trainer, flow_trainer, energy_trainer, autoregressive_trainer
Callbacks base, checkpoint, early_stopping, logging, profiling, visualization
Distributed data_parallel, model_parallel, mesh, device_placement, distributed_metrics
Optimizers adamw, lion, adafactor, optax_wrappers
Schedulers cosine, linear, exponential, factory, scheduler
RL reinforce, ppo, dpo, grpo
Utilities base, gradient_accumulation, mixed_precision, tracking, trainer, utils