Skip to content

Flow Trainer¤

Module: artifex.generative_models.training.trainers.flow_trainer

The Flow Trainer provides specialized training utilities for flow matching models, including Conditional Flow Matching (CFM), Optimal Transport CFM (OT-CFM), and various time sampling strategies.

Overview¤

Flow matching enables simulation-free training of continuous normalizing flows. The Flow Trainer provides:

  • Flow Types: Standard CFM, OT-CFM, and Rectified Flow
  • Time Sampling: Uniform, logit-normal, and U-shaped strategies
  • Linear Interpolation: Straight paths from noise to data
  • Minimal Noise: Configurable sigma_min for path endpoints

Quick Start¤

from artifex.generative_models.training.trainers import (
    FlowTrainer,
    FlowTrainingConfig,
)
from flax import nnx
import optax
import jax

# Create model and optimizer
model = create_flow_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)

# Configure flow matching training
config = FlowTrainingConfig(
    flow_type="cfm",
    time_sampling="logit_normal",
    sigma_min=0.001,
)

trainer = FlowTrainer(model, optimizer, config)

# Training loop
key = jax.random.key(0)

for step, batch in enumerate(train_loader):
    key, subkey = jax.random.split(key)
    loss, metrics = trainer.train_step(batch, subkey)

    if step % 100 == 0:
        print(f"Step {step}: loss={metrics['loss']:.4f}")

Configuration¤

artifex.generative_models.training.trainers.flow_trainer.FlowTrainingConfig dataclass ¤

FlowTrainingConfig(
    flow_type: Literal[
        "cfm", "ot_cfm", "rectified_flow"
    ] = "cfm",
    time_sampling: Literal[
        "uniform", "logit_normal", "u_shaped"
    ] = "uniform",
    sigma_min: float = 0.001,
    use_ot: bool = False,
    ot_regularization: float = 0.01,
    logit_normal_loc: float = 0.0,
    logit_normal_scale: float = 1.0,
)

Configuration for flow matching training.

Attributes:

Name Type Description
flow_type Literal['cfm', 'ot_cfm', 'rectified_flow']

Type of flow matching. - "cfm": Standard Conditional Flow Matching - "ot_cfm": Optimal Transport CFM for straighter paths - "rectified_flow": Rectified Flow for straighter paths

time_sampling Literal['uniform', 'logit_normal', 'u_shaped']

How to sample time values during training. - "uniform": Uniform sampling in [0, 1] - "logit_normal": Logit-normal (favors middle times) - "u_shaped": U-shaped (favors endpoints, good for rectified flows)

sigma_min float

Minimum noise level for the Gaussian path.

use_ot bool

Whether to use optimal transport coupling.

ot_regularization float

Regularization for OT (Sinkhorn epsilon).

logit_normal_loc float

Location parameter for logit-normal sampling.

logit_normal_scale float

Scale parameter for logit-normal sampling.

flow_type class-attribute instance-attribute ¤

flow_type: Literal["cfm", "ot_cfm", "rectified_flow"] = (
    "cfm"
)

time_sampling class-attribute instance-attribute ¤

time_sampling: Literal[
    "uniform", "logit_normal", "u_shaped"
] = "uniform"

sigma_min class-attribute instance-attribute ¤

sigma_min: float = 0.001

use_ot class-attribute instance-attribute ¤

use_ot: bool = False

ot_regularization class-attribute instance-attribute ¤

ot_regularization: float = 0.01

logit_normal_loc class-attribute instance-attribute ¤

logit_normal_loc: float = 0.0

logit_normal_scale class-attribute instance-attribute ¤

logit_normal_scale: float = 1.0

Configuration Options¤

Parameter Type Default Description
flow_type str "cfm" Flow type: "cfm", "ot_cfm", "rectified_flow"
time_sampling str "uniform" Time distribution: "uniform", "logit_normal", "u_shaped"
sigma_min float 0.001 Minimum noise level for paths
use_ot bool False Enable optimal transport coupling
ot_regularization float 0.01 Sinkhorn regularization for OT
logit_normal_loc float 0.0 Logit-normal location parameter
logit_normal_scale float 1.0 Logit-normal scale parameter

Flow Types¤

Conditional Flow Matching (CFM)¤

Standard CFM with linear interpolation paths:

