Diffusion Trainer¤
Module: artifex.generative_models.training.trainers.diffusion_trainer
The Diffusion Trainer provides state-of-the-art training utilities for diffusion models, including multiple prediction types, advanced timestep sampling strategies, loss weighting schemes, and EMA model updates.
Overview¤
Modern diffusion model training benefits from several advanced techniques. The Diffusion Trainer provides:
- Prediction Types: Epsilon, v-prediction, and x-prediction
- Timestep Sampling: Uniform, logit-normal, and mode-seeking strategies
- Loss Weighting: Uniform, SNR, min-SNR, and EDM-style weighting
- EMA Updates: Exponential moving average for stable inference
Quick Start¤
from artifex.generative_models.training.trainers import (
DiffusionTrainer,
DiffusionTrainingConfig,
)
from artifex.generative_models.core.noise_schedules import CosineNoiseSchedule
from flax import nnx
import optax
import jax
# Create model and optimizer
model = create_diffusion_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
noise_schedule = CosineNoiseSchedule(num_timesteps=1000)
# Configure diffusion training with SOTA techniques
config = DiffusionTrainingConfig(
prediction_type="v_prediction",
timestep_sampling="logit_normal",
loss_weighting="min_snr",
snr_gamma=5.0,
)
trainer = DiffusionTrainer(model, optimizer, noise_schedule, config)
# Training loop
key = jax.random.key(0)
for step, batch in enumerate(train_loader):
key, subkey = jax.random.split(key)
loss, metrics = trainer.train_step(batch, subkey)
if step % 100 == 0:
print(f"Step {step}: loss={metrics['loss']:.4f}")
Configuration¤
artifex.generative_models.training.trainers.diffusion_trainer.DiffusionTrainingConfig
dataclass
¤
DiffusionTrainingConfig(
prediction_type: Literal[
"epsilon", "v_prediction", "x_start"
] = "epsilon",
timestep_sampling: Literal[
"uniform", "logit_normal", "mode"
] = "uniform",
loss_weighting: Literal[
"uniform", "snr", "min_snr", "edm"
] = "uniform",
snr_gamma: float = 5.0,
logit_normal_loc: float = -0.5,
logit_normal_scale: float = 1.0,
ema_decay: float = 0.9999,
ema_update_every: int = 10,
)
Configuration for diffusion model training.
Attributes:
| Name | Type | Description |
|---|---|---|
prediction_type |
Literal['epsilon', 'v_prediction', 'x_start']
|
What the model predicts. - "epsilon": Predicts the added noise - "v_prediction": Predicts v = sqrt(alpha)*noise - sqrt(1-alpha)*x0 - "x_start": Predicts the original clean data |
timestep_sampling |
Literal['uniform', 'logit_normal', 'mode']
|
How to sample timesteps during training. - "uniform": Uniform random sampling - "logit_normal": Logit-normal distribution (favors middle timesteps) - "mode": Mode-seeking (favors high-noise timesteps) |
loss_weighting |
Literal['uniform', 'snr', 'min_snr', 'edm']
|
How to weight the loss across timesteps. - "uniform": Equal weighting - "snr": Weight by signal-to-noise ratio - "min_snr": Min-SNR-gamma weighting (3.4x faster convergence) - "edm": EDM-style weighting |
snr_gamma |
float
|
Gamma parameter for min-SNR weighting (5.0 typical). |
logit_normal_loc |
float
|
Location parameter for logit-normal sampling. |
logit_normal_scale |
float
|
Scale parameter for logit-normal sampling. |
ema_decay |
float
|
EMA decay rate for model weights. |
ema_update_every |
int
|
Update EMA every N training steps. |
prediction_type
class-attribute
instance-attribute
¤
prediction_type: Literal[
"epsilon", "v_prediction", "x_start"
] = "epsilon"
timestep_sampling
class-attribute
instance-attribute
¤
timestep_sampling: Literal[
"uniform", "logit_normal", "mode"
] = "uniform"
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
prediction_type |
str |
"epsilon" |
Model prediction: "epsilon", "v_prediction", "x_start" |
timestep_sampling |
str |
"uniform" |
Timestep distribution: "uniform", "logit_normal", "mode" |
loss_weighting |
str |
"uniform" |
Loss weighting: "uniform", "snr", "min_snr", "edm" |
snr_gamma |
float |
5.0 |
Gamma for min-SNR weighting |
logit_normal_loc |
float |
-0.5 |
Logit-normal location parameter |
logit_normal_scale |
float |
1.0 |
Logit-normal scale parameter |
ema_decay |
float |
0.9999 |
EMA decay rate |
ema_update_every |
int |
10 |
Steps between EMA updates |
Prediction Types¤
Epsilon Prediction (DDPM)¤
The classic approach - model predicts the noise added:
V-Prediction¤
Model predicts v = sqrt(alpha) noise - sqrt(1-alpha) x_0:
config = DiffusionTrainingConfig(prediction_type="v_prediction")
# Provides more consistent gradients across timesteps
# Used in Stable Diffusion 3 and Imagen Video
V-prediction often leads to faster convergence and better sample quality.
X-Start Prediction¤
Model directly predicts the clean data:
Timestep Sampling Strategies¤
Uniform Sampling¤
Standard uniform sampling over all timesteps:
Logit-Normal Sampling¤
Favors middle timesteps where learning is most effective:
config = DiffusionTrainingConfig(
timestep_sampling="logit_normal",
logit_normal_loc=-0.5,
logit_normal_scale=1.0,
)
# Used in Stable Diffusion 3 for improved convergence
Mode-Seeking Sampling¤
Favors high-noise timesteps for improved generation quality:
Loss Weighting Schemes¤
Uniform Weighting¤
No weighting - all timesteps contribute equally:
SNR Weighting¤
Weight by signal-to-noise ratio:
Min-SNR-Gamma Weighting¤
Clips high SNR weights for 3.4x faster convergence:
config = DiffusionTrainingConfig(
loss_weighting="min_snr",
snr_gamma=5.0,
)
# weight = min(SNR, gamma) / SNR
# Down-weights low-noise timesteps where SNR is high
Min-SNR-gamma is the recommended weighting scheme for most use cases.
EDM Weighting¤
EDM-style weighting based on sigma:
EMA (Exponential Moving Average)¤
Maintain an EMA of model parameters for stable inference:
config = DiffusionTrainingConfig(
ema_decay=0.9999,
ema_update_every=10,
)
# After training, get EMA parameters
ema_params = trainer.get_ema_params()
# Apply EMA params to model for inference
from flax import nnx
nnx.update(model, ema_params)
API Reference¤
artifex.generative_models.training.trainers.diffusion_trainer.DiffusionTrainer
¤
DiffusionTrainer(
noise_schedule: NoiseScheduleProtocol,
config: DiffusionTrainingConfig | None = None,
)
Diffusion model trainer with modern training techniques.
This trainer provides a JIT-compatible interface for training diffusion models with state-of-the-art techniques. The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.
Features
- Multiple prediction types (epsilon, v, x0)
- Non-uniform timestep sampling (logit-normal, mode-seeking)
- Loss weighting (SNR, min-SNR, EDM)
- EMA model updates (call update_ema separately, outside JIT)
Example (non-JIT):
from artifex.generative_models.training.trainers import (
DiffusionTrainer,
DiffusionTrainingConfig,
)
# Create trainer with noise schedule and config
config = DiffusionTrainingConfig(
prediction_type="v_prediction",
timestep_sampling="logit_normal",
loss_weighting="min_snr",
)
trainer = DiffusionTrainer(noise_schedule, config)
# Create model and optimizer separately
model = DDPMModel(model_config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(1e-4))
# Training loop
for batch in data:
rng, step_rng = jax.random.split(rng)
loss, metrics = trainer.train_step(model, optimizer, batch, step_rng)
trainer.update_ema(model) # EMA updates outside train_step
Example (JIT-compiled):
trainer = DiffusionTrainer(noise_schedule, config)
jit_step = nnx.jit(trainer.train_step)
for batch in data:
rng, step_rng = jax.random.split(rng)
loss, metrics = jit_step(model, optimizer, batch, step_rng)
trainer.update_ema(model) # Outside JIT
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
noise_schedule
|
NoiseScheduleProtocol
|
Noise schedule with alphas_cumprod and add_noise. |
required |
config
|
DiffusionTrainingConfig | None
|
Diffusion training configuration. |
None
|
sample_timesteps
¤
get_loss_weight
¤
compute_target
¤
compute_loss
¤
Compute diffusion training loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Diffusion model to evaluate. |
required |
batch
|
dict[str, Any]
|
Batch dictionary with "image" or "data" key. |
required |
key
|
Array
|
PRNG key for sampling noise and timesteps. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (total_loss, metrics_dict). |
train_step
¤
train_step(
model: Module,
optimizer: Optimizer,
batch: dict[str, Any],
key: Array,
) -> tuple[Array, dict[str, Any]]
Execute a single training step.
This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, key)
Note: Call update_ema() separately after train_step for EMA updates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Diffusion model to train. |
required |
optimizer
|
Optimizer
|
NNX optimizer for parameter updates. |
required |
batch
|
dict[str, Any]
|
Batch dictionary with "image" or "data" key. |
required |
key
|
Array
|
PRNG key for sampling. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
update_ema
¤
update_ema(model: Module) -> None
Update EMA parameters.
Call this method separately after train_step, outside of JIT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model whose parameters to use for EMA update. |
required |
get_ema_params
¤
create_loss_fn
¤
Create loss function compatible with base Trainer.
This enables integration with the base Trainer for callbacks, checkpointing, logging, and other training infrastructure.
Returns:
| Type | Description |
|---|---|
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]
|
Function with signature: (model, batch, rng) -> (loss, metrics) |
Noise Schedule Protocol¤
The trainer works with any noise schedule implementing the NoiseScheduleProtocol:
from typing import Protocol
import jax
class NoiseScheduleProtocol(Protocol):
"""Protocol for noise schedules used by diffusion trainers."""
num_timesteps: int
alphas_cumprod: jax.Array
def add_noise(
self,
x: jax.Array,
noise: jax.Array,
t: jax.Array,
) -> jax.Array:
"""Add noise to data at given timesteps."""
...
Artifex provides several noise schedule implementations:
from artifex.generative_models.core.noise_schedules import (
LinearNoiseSchedule,
CosineNoiseSchedule,
SquaredCosineNoiseSchedule,
)
# Linear schedule (DDPM default)
schedule = LinearNoiseSchedule(num_timesteps=1000)
# Cosine schedule (improved for images)
schedule = CosineNoiseSchedule(num_timesteps=1000)
Integration with Base Trainer¤
Use create_loss_fn() for integration with callbacks and checkpointing:
from artifex.generative_models.training import Trainer
from artifex.generative_models.training.trainers import (
DiffusionTrainer,
DiffusionTrainingConfig,
)
from artifex.generative_models.training.callbacks import (
ModelCheckpoint,
CheckpointConfig,
)
# Create diffusion trainer
diff_config = DiffusionTrainingConfig(
prediction_type="v_prediction",
loss_weighting="min_snr",
)
diff_trainer = DiffusionTrainer(model, optimizer, noise_schedule, diff_config)
# Get loss function for base Trainer
loss_fn = diff_trainer.create_loss_fn()
# Use with base Trainer for callbacks
callbacks = [
ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="loss")),
]
Model Requirements¤
The Diffusion Trainer expects models with the following interface:
class DiffusionModel(nnx.Module):
def __call__(
self,
x_noisy: jax.Array,
t: jax.Array,
) -> jax.Array:
"""Predict noise/v/x_0 from noisy input.
Args:
x_noisy: Noisy data, shape (batch, ...).
t: Integer timesteps, shape (batch,).
Returns:
Prediction matching prediction_type, shape (batch, ...).
"""
...
Training Metrics¤
| Metric | Description |
|---|---|
loss |
Weighted loss (affected by loss_weighting) |
loss_unweighted |
Raw MSE loss without weighting |
Recommended Configurations¤
High-Quality Image Generation¤
config = DiffusionTrainingConfig(
prediction_type="v_prediction",
timestep_sampling="logit_normal",
loss_weighting="min_snr",
snr_gamma=5.0,
ema_decay=0.9999,
)
Fast Training¤
config = DiffusionTrainingConfig(
prediction_type="epsilon",
timestep_sampling="uniform",
loss_weighting="min_snr",
snr_gamma=5.0,
)
Large Models (EDM-style)¤
config = DiffusionTrainingConfig(
prediction_type="epsilon",
timestep_sampling="logit_normal",
loss_weighting="edm",
ema_decay=0.9999,
ema_update_every=1,
)