Skip to content

Diffusion Trainer¤

Module: artifex.generative_models.training.trainers.diffusion_trainer

The Diffusion Trainer provides state-of-the-art training utilities for diffusion models, including multiple prediction types, advanced timestep sampling strategies, loss weighting schemes, and EMA model updates.

Overview¤

Modern diffusion model training benefits from several advanced techniques. The Diffusion Trainer provides:

  • Prediction Types: Epsilon, v-prediction, and x-prediction
  • Timestep Sampling: Uniform, logit-normal, and mode-seeking strategies
  • Loss Weighting: Uniform, SNR, min-SNR, and EDM-style weighting
  • EMA Updates: Exponential moving average for stable inference

Quick Start¤

from artifex.generative_models.training.trainers import (
    DiffusionTrainer,
    DiffusionTrainingConfig,
)
from artifex.generative_models.core.noise_schedules import CosineNoiseSchedule
from flax import nnx
import optax
import jax

# Create model and optimizer
model = create_diffusion_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
noise_schedule = CosineNoiseSchedule(num_timesteps=1000)

# Configure diffusion training with SOTA techniques
config = DiffusionTrainingConfig(
    prediction_type="v_prediction",
    timestep_sampling="logit_normal",
    loss_weighting="min_snr",
    snr_gamma=5.0,
)

trainer = DiffusionTrainer(model, optimizer, noise_schedule, config)

# Training loop
key = jax.random.key(0)

for step, batch in enumerate(train_loader):
    key, subkey = jax.random.split(key)
    loss, metrics = trainer.train_step(batch, subkey)

    if step % 100 == 0:
        print(f"Step {step}: loss={metrics['loss']:.4f}")

Configuration¤

artifex.generative_models.training.trainers.diffusion_trainer.DiffusionTrainingConfig dataclass ¤

DiffusionTrainingConfig(
    prediction_type: Literal[
        "epsilon", "v_prediction", "x_start"
    ] = "epsilon",
    timestep_sampling: Literal[
        "uniform", "logit_normal", "mode"
    ] = "uniform",
    loss_weighting: Literal[
        "uniform", "snr", "min_snr", "edm"
    ] = "uniform",
    snr_gamma: float = 5.0,
    logit_normal_loc: float = -0.5,
    logit_normal_scale: float = 1.0,
    ema_decay: float = 0.9999,
    ema_update_every: int = 10,
)

Configuration for diffusion model training.

Attributes:

Name Type Description
prediction_type Literal['epsilon', 'v_prediction', 'x_start']

What the model predicts. - "epsilon": Predicts the added noise - "v_prediction": Predicts v = sqrt(alpha)*noise - sqrt(1-alpha)*x0 - "x_start": Predicts the original clean data

timestep_sampling Literal['uniform', 'logit_normal', 'mode']

How to sample timesteps during training. - "uniform": Uniform random sampling - "logit_normal": Logit-normal distribution (favors middle timesteps) - "mode": Mode-seeking (favors high-noise timesteps)

loss_weighting Literal['uniform', 'snr', 'min_snr', 'edm']

How to weight the loss across timesteps. - "uniform": Equal weighting - "snr": Weight by signal-to-noise ratio - "min_snr": Min-SNR-gamma weighting (3.4x faster convergence) - "edm": EDM-style weighting

snr_gamma float

Gamma parameter for min-SNR weighting (5.0 typical).

logit_normal_loc float

Location parameter for logit-normal sampling.

logit_normal_scale float

Scale parameter for logit-normal sampling.

ema_decay float

EMA decay rate for model weights.

ema_update_every int

Update EMA every N training steps.

prediction_type class-attribute instance-attribute ¤

prediction_type: Literal[
    "epsilon", "v_prediction", "x_start"
] = "epsilon"

timestep_sampling class-attribute instance-attribute ¤

timestep_sampling: Literal[
    "uniform", "logit_normal", "mode"
] = "uniform"