config = FlowTrainingConfig(flow_type="cfm")
# Learns velocity field: v(x_t, t) = x_1 - x_0

The interpolation path is defined as:

\[x_t = (1 - t) x_0 + t x_1\]

where \(x_0\) is noise and \(x_1\) is data.

Optimal Transport CFM (OT-CFM)¤

CFM with optimal transport coupling for straighter paths:

config = FlowTrainingConfig(
    flow_type="ot_cfm",
    use_ot=True,
    ot_regularization=0.01,
)
# Uses minibatch OT to pair noise and data samples

Rectified Flow¤

Straighten paths through reflow iterations:

config = FlowTrainingConfig(flow_type="rectified_flow")
# Single reflow iteration typically sufficient

Time Sampling Strategies¤

Uniform Sampling¤

Standard uniform sampling in [0, 1]:

config = FlowTrainingConfig(time_sampling="uniform")
# Equal probability across all time values

Logit-Normal Sampling¤

Favors middle time values for improved convergence:

config = FlowTrainingConfig(
    time_sampling="logit_normal",
    logit_normal_loc=0.0,
    logit_normal_scale=1.0,
)

U-Shaped Sampling¤

Favors endpoints (t=0 and t=1), useful for rectified flows:

config = FlowTrainingConfig(time_sampling="u_shaped")
# More samples near 0 and 1 where endpoint behavior is critical

U-shaped sampling is computed as:

\[t = \sin^2(\pi u / 2)\]

where \(u \sim \text{Uniform}(0, 1)\).

API Reference¤

artifex.generative_models.training.trainers.flow_trainer.FlowTrainer ¤

FlowTrainer(config: FlowTrainingConfig | None = None)

Flow matching trainer with modern training techniques.

This trainer provides a JIT-compatible interface for training flow matching models. The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.

Features
  • Multiple flow types (CFM, OT-CFM, Rectified Flow)
  • Non-uniform time sampling (logit-normal, u-shaped)
  • Optimal transport coupling support
  • DRY integration with base Trainer via create_loss_fn()

The flow matching objective learns a velocity field v_theta(x_t, t) that transports samples from noise distribution to data distribution along straight paths in probability space.

Example (non-JIT):

from artifex.generative_models.training.trainers import (
    FlowTrainer,
    FlowTrainingConfig,
)

config = FlowTrainingConfig(
    flow_type="cfm",
    time_sampling="logit_normal",
)
trainer = FlowTrainer(config)

# Create model and optimizer separately
model = FlowModel(config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

# Training loop
for batch in data:
    rng, step_rng = jax.random.split(rng)
    loss, metrics = trainer.train_step(model, optimizer, batch, step_rng)

Example (JIT-compiled):

trainer = FlowTrainer(config)
jit_step = nnx.jit(trainer.train_step)

for batch in data:
    rng, step_rng = jax.random.split(rng)
    loss, metrics = jit_step(model, optimizer, batch, step_rng)

Parameters:

Name Type Description Default
config FlowTrainingConfig | None

Flow training configuration.

None

config instance-attribute ¤

config = config or FlowTrainingConfig()

sample_time ¤

sample_time(batch_size: int, key: Array) -> Array

Sample time values according to configured strategy.

Parameters:

Name Type Description Default
batch_size int

Number of time values to sample.

required
key Array

PRNG key for random sampling.

required

Returns:

Type Description
Array

Time values array of shape (batch_size, 1) in [0, 1].

compute_conditional_vector_field ¤

compute_conditional_vector_field(
    x0: Array, x1: Array, t: Array
) -> tuple[Array, Array]

Compute interpolated point and target vector field.

For linear interpolation path

x_t = (1 - t) * x0 + t * x1 u_t = x1 - x0 (constant velocity)

Parameters:

Name Type Description Default
x0 Array

Source samples (noise), shape (batch, ...).

required
x1 Array

Target samples (data), shape (batch, ...).

required
t Array

Time values, shape (batch, 1).

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (x_t, u_t) where: - x_t: Interpolated points, shape (batch, ...) - u_t: Target velocity field, shape (batch, ...)

compute_loss ¤

compute_loss(
    model: Module, batch: dict[str, Any], key: Array
) -> tuple[Array, dict[str, Any]]

Compute flow matching loss.

The loss is the MSE between predicted and target velocity

L = E_{t, x0, x1} || v_theta(x_t, t) - u_t ||^2

Parameters:

Name Type Description Default
model Module

Flow model (velocity field) to evaluate.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
key Array

PRNG key for sampling noise and time.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (total_loss, metrics_dict).

train_step ¤

train_step(
    model: Module,
    optimizer: Optimizer,
    batch: dict[str, Any],
    key: Array,
) -> tuple[Array, dict[str, Any]]

Execute a single training step.

This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, key)

