VAE Trainer¤
Module: artifex.generative_models.training.trainers.vae_trainer
The VAE Trainer provides specialized training utilities for Variational Autoencoders, including KL divergence annealing schedules, beta-VAE weighting for disentanglement, and free bits constraints to prevent posterior collapse.
Overview¤
Training VAEs requires balancing reconstruction quality against latent space regularization. The VAE Trainer handles this balance through:
- KL Annealing: Gradual increase of KL weight to prevent posterior collapse
- Beta-VAE Weighting: Control disentanglement vs reconstruction trade-off
- Free Bits Constraint: Minimum KL per dimension to ensure information flow
Quick Start¤
from artifex.generative_models.training.trainers import (
VAETrainer,
VAETrainingConfig,
)
from flax import nnx
import optax
# Create model and optimizer
model = create_vae_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
# Configure VAE-specific training
config = VAETrainingConfig(
kl_annealing="cyclical",
kl_warmup_steps=5000,
beta=4.0,
free_bits=0.5,
)
trainer = VAETrainer(model, optimizer, config)
# Training loop
for step, batch in enumerate(train_loader):
loss, metrics = trainer.train_step(batch, step=step)
if step % 100 == 0:
print(f"Step {step}: loss={metrics['loss']:.4f}, "
f"recon={metrics['recon_loss']:.4f}, kl={metrics['kl_loss']:.4f}")
Configuration¤
artifex.generative_models.training.trainers.vae_trainer.VAETrainingConfig
dataclass
¤
VAETrainingConfig(
kl_annealing: Literal[
"none", "linear", "sigmoid", "cyclical"
] = "linear",
kl_warmup_steps: int = 10000,
beta: float = 1.0,
free_bits: float = 0.0,
cyclical_period: int = 10000,
)
Configuration for VAE-specific training.
Attributes:
| Name | Type | Description |
|---|---|---|
kl_annealing |
Literal['none', 'linear', 'sigmoid', 'cyclical']
|
Type of KL annealing schedule. - "none": No annealing, use full beta from start - "linear": Linear warmup from 0 to beta - "sigmoid": Sigmoid-shaped warmup - "cyclical": Cyclical annealing with periodic resets |
kl_warmup_steps |
int
|
Number of steps to reach full KL weight. |
beta |
float
|
Final beta weight for KL term (beta-VAE). Higher values encourage disentanglement but may hurt reconstruction. |
free_bits |
float
|
Minimum KL per latent dimension (0 = disabled). Prevents posterior collapse by ensuring minimum information flow. |
cyclical_period |
int
|
Period for cyclical annealing (if used). |
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
kl_annealing |
str |
"linear" |
KL weight schedule: "none", "linear", "sigmoid", "cyclical" |
kl_warmup_steps |
int |
10000 |
Steps to reach full KL weight |
beta |
float |
1.0 |
Final KL weight (beta-VAE parameter) |
free_bits |
float |
0.0 |
Minimum KL per latent dimension |
cyclical_period |
int |
10000 |
Period for cyclical annealing |
KL Annealing Schedules¤
None (Constant)¤
No annealing - use full beta weight from the start:
Linear Warmup¤
Linearly increase KL weight from 0 to beta:
config = VAETrainingConfig(
kl_annealing="linear",
kl_warmup_steps=10000,
beta=1.0,
)
# KL weight = beta * min(1.0, step / warmup_steps)
Sigmoid Warmup¤
S-shaped warmup curve centered at half the warmup steps:
Cyclical Annealing¤
Periodically reset KL weight to encourage information flow:
config = VAETrainingConfig(
kl_annealing="cyclical",
cyclical_period=5000,
beta=4.0,
)
# KL weight cycles: 0 -> beta -> 0 -> beta -> ...
Cyclical annealing helps prevent posterior collapse by periodically "reopening" information pathways.
Beta-VAE Training¤
Higher beta values encourage disentangled representations at the cost of reconstruction quality:
# Standard VAE (beta=1)
standard_config = VAETrainingConfig(beta=1.0)
# Beta-VAE for disentanglement (beta=4)
disentangled_config = VAETrainingConfig(beta=4.0)
# Strong regularization (beta=10)
strong_reg_config = VAETrainingConfig(beta=10.0)
Free Bits Constraint¤
Prevent posterior collapse by ensuring minimum KL per latent dimension:
The free bits constraint ensures each latent dimension carries at least the specified amount of information.
API Reference¤
artifex.generative_models.training.trainers.vae_trainer.VAETrainer
¤
VAETrainer(config: VAETrainingConfig | None = None)
VAE-specific trainer with KL annealing and beta-VAE support.
This trainer provides a JIT-compatible interface for training VAEs with: - KL annealing schedules (linear, sigmoid, cyclical) - Beta-VAE weighting for disentanglement - Free bits constraint to prevent posterior collapse
The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.
The trainer computes the ELBO loss with configurable KL weighting
L = reconstruction_loss + beta * kl_weight(step) * kl_loss
Example (non-JIT):
from artifex.generative_models.training.trainers import (
VAETrainer,
VAETrainingConfig,
)
config = VAETrainingConfig(
kl_annealing="cyclical",
beta=4.0,
free_bits=0.5,
)
trainer = VAETrainer(config)
# Create model and optimizer separately
model = VAEModel(config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
# Training loop
for step, batch in enumerate(data):
loss, metrics = trainer.train_step(model, optimizer, batch, step=step)
Example (JIT-compiled):
trainer = VAETrainer(config)
jit_step = nnx.jit(trainer.train_step)
for step, batch in enumerate(data):
loss, metrics = jit_step(model, optimizer, batch, step=step)
Note
The model is expected to return (reconstruction, mean, logvar) from its forward pass. The trainer handles loss computation and KL annealing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VAETrainingConfig | None
|
VAE training configuration. Uses defaults if not provided. |
None
|
get_kl_weight
¤
Compute KL weight based on annealing schedule.
This method is JIT-compatible - uses JAX operations instead of Python builtins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step
|
int | Array
|
Current training step (can be traced array for JIT). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
KL weight multiplier (0.0 to beta). |
apply_free_bits
¤
compute_kl_loss
¤
Compute KL divergence from standard normal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Array
|
Latent mean, shape (batch, latent_dim). |
required |
logvar
|
Array
|
Latent log-variance, shape (batch, latent_dim). |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple of (total_kl_loss, kl_per_sample) where: - total_kl_loss: Scalar mean KL loss - kl_per_sample: KL loss per sample, shape (batch,) |
compute_reconstruction_loss
¤
compute_reconstruction_loss(
x: Array,
recon_x: Array,
loss_type: Literal["mse", "bce"] = "mse",
) -> Array
Compute reconstruction loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Original input, shape (batch, ...). |
required |
recon_x
|
Array
|
Reconstructed output, shape (batch, ...). |
required |
loss_type
|
Literal['mse', 'bce']
|
Type of reconstruction loss ("mse" or "bce"). |
'mse'
|
Returns:
| Type | Description |
|---|---|
Array
|
Scalar reconstruction loss. |
compute_loss
¤
compute_loss(
model: Module,
batch: dict[str, Any],
step: int,
loss_type: Literal["mse", "bce"] = "mse",
) -> tuple[Array, dict[str, Any]]
Compute VAE loss with KL annealing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
VAE model to evaluate. |
required |
batch
|
dict[str, Any]
|
Batch dictionary with "image" or "data" key. |
required |
step
|
int
|
Current training step for annealing. |
required |
loss_type
|
Literal['mse', 'bce']
|
Type of reconstruction loss. |
'mse'
|
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (total_loss, metrics_dict). |
train_step
¤
train_step(
model: Module,
optimizer: Optimizer,
batch: dict[str, Any],
step: int = 0,
loss_type: Literal["mse", "bce"] = "mse",
) -> tuple[Array, dict[str, Any]]
Execute a single training step.
This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, step=step)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
VAE model to train. |
required |
optimizer
|
Optimizer
|
NNX optimizer for parameter updates. |
required |
batch
|
dict[str, Any]
|
Batch dictionary with "image" or "data" key. |
required |
step
|
int
|
Current training step for annealing. |
0
|
loss_type
|
Literal['mse', 'bce']
|
Type of reconstruction loss. |
'mse'
|
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
create_loss_fn
¤
create_loss_fn(
step: int, loss_type: Literal["mse", "bce"] = "mse"
) -> Callable[
[Module, dict[str, Any], Array],
tuple[Array, dict[str, Any]],
]
Create loss function compatible with base Trainer.
This enables DRY integration - VAE-specific training logic can be used with the base Trainer infrastructure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step
|
int
|
Current training step for KL annealing. |
required |
loss_type
|
Literal['mse', 'bce']
|
Type of reconstruction loss. |
'mse'
|
Returns:
| Name | Type | Description |
|---|---|---|
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]
|
Function with signature: (model, batch, rng) -> (loss, metrics) |
|
Note |
Callable[[Module, dict[str, Any], Array], tuple[Array, dict[str, Any]]]
|
rng is accepted for API compatibility but not used by VAE. |
Integration with Base Trainer¤
The VAE Trainer provides a create_loss_fn() method for seamless integration with the base Trainer's callbacks, checkpointing, and logging infrastructure:
from artifex.generative_models.training import Trainer
from artifex.generative_models.training.trainers import VAETrainer, VAETrainingConfig
from artifex.generative_models.training.callbacks import (
EarlyStopping,
EarlyStoppingConfig,
ModelCheckpoint,
CheckpointConfig,
)
# Create VAE-specific trainer
vae_config = VAETrainingConfig(kl_annealing="cyclical", beta=4.0)
vae_trainer = VAETrainer(model, optimizer, vae_config)
# Create loss function for a specific training step
# Note: step is required for KL annealing
def make_loss_fn(step: int):
return vae_trainer.create_loss_fn(step=step)
# Use with base Trainer for callbacks
callbacks = [
EarlyStopping(EarlyStoppingConfig(monitor="val_loss", patience=10)),
ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="val_loss")),
]
Model Requirements¤
The VAE Trainer expects models with the following interface:
class VAEModel(nnx.Module):
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Forward pass returning (reconstruction, mean, logvar).
Args:
x: Input data, shape (batch, ...).
Returns:
Tuple of:
- recon_x: Reconstructed data, shape (batch, ...)
- mean: Latent mean, shape (batch, latent_dim)
- logvar: Latent log-variance, shape (batch, latent_dim)
"""
...
Reconstruction Loss Types¤
The trainer supports MSE and BCE reconstruction losses:
# Mean Squared Error (default, for continuous data)
loss, metrics = trainer.train_step(batch, step=100, loss_type="mse")
# Binary Cross-Entropy (for images normalized to [0, 1])
loss, metrics = trainer.train_step(batch, step=100, loss_type="bce")
Training Metrics¤
The trainer returns detailed metrics for monitoring:
| Metric | Description |
|---|---|
loss |
Total ELBO loss |
recon_loss |
Reconstruction loss |
kl_loss |
KL divergence (unweighted) |
kl_weight |
Current KL weight from annealing |