Skip to content

Callbacks Base¤

Module: generative_models.training.callbacks.base

Source: generative_models/training/callbacks/base.py

Overview¤

Base classes and protocols for the training callback system. Provides a lightweight, Protocol-based callback system for training loops that avoids circular imports through the TrainerLike protocol.

Protocols¤

TrainerLike¤

class TrainerLike(Protocol):
    """Protocol for trainer-like objects that callbacks can interact with.

    This protocol defines the minimal interface that a trainer must implement
    to work with callbacks. Using a protocol avoids circular imports between
    callbacks and trainer modules.
    """

    @property
    def model(self) -> nnx.Module:
        """The NNX model being trained."""
        ...

Defines the minimal interface for trainers to work with callbacks without causing circular imports. The actual Trainer class satisfies this protocol.


TrainingCallback¤

class TrainingCallback(Protocol):
    """Protocol defining the callback interface."""

    def on_train_begin(self, trainer: TrainerLike) -> None: ...
    def on_train_end(self, trainer: TrainerLike) -> None: ...
    def on_epoch_begin(self, trainer: TrainerLike, epoch: int) -> None: ...
    def on_epoch_end(self, trainer: TrainerLike, epoch: int, logs: dict[str, Any]) -> None: ...
    def on_batch_begin(self, trainer: TrainerLike, batch: int) -> None: ...
    def on_batch_end(self, trainer: TrainerLike, batch: int, logs: dict[str, Any]) -> None: ...
    def on_validation_begin(self, trainer: TrainerLike) -> None: ...
    def on_validation_end(self, trainer: TrainerLike, logs: dict[str, Any]) -> None: ...

Defines the 8 lifecycle hooks that callbacks can implement.

Classes¤

BaseCallback¤

class BaseCallback:
    """Base callback class with no-op implementations of all hooks."""

Base class providing no-op implementations of all callback hooks. Extend this class and override only the methods you need.

Example:

from artifex.generative_models.training.callbacks import BaseCallback

class MyCallback(BaseCallback):
    def on_epoch_end(self, trainer, epoch, metrics):
        print(f"Epoch {epoch} completed with loss: {metrics.get('loss', 'N/A')}")

CallbackList¤

class CallbackList:
    """Container for multiple callbacks that dispatches to all."""

    def __init__(self, callbacks: list[TrainingCallback] | None = None): ...
    def append(self, callback: TrainingCallback) -> None: ...

Container that holds multiple callbacks and dispatches lifecycle events to all of them.

Example:

from artifex.generative_models.training.callbacks import (
    CallbackList,
    EarlyStopping,
    ModelCheckpoint,
)

callbacks = CallbackList([
    EarlyStopping(EarlyStoppingConfig(patience=10)),
    ModelCheckpoint(CheckpointConfig(dirpath="./checkpoints")),
])

# All callbacks receive events
callbacks.on_epoch_end(trainer, epoch=5, metrics={"loss": 0.5})

Usage¤

from artifex.generative_models.training.callbacks import (
    TrainerLike,
    TrainingCallback,
    BaseCallback,
    CallbackList,
)

# Implement custom callback
class LoggingCallback(BaseCallback):
    def on_batch_end(self, trainer: TrainerLike, batch: int, logs: dict):
        if batch % 100 == 0:
            print(f"Batch {batch}: {logs}")

# Use with trainer
callbacks = CallbackList([LoggingCallback()])

Module Statistics¤

  • Classes: 2 (BaseCallback, CallbackList)
  • Protocols: 2 (TrainerLike, TrainingCallback)
  • Imports: 4