Parameters:

Name Type Description Default
model Module

Flow model to train.

required
optimizer Optimizer

NNX optimizer for parameter updates.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
key Array

PRNG key for sampling.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

create_loss_fn ¤

create_loss_fn() -> Callable[
    [Module, dict[str, Any], Array],
    tuple[Array, dict[str, Any]],
]

Create loss function compatible with base Trainer.

This enables integration with the base Trainer for callbacks, checkpointing, logging, and other training infrastructure.

Returns:

Type Description
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]

Function with signature: (model, batch, rng) -> (loss, metrics)

Flow Matching Theory¤

Flow matching learns a velocity field \(v_\theta(x_t, t)\) that transports samples from noise distribution to data distribution.

Training Objective¤

The CFM loss is:

\[\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \|v_\theta(x_t, t) - u_t\|^2\]

where:

  • \(x_0 \sim \mathcal{N}(0, I)\) (source noise)
  • \(x_1 \sim p_{\text{data}}\) (target data)
  • \(x_t = (1-t) x_0 + t x_1\) (interpolated point)
  • \(u_t = x_1 - x_0\) (target velocity)

Sampling¤

Generate samples by solving the ODE from t=0 to t=1:

\[\frac{dx}{dt} = v_\theta(x, t)\]

Integration with Base Trainer¤

Use create_loss_fn() for integration with callbacks and checkpointing:

from artifex.generative_models.training import Trainer
from artifex.generative_models.training.trainers import FlowTrainer, FlowTrainingConfig
from artifex.generative_models.training.callbacks import (
    EarlyStopping,
    EarlyStoppingConfig,
    ModelCheckpoint,
    CheckpointConfig,
)

# Create flow trainer
flow_config = FlowTrainingConfig(
    flow_type="cfm",
    time_sampling="logit_normal",
)
flow_trainer = FlowTrainer(model, optimizer, flow_config)

# Get loss function for base Trainer
loss_fn = flow_trainer.create_loss_fn()

# Use with callbacks
callbacks = [
    EarlyStopping(EarlyStoppingConfig(monitor="loss", patience=10)),
    ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="loss")),
]

Model Requirements¤

The Flow Trainer expects models with the following interface:

class FlowModel(nnx.Module):
    def __call__(
        self,
        x_t: jax.Array,
        t: jax.Array,
    ) -> jax.Array:
        """Predict velocity at (x_t, t).

        Args:
            x_t: Points along flow path, shape (batch, ...).
            t: Time values in [0, 1], shape (batch,).

        Returns:
            Predicted velocity field, shape (batch, ...).
        """
        ...

Training Metrics¤

Metric Description
loss MSE between predicted and target velocity

Standard CFM Training¤

config = FlowTrainingConfig(
    flow_type="cfm",
    time_sampling="uniform",
    sigma_min=0.001,
)

High-Quality Generation¤

config = FlowTrainingConfig(
    flow_type="cfm",
    time_sampling="logit_normal",
    logit_normal_loc=0.0,
    logit_normal_scale=1.0,
)

Rectified Flow¤

config = FlowTrainingConfig(
    flow_type="rectified_flow",
    time_sampling="u_shaped",
)

Sampling from Trained Models¤

After training, generate samples using ODE integration:

from jax.experimental.ode import odeint
import jax.numpy as jnp

def sample(model, shape, key, num_steps=100):
    """Generate samples from trained flow model."""
    # Start from noise
    x_0 = jax.random.normal(key, shape)

    # Define ODE function
    def velocity_fn(x, t):
        t_batch = jnp.full((x.shape[0],), t)
        return model(x, t_batch)

    # Integrate from t=0 to t=1
    ts = jnp.linspace(0, 1, num_steps)
    trajectory = odeint(velocity_fn, x_0, ts)

    # Return final sample at t=1
    return trajectory[-1]

# Generate samples
samples = sample(model, (batch_size, *data_shape), key)

References¤