Checkpointing Callbacks¤
Module: generative_models.training.callbacks.checkpoint
Source: generative_models/training/callbacks/checkpoint.py
Overview¤
Model checkpointing callback that monitors metrics and saves checkpoints when they improve. Uses Orbax checkpointing under the hood with minimal overhead when not saving.
Classes¤
CheckpointConfig¤
@dataclass(slots=True)
class CheckpointConfig:
"""Configuration for model checkpointing."""
dirpath: str | Path = "checkpoints"
filename: str = "model-{epoch:02d}-{val_loss:.4f}"
monitor: str = "val_loss"
mode: Literal["min", "max"] = "min"
save_top_k: int = 3
save_last: bool = True
every_n_epochs: int = 1
save_weights_only: bool = False
Attributes:
| Attribute | Type | Default | Description |
|---|---|---|---|
dirpath |
str \| Path |
"checkpoints" |
Directory to save checkpoints |
filename |
str |
"model-{epoch:02d}-{val_loss:.4f}" |
Filename template with {epoch} and metric placeholders |
monitor |
str |
"val_loss" |
Metric name to monitor |
mode |
Literal["min", "max"] |
"min" |
Whether lower or higher is better |
save_top_k |
int |
3 |
Number of best checkpoints to keep (-1 = all, 0 = none) |
save_last |
bool |
True |
Whether to save checkpoint on every epoch |
every_n_epochs |
int |
1 |
Save checkpoint every n epochs |
save_weights_only |
bool |
False |
Only save model weights (not optimizer state) |
ModelCheckpoint¤
class ModelCheckpoint(BaseCallback):
"""Save model checkpoints based on monitored metrics."""
def __init__(self, config: CheckpointConfig): ...
Callback that saves model checkpoints when monitored metrics improve. Uses Orbax checkpointing infrastructure with automatic cleanup of old checkpoints.
Key Properties:
| Property | Type | Description |
|---|---|---|
best_score |
float \| None |
Best metric value seen so far |
best_checkpoint_path |
Path \| None |
Path to the best checkpoint |
saved_checkpoints |
list[tuple[float, Path]] |
List of (score, path) for saved checkpoints |
Usage¤
Basic Checkpointing¤
from artifex.generative_models.training.callbacks import (
ModelCheckpoint,
CheckpointConfig,
)
# Save best 3 checkpoints based on validation loss
checkpoint = ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_loss",
mode="min",
save_top_k=3,
))
trainer.fit(callbacks=[checkpoint])
# Access best checkpoint after training
print(f"Best checkpoint: {checkpoint.best_checkpoint_path}")
print(f"Best score: {checkpoint.best_score}")
Monitor Accuracy (Higher is Better)¤
checkpoint = ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_accuracy",
mode="max", # Higher accuracy is better
save_top_k=1, # Keep only the best
filename="best-model-{epoch:02d}-{val_accuracy:.4f}",
))
Save All Checkpoints¤
checkpoint = ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
save_top_k=-1, # Keep all checkpoints
every_n_epochs=5, # Save every 5 epochs
))
Custom Filename Template¤
# Filename template supports {epoch} and any metric in logs
checkpoint = ModelCheckpoint(CheckpointConfig(
filename="model-epoch{epoch:03d}-loss{val_loss:.6f}-acc{val_accuracy:.4f}",
monitor="val_loss",
))
Combined with Other Callbacks¤
from artifex.generative_models.training.callbacks import (
CallbackList,
ModelCheckpoint,
CheckpointConfig,
EarlyStopping,
EarlyStoppingConfig,
ProgressBarCallback,
ProgressBarConfig,
)
callbacks = CallbackList([
ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_loss",
save_top_k=3,
)),
EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
patience=10,
)),
ProgressBarCallback(ProgressBarConfig()),
])
trainer.fit(callbacks=callbacks)
How It Works¤
- Metric Monitoring: Tracks the specified metric (
monitor) at the end of each epoch - Improvement Check: Compares current value against best using
mode(min/max) - Checkpoint Saving: Uses Orbax infrastructure to save model state
- Automatic Cleanup: Removes old checkpoints beyond
save_top_klimit - Best Tracking: Maintains reference to the best checkpoint path
Checkpoint Cleanup Strategy¤
When save_top_k > 0, checkpoints are sorted by their metric value:
- Min mode: Keeps checkpoints with lowest scores, removes highest
- Max mode: Keeps checkpoints with highest scores, removes lowest
Integration with Orbax¤
ModelCheckpoint uses the existing Orbax-based checkpointing infrastructure:
from artifex.generative_models.core.checkpointing import (
save_checkpoint,
load_checkpoint,
setup_checkpoint_manager,
)
# Checkpoints are saved using Orbax under the hood
# You can load them manually:
checkpoint_manager, _ = setup_checkpoint_manager("./checkpoints")
model = load_checkpoint(checkpoint_manager, model, step=10)
See Checkpointing Guide for advanced checkpointing features including optimizer state and corruption recovery.
Module Statistics¤
- Classes: 2 (CheckpointConfig, ModelCheckpoint)
- Dependencies: Orbax checkpointing infrastructure
- Slots: Uses
__slots__for memory efficiency