Skip to content

VAE Trainer¤

Module: artifex.generative_models.training.trainers.vae_trainer

The VAE Trainer provides specialized training utilities for Variational Autoencoders, including KL divergence annealing schedules, beta-VAE weighting for disentanglement, and free bits constraints to prevent posterior collapse.

Overview¤

Training VAEs requires balancing reconstruction quality against latent space regularization. The VAE Trainer handles this balance through:

  • KL Annealing: Gradual increase of KL weight to prevent posterior collapse
  • Beta-VAE Weighting: Control disentanglement vs reconstruction trade-off
  • Free Bits Constraint: Minimum KL per dimension to ensure information flow

Quick Start¤

from artifex.generative_models.training.trainers import (
    VAETrainer,
    VAETrainingConfig,
)
from flax import nnx
import optax

# Create model and optimizer
model = create_vae_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)

# Configure VAE-specific training
config = VAETrainingConfig(
    kl_annealing="cyclical",
    kl_warmup_steps=5000,
    beta=4.0,
    free_bits=0.5,
)

trainer = VAETrainer(model, optimizer, config)

# Training loop
for step, batch in enumerate(train_loader):
    loss, metrics = trainer.train_step(batch, step=step)
    if step % 100 == 0:
        print(f"Step {step}: loss={metrics['loss']:.4f}, "
              f"recon={metrics['recon_loss']:.4f}, kl={metrics['kl_loss']:.4f}")

Configuration¤

artifex.generative_models.training.trainers.vae_trainer.VAETrainingConfig dataclass ¤

VAETrainingConfig(
    kl_annealing: Literal[
        "none", "linear", "sigmoid", "cyclical"
    ] = "linear",
    kl_warmup_steps: int = 10000,
    beta: float = 1.0,
    free_bits: float = 0.0,
    cyclical_period: int = 10000,
)

Configuration for VAE-specific training.

Attributes:

Name Type Description
kl_annealing Literal['none', 'linear', 'sigmoid', 'cyclical']

Type of KL annealing schedule. - "none": No annealing, use full beta from start - "linear": Linear warmup from 0 to beta - "sigmoid": Sigmoid-shaped warmup - "cyclical": Cyclical annealing with periodic resets

kl_warmup_steps int

Number of steps to reach full KL weight.

beta float

Final beta weight for KL term (beta-VAE). Higher values encourage disentanglement but may hurt reconstruction.

free_bits float

Minimum KL per latent dimension (0 = disabled). Prevents posterior collapse by ensuring minimum information flow.

cyclical_period int

Period for cyclical annealing (if used).

kl_annealing class-attribute instance-attribute ¤

kl_annealing: Literal[
    "none", "linear", "sigmoid", "cyclical"
] = "linear"

kl_warmup_steps class-attribute instance-attribute ¤

kl_warmup_steps: int = 10000

beta class-attribute instance-attribute ¤

beta: float = 1.0

free_bits class-attribute instance-attribute ¤

free_bits: float = 0.0

cyclical_period class-attribute instance-attribute ¤

cyclical_period: int = 10000

Configuration Options¤

Parameter Type Default Description
kl_annealing str "linear" KL weight schedule: "none", "linear", "sigmoid", "cyclical"
kl_warmup_steps int 10000 Steps to reach full KL weight
beta float 1.0 Final KL weight (beta-VAE parameter)
free_bits float 0.0 Minimum KL per latent dimension
cyclical_period int 10000 Period for cyclical annealing

KL Annealing Schedules¤

None (Constant)¤

No annealing - use full beta weight from the start:

config = VAETrainingConfig(kl_annealing="none", beta=1.0)
# KL weight = 1.0 at all steps

Linear Warmup¤

Linearly increase KL weight from 0 to beta:

config = VAETrainingConfig(
    kl_annealing="linear",
    kl_warmup_steps=10000,
    beta=1.0,
)
# KL weight = beta * min(1.0, step / warmup_steps)

Sigmoid Warmup¤

S-shaped warmup curve centered at half the warmup steps:

config = VAETrainingConfig(
    kl_annealing="sigmoid",
    kl_warmup_steps=10000,
    beta=1.0,
)

Cyclical Annealing¤

Periodically reset KL weight to encourage information flow:

config = VAETrainingConfig(
    kl_annealing="cyclical",
    cyclical_period=5000,
    beta=4.0,
)
# KL weight cycles: 0 -> beta -> 0 -> beta -> ...

Cyclical annealing helps prevent posterior collapse by periodically "reopening" information pathways.

Beta-VAE Training¤

Higher beta values encourage disentangled representations at the cost of reconstruction quality:

# Standard VAE (beta=1)
standard_config = VAETrainingConfig(beta=1.0)

# Beta-VAE for disentanglement (beta=4)
disentangled_config = VAETrainingConfig(beta=4.0)

# Strong regularization (beta=10)
strong_reg_config = VAETrainingConfig(beta=10.0)

Free Bits Constraint¤

Prevent posterior collapse by ensuring minimum KL per latent dimension:

config = VAETrainingConfig(
    free_bits=0.5,  # Minimum 0.5 nats per dimension
    beta=1.0,
)

The free bits constraint ensures each latent dimension carries at least the specified amount of information.

API Reference¤

artifex.generative_models.training.trainers.vae_trainer.VAETrainer ¤

VAETrainer(config: VAETrainingConfig | None = None)

VAE-specific trainer with KL annealing and beta-VAE support.

This trainer provides a JIT-compatible interface for training VAEs with: - KL annealing schedules (linear, sigmoid, cyclical) - Beta-VAE weighting for disentanglement - Free bits constraint to prevent posterior collapse

The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.

The trainer computes the ELBO loss with configurable KL weighting

L = reconstruction_loss + beta * kl_weight(step) * kl_loss

Example (non-JIT):

from artifex.generative_models.training.trainers import (
    VAETrainer,
    VAETrainingConfig,
)

config = VAETrainingConfig(
    kl_annealing="cyclical",
    beta=4.0,
    free_bits=0.5,
)
trainer = VAETrainer(config)