loss_weighting class-attribute instance-attribute ¤

loss_weighting: Literal[
    "uniform", "snr", "min_snr", "edm"
] = "uniform"

snr_gamma class-attribute instance-attribute ¤

snr_gamma: float = 5.0

logit_normal_loc class-attribute instance-attribute ¤

logit_normal_loc: float = -0.5

logit_normal_scale class-attribute instance-attribute ¤

logit_normal_scale: float = 1.0

ema_decay class-attribute instance-attribute ¤

ema_decay: float = 0.9999

ema_update_every class-attribute instance-attribute ¤

ema_update_every: int = 10

Configuration Options¤

Parameter Type Default Description
prediction_type str "epsilon" Model prediction: "epsilon", "v_prediction", "x_start"
timestep_sampling str "uniform" Timestep distribution: "uniform", "logit_normal", "mode"
loss_weighting str "uniform" Loss weighting: "uniform", "snr", "min_snr", "edm"
snr_gamma float 5.0 Gamma for min-SNR weighting
logit_normal_loc float -0.5 Logit-normal location parameter
logit_normal_scale float 1.0 Logit-normal scale parameter
ema_decay float 0.9999 EMA decay rate
ema_update_every int 10 Steps between EMA updates

Prediction Types¤

Epsilon Prediction (DDPM)¤

The classic approach - model predicts the noise added:

config = DiffusionTrainingConfig(prediction_type="epsilon")
# Target: noise that was added to x_0

V-Prediction¤

Model predicts v = sqrt(alpha) noise - sqrt(1-alpha) x_0:

config = DiffusionTrainingConfig(prediction_type="v_prediction")
# Provides more consistent gradients across timesteps
# Used in Stable Diffusion 3 and Imagen Video

V-prediction often leads to faster convergence and better sample quality.

X-Start Prediction¤

Model directly predicts the clean data:

config = DiffusionTrainingConfig(prediction_type="x_start")
# Target: original clean data x_0

Timestep Sampling Strategies¤

Uniform Sampling¤

Standard uniform sampling over all timesteps:

config = DiffusionTrainingConfig(timestep_sampling="uniform")
# Equal probability for all timesteps

Logit-Normal Sampling¤

Favors middle timesteps where learning is most effective:

config = DiffusionTrainingConfig(
    timestep_sampling="logit_normal",
    logit_normal_loc=-0.5,
    logit_normal_scale=1.0,
)
# Used in Stable Diffusion 3 for improved convergence

Mode-Seeking Sampling¤

Favors high-noise timesteps for improved generation quality:

config = DiffusionTrainingConfig(timestep_sampling="mode")
# Quadratic bias toward larger timesteps

Loss Weighting Schemes¤

Uniform Weighting¤

No weighting - all timesteps contribute equally:

config = DiffusionTrainingConfig(loss_weighting="uniform")

SNR Weighting¤

Weight by signal-to-noise ratio:

config = DiffusionTrainingConfig(loss_weighting="snr")
# weight = alpha / (1 - alpha)

Min-SNR-Gamma Weighting¤

Clips high SNR weights for 3.4x faster convergence:

config = DiffusionTrainingConfig(
    loss_weighting="min_snr",
    snr_gamma=5.0,
)
# weight = min(SNR, gamma) / SNR
# Down-weights low-noise timesteps where SNR is high

Min-SNR-gamma is the recommended weighting scheme for most use cases.

EDM Weighting¤

EDM-style weighting based on sigma:

config = DiffusionTrainingConfig(loss_weighting="edm")
# weight = 1 / (sigma^2 + 1)

EMA (Exponential Moving Average)¤

Maintain an EMA of model parameters for stable inference:

config = DiffusionTrainingConfig(
    ema_decay=0.9999,
    ema_update_every=10,
)

# After training, get EMA parameters
ema_params = trainer.get_ema_params()

# Apply EMA params to model for inference
from flax import nnx
nnx.update(model, ema_params)

