Early Stopping Callbacks¤
Module: generative_models.training.callbacks.early_stopping
Source: generative_models/training/callbacks/early_stopping.py
Overview¤
Early stopping callback that monitors a metric and stops training when it stops improving. Designed for minimal overhead with simple comparisons and no external dependencies.
Classes¤
EarlyStoppingConfig¤
@dataclass(slots=True)
class EarlyStoppingConfig:
"""Configuration for early stopping."""
monitor: str = "val_loss"
min_delta: float = 0.0
patience: int = 10
mode: Literal["min", "max"] = "min"
check_finite: bool = True
stopping_threshold: float | None = None
divergence_threshold: float | None = None
Attributes:
| Attribute | Type | Default | Description |
|---|---|---|---|
monitor |
str |
"val_loss" |
Metric name to monitor |
min_delta |
float |
0.0 |
Minimum change to qualify as improvement |
patience |
int |
10 |
Epochs without improvement before stopping |
mode |
Literal["min", "max"] |
"min" |
Whether lower or higher is better |
check_finite |
bool |
True |
Stop when metric becomes NaN or Inf |
stopping_threshold |
float \| None |
None |
Stop immediately when metric reaches this value |
divergence_threshold |
float \| None |
None |
Stop if metric exceeds this (min mode only) |
EarlyStopping¤
class EarlyStopping(BaseCallback):
"""Stop training when a monitored metric stops improving."""
def __init__(self, config: EarlyStoppingConfig): ...
Callback that tracks a metric and signals when training should stop due to lack of improvement or other conditions.
Key Properties:
| Property | Type | Description |
|---|---|---|
should_stop |
bool |
Whether training should stop |
best_score |
float \| None |
Best metric value seen |
wait_count |
int |
Number of epochs without improvement |
stopped_epoch |
int \| None |
Epoch when stopping was triggered |
Usage¤
Basic Early Stopping¤
from artifex.generative_models.training.callbacks import (
EarlyStopping,
EarlyStoppingConfig,
)
# Stop if validation loss doesn't improve for 10 epochs
early_stop = EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
patience=10,
mode="min",
))
trainer.fit(callbacks=[early_stop])
# Check if training stopped early
if early_stop.should_stop:
print(f"Stopped at epoch {early_stop.stopped_epoch}")
Monitor Accuracy (Higher is Better)¤
early_stop = EarlyStopping(EarlyStoppingConfig(
monitor="val_accuracy",
patience=15,
mode="max", # Higher accuracy is better
min_delta=0.001, # Require at least 0.1% improvement
))
Stop at Target Performance¤
# Stop when validation loss reaches 0.01 (goal achieved)
early_stop = EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
stopping_threshold=0.01,
mode="min",
))
Detect Training Divergence¤
# Stop if loss explodes (divergence detection)
early_stop = EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
patience=10,
mode="min",
divergence_threshold=10.0, # Stop if loss > 10
check_finite=True, # Also stop on NaN/Inf
))
Combined with Checkpointing¤
from artifex.generative_models.training.callbacks import (
CallbackList,
EarlyStopping,
EarlyStoppingConfig,
ModelCheckpoint,
CheckpointConfig,
)
callbacks = CallbackList([
ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_loss",
save_top_k=3,
)),
EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
patience=10,
min_delta=1e-4,
)),
])
trainer.fit(callbacks=callbacks)
How It Works¤
- Metric Tracking: At each epoch end, reads the monitored metric from logs
- Finite Check: Optionally stops immediately if metric is NaN or Inf
- Threshold Check: Optionally stops if target performance reached
- Divergence Check: Optionally stops if metric exceeds threshold (min mode)
- Improvement Check: Compares current value to best with
min_deltatolerance - Patience Tracking: Counts epochs without improvement, stops when patience exceeded
Improvement Criteria¤
For a value to be considered an improvement:
- Min mode:
current < best - min_delta - Max mode:
current > best + min_delta
Stopping Conditions¤
Training stops when ANY of these conditions are met:
- Non-finite value:
check_finite=Trueand metric is NaN or Inf - Goal reached:
stopping_thresholdis set and metric reaches target - Divergence:
divergence_thresholdis set and metric exceeds it (min mode) - No improvement:
patienceepochs pass without improvement
Integration with Training Loop¤
The trainer checks callback.should_stop after each epoch:
# Inside trainer
for epoch in range(num_epochs):
# Training step...
logs = {"val_loss": val_loss, "val_accuracy": val_acc}
# Call callbacks
for callback in callbacks:
callback.on_epoch_end(self, epoch, logs)
# Check for early stopping
if hasattr(callback, "should_stop") and callback.should_stop:
print(f"Early stopping triggered at epoch {epoch}")
return
Module Statistics¤
- Classes: 2 (EarlyStoppingConfig, EarlyStopping)
- Dependencies: None (uses only Python standard library)
- Slots: Uses
__slots__for memory efficiency