Profiling Callbacks¤
Module: generative_models.training.callbacks.profiling
Source: generative_models/training/callbacks/profiling.py
Overview¤
Profiling callbacks for integrating JAX-native performance analysis into training loops. These callbacks provide trace-based profiling for TensorBoard visualization and memory usage tracking with minimal overhead.
Classes¤
ProfilingConfig¤
@dataclass(slots=True)
class ProfilingConfig:
log_dir: str = "logs/profiles"
start_step: int = 10
end_step: int = 20
trace_memory: bool = True
trace_python: bool = False
Configuration for JAX trace profiling.
Attributes:
log_dir: Directory to save profiling tracesstart_step: Step at which to start profiling (skip warmup)end_step: Step at which to stop profilingtrace_memory: Whether to include memory usage in tracestrace_python: Whether to trace Python execution (slower but more detail)
JAXProfiler¤
JAX profiler callback for performance analysis.
Integrates with JAX's built-in profiler to capture traces that can be viewed in TensorBoard or Perfetto. Automatically skips warmup steps to get more representative profiling data.
Features:
- Integration with JAX's built-in profiler
- TensorBoard trace visualization
- Configurable profiling window (start/end steps)
- Automatic cleanup on training end
- No interference with JIT compilation
- Minimal overhead outside profiling window
Example:
from artifex.generative_models.training.callbacks import (
JAXProfiler,
ProfilingConfig,
)
config = ProfilingConfig(
log_dir="logs/profiles",
start_step=10, # Skip JIT warmup
end_step=20, # Profile 10 steps
)
profiler = JAXProfiler(config)
trainer.fit(callbacks=[profiler])
# View in TensorBoard:
# tensorboard --logdir logs/profiles
Best Practices:
- Set
start_stepafter JIT warmup (typically 5-10 steps) - Keep profiling window small (10-20 steps) to minimize impact
- Use
trace_python=Trueonly when debugging Python bottlenecks - Traces can be viewed in TensorBoard or Perfetto
MemoryProfileConfig¤
@dataclass(slots=True)
class MemoryProfileConfig:
log_dir: str = "logs/memory"
profile_every_n_steps: int = 100
log_device_memory: bool = True
Configuration for memory profiling.
Attributes:
log_dir: Directory to save memory profile dataprofile_every_n_steps: Collect memory info every N stepslog_device_memory: Whether to log device (GPU/TPU) memory stats
MemoryProfiler¤
Memory usage profiling callback.
Tracks memory usage during training and saves a timeline to JSON. Useful for identifying memory leaks and understanding memory patterns.
Features:
- Track JAX device memory usage (GPU/TPU)
- Log peak memory per step
- Export memory timeline to JSON
- Configurable profiling interval
- Minimal overhead between collection intervals
Example:
from artifex.generative_models.training.callbacks import (
MemoryProfiler,
MemoryProfileConfig,
)
config = MemoryProfileConfig(
log_dir="logs/memory",
profile_every_n_steps=50,
)
profiler = MemoryProfiler(config)
trainer.fit(callbacks=[profiler])
# Memory profile saved to logs/memory/memory_profile.json
Output Format:
The memory profile is saved as JSON with the following structure:
[
{
"step": 0,
"memory": {
"cuda:0": {
"bytes_in_use": 1073741824,
"peak_bytes_in_use": 2147483648
}
}
},
{
"step": 100,
"memory": {
"cuda:0": {
"bytes_in_use": 1073741824,
"peak_bytes_in_use": 2147483648
}
}
}
]
Note: Not all devices support memory_stats(). CPU devices typically return None, in which case those devices are skipped.
Usage with Multiple Callbacks¤
Profiling callbacks can be combined with other callbacks:
from artifex.generative_models.training.callbacks import (
CallbackList,
EarlyStopping,
EarlyStoppingConfig,
ModelCheckpoint,
CheckpointConfig,
JAXProfiler,
ProfilingConfig,
MemoryProfiler,
MemoryProfileConfig,
ProgressBarCallback,
)
# Configure callbacks
callbacks = CallbackList([
# Profiling (runs first to capture full training)
JAXProfiler(ProfilingConfig(
log_dir="logs/profiles",
start_step=10,
end_step=20,
)),
MemoryProfiler(MemoryProfileConfig(
log_dir="logs/memory",
profile_every_n_steps=100,
)),
# Progress display
ProgressBarCallback(),
# Training control
EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
patience=10,
)),
# Checkpointing
ModelCheckpoint(CheckpointConfig(
dirpath="checkpoints",
monitor="val_loss",
)),
])
trainer.fit(callbacks=callbacks)
Performance Considerations¤
Both profiling callbacks are designed for minimal overhead:
- JAXProfiler: Zero overhead outside the profiling window (start_step to end_step)
- MemoryProfiler: Only collects data at configured intervals; no overhead between intervals
The callbacks do not interfere with JAX's JIT compilation. JIT-compiled functions produce identical results before, during, and after profiling.
Viewing Traces¤
TensorBoard¤
Navigate to the "Profile" tab to view:
- XLA compilation times
- Device execution times
- Memory allocation patterns
- Kernel execution traces
Perfetto¤
- Go to Perfetto UI
- Click "Open trace file"
- Select the
.tracefile from your log directory
Perfetto provides more detailed trace analysis capabilities including:
- Timeline visualization
- Flame graphs
- Memory analysis
Module Statistics¤
- Classes: 4
- Functions: 0
- Imports: 4