API Reference¤

artifex.generative_models.training.trainers.diffusion_trainer.DiffusionTrainer ¤

DiffusionTrainer(
    noise_schedule: NoiseScheduleProtocol,
    config: DiffusionTrainingConfig | None = None,
)

Diffusion model trainer with modern training techniques.

This trainer provides a JIT-compatible interface for training diffusion models with state-of-the-art techniques. The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.

Features
  • Multiple prediction types (epsilon, v, x0)
  • Non-uniform timestep sampling (logit-normal, mode-seeking)
  • Loss weighting (SNR, min-SNR, EDM)
  • EMA model updates (call update_ema separately, outside JIT)

Example (non-JIT):

from artifex.generative_models.training.trainers import (
    DiffusionTrainer,
    DiffusionTrainingConfig,
)

# Create trainer with noise schedule and config
config = DiffusionTrainingConfig(
    prediction_type="v_prediction",
    timestep_sampling="logit_normal",
    loss_weighting="min_snr",
)
trainer = DiffusionTrainer(noise_schedule, config)

# Create model and optimizer separately
model = DDPMModel(model_config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(1e-4))

# Training loop
for batch in data:
    rng, step_rng = jax.random.split(rng)
    loss, metrics = trainer.train_step(model, optimizer, batch, step_rng)
    trainer.update_ema(model)  # EMA updates outside train_step

Example (JIT-compiled):

trainer = DiffusionTrainer(noise_schedule, config)
jit_step = nnx.jit(trainer.train_step)

for batch in data:
    rng, step_rng = jax.random.split(rng)
    loss, metrics = jit_step(model, optimizer, batch, step_rng)
    trainer.update_ema(model)  # Outside JIT

Parameters:

Name Type Description Default
noise_schedule NoiseScheduleProtocol

Noise schedule with alphas_cumprod and add_noise.

required
config DiffusionTrainingConfig | None

Diffusion training configuration.

None

noise_schedule instance-attribute ¤

noise_schedule = noise_schedule

config instance-attribute ¤

config = config or DiffusionTrainingConfig()

sample_timesteps ¤

sample_timesteps(batch_size: int, key: Array) -> Array

Sample timesteps according to configured strategy.

Parameters:

Name Type Description Default
batch_size int

Number of timesteps to sample.

required
key Array

PRNG key for random sampling.

required

Returns:

Type Description
Array

Integer timesteps array of shape (batch_size,).

get_loss_weight ¤

get_loss_weight(t: Array) -> Array

Compute loss weight for given timesteps.

Parameters:

Name Type Description Default
t Array

Integer timesteps array.

required

Returns:

Type Description
Array

Loss weights for each timestep.

compute_target ¤

compute_target(x0: Array, noise: Array, t: Array) -> Array

Compute prediction target based on prediction type.

Parameters:

Name Type Description Default
x0 Array

Original clean data.

required
noise Array

Added noise.

required
t Array

Timesteps.

required

Returns:

Type Description
Array

Target for the model prediction.

compute_loss ¤

compute_loss(
    model: Module, batch: dict[str, Any], key: Array
) -> tuple[Array, dict[str, Any]]

Compute diffusion training loss.

Parameters:

Name Type Description Default
model Module

Diffusion model to evaluate.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
key Array

PRNG key for sampling noise and timesteps.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (total_loss, metrics_dict).

train_step ¤

train_step(
    model: Module,
    optimizer: Optimizer,
    batch: dict[str, Any],
    key: Array,
) -> tuple[Array, dict[str, Any]]

Execute a single training step.

This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, key)

Note: Call update_ema() separately after train_step for EMA updates.

Parameters:

Name Type Description Default
model Module

Diffusion model to train.

required
optimizer Optimizer

NNX optimizer for parameter updates.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
key Array

PRNG key for sampling.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

update_ema ¤

update_ema(model: Module) -> None

