Flow Trainer¤
Module: artifex.generative_models.training.trainers.flow_trainer
The Flow Trainer provides specialized training utilities for flow matching models, including Conditional Flow Matching (CFM), Optimal Transport CFM (OT-CFM), and various time sampling strategies.
Overview¤
Flow matching enables simulation-free training of continuous normalizing flows. The Flow Trainer provides:
- Flow Types: Standard CFM, OT-CFM, and Rectified Flow
- Time Sampling: Uniform, logit-normal, and U-shaped strategies
- Linear Interpolation: Straight paths from noise to data
- Minimal Noise: Configurable sigma_min for path endpoints
Quick Start¤
from artifex.generative_models.training.trainers import (
FlowTrainer,
FlowTrainingConfig,
)
from flax import nnx
import optax
import jax
# Create model and optimizer
model = create_flow_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
# Configure flow matching training
config = FlowTrainingConfig(
flow_type="cfm",
time_sampling="logit_normal",
sigma_min=0.001,
)
trainer = FlowTrainer(model, optimizer, config)
# Training loop
key = jax.random.key(0)
for step, batch in enumerate(train_loader):
key, subkey = jax.random.split(key)
loss, metrics = trainer.train_step(batch, subkey)
if step % 100 == 0:
print(f"Step {step}: loss={metrics['loss']:.4f}")
Configuration¤
artifex.generative_models.training.trainers.flow_trainer.FlowTrainingConfig
dataclass
¤
FlowTrainingConfig(
flow_type: Literal[
"cfm", "ot_cfm", "rectified_flow"
] = "cfm",
time_sampling: Literal[
"uniform", "logit_normal", "u_shaped"
] = "uniform",
sigma_min: float = 0.001,
use_ot: bool = False,
ot_regularization: float = 0.01,
logit_normal_loc: float = 0.0,
logit_normal_scale: float = 1.0,
)
Configuration for flow matching training.
Attributes:
| Name | Type | Description |
|---|---|---|
flow_type |
Literal['cfm', 'ot_cfm', 'rectified_flow']
|
Type of flow matching. - "cfm": Standard Conditional Flow Matching - "ot_cfm": Optimal Transport CFM for straighter paths - "rectified_flow": Rectified Flow for straighter paths |
time_sampling |
Literal['uniform', 'logit_normal', 'u_shaped']
|
How to sample time values during training. - "uniform": Uniform sampling in [0, 1] - "logit_normal": Logit-normal (favors middle times) - "u_shaped": U-shaped (favors endpoints, good for rectified flows) |
sigma_min |
float
|
Minimum noise level for the Gaussian path. |
use_ot |
bool
|
Whether to use optimal transport coupling. |
ot_regularization |
float
|
Regularization for OT (Sinkhorn epsilon). |
logit_normal_loc |
float
|
Location parameter for logit-normal sampling. |
logit_normal_scale |
float
|
Scale parameter for logit-normal sampling. |
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
flow_type |
str |
"cfm" |
Flow type: "cfm", "ot_cfm", "rectified_flow" |
time_sampling |
str |
"uniform" |
Time distribution: "uniform", "logit_normal", "u_shaped" |
sigma_min |
float |
0.001 |
Minimum noise level for paths |
use_ot |
bool |
False |
Enable optimal transport coupling |
ot_regularization |
float |
0.01 |
Sinkhorn regularization for OT |
logit_normal_loc |
float |
0.0 |
Logit-normal location parameter |
logit_normal_scale |
float |
1.0 |
Logit-normal scale parameter |
Flow Types¤
Conditional Flow Matching (CFM)¤
Standard CFM with linear interpolation paths:
The interpolation path is defined as:
where \(x_0\) is noise and \(x_1\) is data.
Optimal Transport CFM (OT-CFM)¤
CFM with optimal transport coupling for straighter paths:
config = FlowTrainingConfig(
flow_type="ot_cfm",
use_ot=True,
ot_regularization=0.01,
)
# Uses minibatch OT to pair noise and data samples
Rectified Flow¤
Straighten paths through reflow iterations:
config = FlowTrainingConfig(flow_type="rectified_flow")
# Single reflow iteration typically sufficient
Time Sampling Strategies¤
Uniform Sampling¤
Standard uniform sampling in [0, 1]:
Logit-Normal Sampling¤
Favors middle time values for improved convergence:
config = FlowTrainingConfig(
time_sampling="logit_normal",
logit_normal_loc=0.0,
logit_normal_scale=1.0,
)
U-Shaped Sampling¤
Favors endpoints (t=0 and t=1), useful for rectified flows:
config = FlowTrainingConfig(time_sampling="u_shaped")
# More samples near 0 and 1 where endpoint behavior is critical
U-shaped sampling is computed as:
where \(u \sim \text{Uniform}(0, 1)\).
API Reference¤
artifex.generative_models.training.trainers.flow_trainer.FlowTrainer
¤
FlowTrainer(config: FlowTrainingConfig | None = None)
Flow matching trainer with modern training techniques.
This trainer provides a JIT-compatible interface for training flow matching models. The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.
Features
- Multiple flow types (CFM, OT-CFM, Rectified Flow)
- Non-uniform time sampling (logit-normal, u-shaped)
- Optimal transport coupling support
- DRY integration with base Trainer via create_loss_fn()
The flow matching objective learns a velocity field v_theta(x_t, t) that transports samples from noise distribution to data distribution along straight paths in probability space.
Example (non-JIT):
from artifex.generative_models.training.trainers import (
FlowTrainer,
FlowTrainingConfig,
)
config = FlowTrainingConfig(
flow_type="cfm",
time_sampling="logit_normal",
)
trainer = FlowTrainer(config)
# Create model and optimizer separately
model = FlowModel(config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
# Training loop
for batch in data:
rng, step_rng = jax.random.split(rng)
loss, metrics = trainer.train_step(model, optimizer, batch, step_rng)
Example (JIT-compiled):
trainer = FlowTrainer(config)
jit_step = nnx.jit(trainer.train_step)
for batch in data:
rng, step_rng = jax.random.split(rng)
loss, metrics = jit_step(model, optimizer, batch, step_rng)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
FlowTrainingConfig | None
|
Flow training configuration. |
None
|
sample_time
¤
compute_conditional_vector_field
¤
Compute interpolated point and target vector field.
For linear interpolation path
x_t = (1 - t) * x0 + t * x1 u_t = x1 - x0 (constant velocity)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x0
|
Array
|
Source samples (noise), shape (batch, ...). |
required |
x1
|
Array
|
Target samples (data), shape (batch, ...). |
required |
t
|
Array
|
Time values, shape (batch, 1). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple of (x_t, u_t) where: - x_t: Interpolated points, shape (batch, ...) - u_t: Target velocity field, shape (batch, ...) |
compute_loss
¤
Compute flow matching loss.
The loss is the MSE between predicted and target velocity
L = E_{t, x0, x1} || v_theta(x_t, t) - u_t ||^2
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flow model (velocity field) to evaluate. |
required |
batch
|
dict[str, Any]
|
Batch dictionary with "image" or "data" key. |
required |
key
|
Array
|
PRNG key for sampling noise and time. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (total_loss, metrics_dict). |
train_step
¤
train_step(
model: Module,
optimizer: Optimizer,
batch: dict[str, Any],
key: Array,
) -> tuple[Array, dict[str, Any]]
Execute a single training step.
This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, key)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flow model to train. |
required |
optimizer
|
Optimizer
|
NNX optimizer for parameter updates. |
required |
batch
|
dict[str, Any]
|
Batch dictionary with "image" or "data" key. |
required |
key
|
Array
|
PRNG key for sampling. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
create_loss_fn
¤
Create loss function compatible with base Trainer.
This enables integration with the base Trainer for callbacks, checkpointing, logging, and other training infrastructure.
Returns:
| Type | Description |
|---|---|
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]
|
Function with signature: (model, batch, rng) -> (loss, metrics) |
Flow Matching Theory¤
Flow matching learns a velocity field \(v_\theta(x_t, t)\) that transports samples from noise distribution to data distribution.
Training Objective¤
The CFM loss is:
where:
- \(x_0 \sim \mathcal{N}(0, I)\) (source noise)
- \(x_1 \sim p_{\text{data}}\) (target data)
- \(x_t = (1-t) x_0 + t x_1\) (interpolated point)
- \(u_t = x_1 - x_0\) (target velocity)
Sampling¤
Generate samples by solving the ODE from t=0 to t=1:
Integration with Base Trainer¤
Use create_loss_fn() for integration with callbacks and checkpointing:
from artifex.generative_models.training import Trainer
from artifex.generative_models.training.trainers import FlowTrainer, FlowTrainingConfig
from artifex.generative_models.training.callbacks import (
EarlyStopping,
EarlyStoppingConfig,
ModelCheckpoint,
CheckpointConfig,
)
# Create flow trainer
flow_config = FlowTrainingConfig(
flow_type="cfm",
time_sampling="logit_normal",
)
flow_trainer = FlowTrainer(model, optimizer, flow_config)
# Get loss function for base Trainer
loss_fn = flow_trainer.create_loss_fn()
# Use with callbacks
callbacks = [
EarlyStopping(EarlyStoppingConfig(monitor="loss", patience=10)),
ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="loss")),
]
Model Requirements¤
The Flow Trainer expects models with the following interface:
class FlowModel(nnx.Module):
def __call__(
self,
x_t: jax.Array,
t: jax.Array,
) -> jax.Array:
"""Predict velocity at (x_t, t).
Args:
x_t: Points along flow path, shape (batch, ...).
t: Time values in [0, 1], shape (batch,).
Returns:
Predicted velocity field, shape (batch, ...).
"""
...
Training Metrics¤
| Metric | Description |
|---|---|
loss |
MSE between predicted and target velocity |
Recommended Configurations¤
Standard CFM Training¤
High-Quality Generation¤
config = FlowTrainingConfig(
flow_type="cfm",
time_sampling="logit_normal",
logit_normal_loc=0.0,
logit_normal_scale=1.0,
)
Rectified Flow¤
Sampling from Trained Models¤
After training, generate samples using ODE integration:
from jax.experimental.ode import odeint
import jax.numpy as jnp
def sample(model, shape, key, num_steps=100):
"""Generate samples from trained flow model."""
# Start from noise
x_0 = jax.random.normal(key, shape)
# Define ODE function
def velocity_fn(x, t):
t_batch = jnp.full((x.shape[0],), t)
return model(x, t_batch)
# Integrate from t=0 to t=1
ts = jnp.linspace(0, 1, num_steps)
trajectory = odeint(velocity_fn, x_0, ts)
# Return final sample at t=1
return trajectory[-1]
# Generate samples
samples = sample(model, (batch_size, *data_shape), key)