# Create model and optimizer separately
model = VAEModel(config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

# Training loop
for step, batch in enumerate(data):
    loss, metrics = trainer.train_step(model, optimizer, batch, step=step)

Example (JIT-compiled):

trainer = VAETrainer(config)
jit_step = nnx.jit(trainer.train_step)

for step, batch in enumerate(data):
    loss, metrics = jit_step(model, optimizer, batch, step=step)

Note

The model is expected to return (reconstruction, mean, logvar) from its forward pass. The trainer handles loss computation and KL annealing.

Parameters:

Name Type Description Default
config VAETrainingConfig | None

VAE training configuration. Uses defaults if not provided.

None

config instance-attribute ¤

config = config or VAETrainingConfig()

get_kl_weight ¤

get_kl_weight(step: int | Array) -> Array

Compute KL weight based on annealing schedule.

This method is JIT-compatible - uses JAX operations instead of Python builtins.

Parameters:

Name Type Description Default
step int | Array

Current training step (can be traced array for JIT).

required

Returns:

Type Description
Array

KL weight multiplier (0.0 to beta).

apply_free_bits ¤

apply_free_bits(kl_per_dim: Array) -> Array

Apply free bits constraint to KL divergence.

Ensures minimum KL per latent dimension to prevent posterior collapse.

Parameters:

Name Type Description Default
kl_per_dim Array

KL divergence per latent dimension.

required

Returns:

Type Description
Array

KL divergence with free bits applied.

compute_kl_loss ¤

compute_kl_loss(
    mean: Array, logvar: Array
) -> tuple[Array, Array]

Compute KL divergence from standard normal.

Parameters:

Name Type Description Default
mean Array

Latent mean, shape (batch, latent_dim).

required
logvar Array

Latent log-variance, shape (batch, latent_dim).

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (total_kl_loss, kl_per_sample) where: - total_kl_loss: Scalar mean KL loss - kl_per_sample: KL loss per sample, shape (batch,)

compute_reconstruction_loss ¤

compute_reconstruction_loss(
    x: Array,
    recon_x: Array,
    loss_type: Literal["mse", "bce"] = "mse",
) -> Array

Compute reconstruction loss.

Parameters:

Name Type Description Default
x Array

Original input, shape (batch, ...).

required
recon_x Array

Reconstructed output, shape (batch, ...).

required
loss_type Literal['mse', 'bce']

Type of reconstruction loss ("mse" or "bce").

'mse'

Returns:

Type Description
Array

Scalar reconstruction loss.

compute_loss ¤

compute_loss(
    model: Module,
    batch: dict[str, Any],
    step: int,
    loss_type: Literal["mse", "bce"] = "mse",
) -> tuple[Array, dict[str, Any]]

Compute VAE loss with KL annealing.

Parameters:

Name Type Description Default
model Module

VAE model to evaluate.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
step int

Current training step for annealing.

required
loss_type Literal['mse', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (total_loss, metrics_dict).

train_step ¤

train_step(
    model: Module,
    optimizer: Optimizer,
    batch: dict[str, Any],
    step: int = 0,
    loss_type: Literal["mse", "bce"] = "mse",
) -> tuple[Array, dict[str, Any]]

Execute a single training step.

This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, step=step)

Parameters:

Name Type Description Default
model Module

VAE model to train.

required
optimizer Optimizer

NNX optimizer for parameter updates.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
step int

Current training step for annealing.

0
loss_type Literal['mse', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

create_loss_fn ¤

create_loss_fn(
    step: int, loss_type: Literal["mse", "bce"] = "mse"
) -> Callable[
    [Module, dict[str, Any], Array],
    tuple[Array, dict[str, Any]],
]

Create loss function compatible with base Trainer.

This enables DRY integration - VAE-specific training logic can be used with the base Trainer infrastructure.

Parameters:

Name Type Description Default
step int

Current training step for KL annealing.

required
loss_type Literal['mse', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Name Type Description
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]

Function with signature: (model, batch, rng) -> (loss, metrics)

Note Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]

rng is accepted for API compatibility but not used by VAE.

Integration with Base Trainer¤

The VAE Trainer provides a create_loss_fn() method for seamless integration with the base Trainer's callbacks, checkpointing, and logging infrastructure:

from artifex.generative_models.training import Trainer
from artifex.generative_models.training.trainers import VAETrainer, VAETrainingConfig
from artifex.generative_models.training.callbacks import (
    EarlyStopping,
    EarlyStoppingConfig,
    ModelCheckpoint,
    CheckpointConfig,
)

# Create VAE-specific trainer
vae_config = VAETrainingConfig(kl_annealing="cyclical", beta=4.0)
vae_trainer = VAETrainer(model, optimizer, vae_config)

# Create loss function for a specific training step
# Note: step is required for KL annealing
def make_loss_fn(step: int):
    return vae_trainer.create_loss_fn(step=step)

# Use with base Trainer for callbacks
callbacks = [
    EarlyStopping(EarlyStoppingConfig(monitor="val_loss", patience=10)),
    ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="val_loss")),
]

Model Requirements¤

The VAE Trainer expects models with the following interface:

class VAEModel(nnx.Module):
    def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
        """Forward pass returning (reconstruction, mean, logvar).

        Args:
            x: Input data, shape (batch, ...).

        Returns:
            Tuple of:
                - recon_x: Reconstructed data, shape (batch, ...)
                - mean: Latent mean, shape (batch, latent_dim)
                - logvar: Latent log-variance, shape (batch, latent_dim)
        """
        ...

Reconstruction Loss Types¤

The trainer supports MSE and BCE reconstruction losses:

# Mean Squared Error (default, for continuous data)
loss, metrics = trainer.train_step(batch, step=100, loss_type="mse")

# Binary Cross-Entropy (for images normalized to [0, 1])
loss, metrics = trainer.train_step(batch, step=100, loss_type="bce")

Training Metrics¤

The trainer returns detailed metrics for monitoring:

Metric Description
loss Total ELBO loss
recon_loss Reconstruction loss
kl_loss KL divergence (unweighted)
kl_weight Current KL weight from annealing

References¤