Update EMA parameters.

Call this method separately after train_step, outside of JIT.

Parameters:

Name Type Description Default
model Module

The model whose parameters to use for EMA update.

required

get_ema_params ¤

get_ema_params(model: Module) -> Any

Get EMA parameters for inference.

Parameters:

Name Type Description Default
model Module

The model to get fallback state from if EMA not initialized.

required

Returns:

Type Description
Any

EMA parameters, or current model state if EMA not initialized.

create_loss_fn ¤

create_loss_fn() -> Callable[
    [Module, dict[str, Any], Array],
    tuple[Array, dict[str, Any]],
]

Create loss function compatible with base Trainer.

This enables integration with the base Trainer for callbacks, checkpointing, logging, and other training infrastructure.

Returns:

Type Description
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]

Function with signature: (model, batch, rng) -> (loss, metrics)

Noise Schedule Protocol¤

The trainer works with any noise schedule implementing the NoiseScheduleProtocol:

from typing import Protocol
import jax

class NoiseScheduleProtocol(Protocol):
    """Protocol for noise schedules used by diffusion trainers."""

    num_timesteps: int
    alphas_cumprod: jax.Array

    def add_noise(
        self,
        x: jax.Array,
        noise: jax.Array,
        t: jax.Array,
    ) -> jax.Array:
        """Add noise to data at given timesteps."""
        ...

Artifex provides several noise schedule implementations:

from artifex.generative_models.core.noise_schedules import (
    LinearNoiseSchedule,
    CosineNoiseSchedule,
    SquaredCosineNoiseSchedule,
)

# Linear schedule (DDPM default)
schedule = LinearNoiseSchedule(num_timesteps=1000)

# Cosine schedule (improved for images)
schedule = CosineNoiseSchedule(num_timesteps=1000)

Integration with Base Trainer¤

Use create_loss_fn() for integration with callbacks and checkpointing:

from artifex.generative_models.training import Trainer
from artifex.generative_models.training.trainers import (
    DiffusionTrainer,
    DiffusionTrainingConfig,
)
from artifex.generative_models.training.callbacks import (
    ModelCheckpoint,
    CheckpointConfig,
)

# Create diffusion trainer
diff_config = DiffusionTrainingConfig(
    prediction_type="v_prediction",
    loss_weighting="min_snr",
)
diff_trainer = DiffusionTrainer(model, optimizer, noise_schedule, diff_config)

# Get loss function for base Trainer
loss_fn = diff_trainer.create_loss_fn()

# Use with base Trainer for callbacks
callbacks = [
    ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="loss")),
]

Model Requirements¤

The Diffusion Trainer expects models with the following interface:

class DiffusionModel(nnx.Module):
    def __call__(
        self,
        x_noisy: jax.Array,
        t: jax.Array,
    ) -> jax.Array:
        """Predict noise/v/x_0 from noisy input.

        Args:
            x_noisy: Noisy data, shape (batch, ...).
            t: Integer timesteps, shape (batch,).

        Returns:
            Prediction matching prediction_type, shape (batch, ...).
        """
        ...

Training Metrics¤

Metric Description
loss Weighted loss (affected by loss_weighting)
loss_unweighted Raw MSE loss without weighting

High-Quality Image Generation¤

config = DiffusionTrainingConfig(
    prediction_type="v_prediction",
    timestep_sampling="logit_normal",
    loss_weighting="min_snr",
    snr_gamma=5.0,
    ema_decay=0.9999,
)

Fast Training¤

config = DiffusionTrainingConfig(
    prediction_type="epsilon",
    timestep_sampling="uniform",
    loss_weighting="min_snr",
    snr_gamma=5.0,
)

Large Models (EDM-style)¤

config = DiffusionTrainingConfig(
    prediction_type="epsilon",
    timestep_sampling="logit_normal",
    loss_weighting="edm",
    ema_decay=0.9999,
    ema_update_every=1,
)

References¤