TensorBoard Integration¤
Visualize training metrics, generated samples, and model architecture using TensorBoard with Workshop.
Overview¤
TensorBoard provides powerful visualization tools for machine learning experiments. Workshop integrates with TensorBoard to log metrics, visualize generated samples, and track training progress.
TensorBoard Benefits
- Real-time Monitoring: Watch training progress live
- Visualization: Interactive charts and image galleries
- Lightweight: No external services required
- Integration: Works seamlessly with JAX/Flax
-
Metrics Logging
Track scalars, histograms, and custom metrics
-
Visualization
Visualize samples, latent spaces, and attention
-
Training Integration
Integrate TensorBoard with Workshop training
Prerequisites¤
# Install TensorBoard
pip install tensorboard tensorboardX
# Or using uv
uv pip install tensorboard tensorboardX
Logging Patterns¤
Basic Scalar Logging¤
Track training metrics over time.
from torch.utils.tensorboard import SummaryWriter
import jax.numpy as jnp
import numpy as np
class TensorBoardLogger:
"""TensorBoard logging for Workshop models."""
def __init__(self, log_dir: str = "./runs/experiment"):
self.writer = SummaryWriter(log_dir)
self.step = 0
def log_scalars(self, metrics: dict, step: int = None):
"""Log scalar metrics.
Args:
metrics: Dictionary of metric names and values
step: Global step (uses internal counter if None)
"""
step = step if step is not None else self.step
for name, value in metrics.items():
# Convert JAX arrays to Python scalars
if isinstance(value, jnp.ndarray):
value = float(value)
self.writer.add_scalar(name, value, step)
if step is None:
self.step += 1
def log_training_step(
self,
loss: float,
learning_rate: float,
step: int,
):
"""Log training step metrics."""
self.log_scalars({
"train/loss": loss,
"train/learning_rate": learning_rate,
}, step=step)
def log_validation(
self,
val_loss: float,
metrics: dict,
epoch: int,
):
"""Log validation metrics."""
self.log_scalars({
"val/loss": val_loss,
**{f"val/{k}": v for k, v in metrics.items()}
}, step=epoch)
def close(self):
"""Close the writer."""
self.writer.close()
Image Logging¤
Visualize generated samples.
class ImageLogger:
"""Log images to TensorBoard."""
def __init__(self, writer: SummaryWriter):
self.writer = writer
def log_images(
self,
images: jax.Array,
tag: str,
step: int,
max_images: int = 16,
):
"""Log image batch.
Args:
images: Image batch (B, H, W, C) or (B, C, H, W)
tag: Tag for the images
step: Global step
max_images: Maximum number of images to log
"""
# Convert to numpy and limit number
images_np = np.array(images[:max_images])
# Denormalize from [-1, 1] to [0, 1]
images_np = (images_np + 1) / 2
images_np = np.clip(images_np, 0, 1)
# Ensure channel-first format (C, H, W)
if images_np.shape[-1] in [1, 3]: # Channel-last
images_np = np.transpose(images_np, (0, 3, 1, 2))
# Log as image grid
self.writer.add_images(tag, images_np, step)
def log_image_comparison(
self,
real_images: jax.Array,
generated_images: jax.Array,
step: int,
):
"""Log real vs generated comparison."""
self.log_images(real_images, "comparison/real", step)
self.log_images(generated_images, "comparison/generated", step)
Histogram Logging¤
Track parameter distributions.
from flax import nnx
class HistogramLogger:
"""Log parameter histograms."""
def __init__(self, writer: SummaryWriter):
self.writer = writer
def log_model_parameters(
self,
model,
step: int,
):
"""Log all model parameter histograms.
Args:
model: Flax NNX model
step: Global step
"""
state = nnx.state(model)
for name, param in state.items():
if isinstance(param, jnp.ndarray):
# Convert to numpy
param_np = np.array(param)
# Log histogram
self.writer.add_histogram(
f"parameters/{name}",
param_np,
step
)
def log_gradients(
self,
grads: dict,
step: int,
):
"""Log gradient histograms."""
for name, grad in grads.items():
if isinstance(grad, jnp.ndarray):
grad_np = np.array(grad)
self.writer.add_histogram(
f"gradients/{name}",
grad_np,
step
)
Visualization¤
Training Curves¤
Monitor loss and metrics over time.
class TrainingVisualizer:
"""Visualize training progress."""
def __init__(self, log_dir: str):
self.writer = SummaryWriter(log_dir)
def log_loss_components(
self,
total_loss: float,
reconstruction_loss: float,
kl_loss: float,
step: int,
):
"""Log VAE loss components."""
self.writer.add_scalars("loss_components", {
"total": total_loss,
"reconstruction": reconstruction_loss,
"kl_divergence": kl_loss,
}, step)
def log_gan_losses(
self,
g_loss: float,
d_loss: float,
d_real: float,
d_fake: float,
step: int,
):
"""Log GAN training metrics."""
self.writer.add_scalars("gan/losses", {
"generator": g_loss,
"discriminator": d_loss,
}, step)
self.writer.add_scalars("gan/discriminator", {
"real_score": d_real,
"fake_score": d_fake,
}, step)
Sample Galleries¤
Create grids of generated samples.
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
def create_sample_grid(images: np.ndarray, nrow: int = 8) -> np.ndarray:
"""Create image grid for visualization.
Args:
images: Batch of images (B, H, W, C)
nrow: Number of images per row
Returns:
Grid image as numpy array
"""
batch_size, h, w, c = images.shape
nrow = min(nrow, batch_size)
ncol = (batch_size + nrow - 1) // nrow
# Create figure
fig, axes = plt.subplots(ncol, nrow, figsize=(nrow * 2, ncol * 2))
axes = axes.flatten() if batch_size > 1 else [axes]
for idx, (ax, img) in enumerate(zip(axes, images)):
if c == 1: # Grayscale
ax.imshow(img.squeeze(), cmap='gray')
else: # RGB
ax.imshow(img)
ax.axis('off')
# Hide extra subplots
for idx in range(batch_size, len(axes)):
axes[idx].axis('off')
plt.tight_layout()
# Convert to numpy array
canvas = FigureCanvasAgg(fig)
canvas.draw()
grid = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
grid = grid.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
return grid
def log_sample_gallery(
logger: TensorBoardLogger,
model,
num_samples: int,
step: int,
rngs,
):
"""Log gallery of generated samples."""
# Generate samples
samples = model.sample(num_samples=num_samples, rngs=rngs)
# Convert and denormalize
samples_np = np.array(samples)
samples_np = ((samples_np + 1) / 2 * 255).astype(np.uint8)
# Create grid
grid = create_sample_grid(samples_np)
# Log to TensorBoard
logger.writer.add_image(
"samples/generated",
grid,
step,
dataformats='HWC'
)
Latent Space Visualization¤
Visualize learned latent representations.
def log_latent_space(
logger: TensorBoardLogger,
model,
test_data: jax.Array,
labels: jax.Array,
step: int,
):
"""Log latent space embedding.
Args:
logger: TensorBoard logger
model: Trained model with encoder
test_data: Test images
labels: Image labels
step: Global step
"""
# Encode to latent space
latents, _ = model.encode(test_data)
latents_np = np.array(latents)
labels_np = np.array(labels)
# Log embedding
logger.writer.add_embedding(
latents_np,
metadata=labels_np.tolist(),
label_img=test_data,
global_step=step,
tag="latent_space"
)
Integration with Training¤
TensorBoard Callback¤
Integrate with Workshop training loop.
from workshop.generative_models.training import Trainer
class TensorBoardTrainer(Trainer):
"""Trainer with TensorBoard logging."""
def __init__(
self,
model,
config: dict,
log_dir: str = "./runs/experiment",
**kwargs
):
super().__init__(model, config, **kwargs)
# Initialize TensorBoard
self.tb_logger = TensorBoardLogger(log_dir)
self.log_frequency = config.get("tb_log_frequency", 100)
def on_train_step_end(self, step: int, loss: float, metrics: dict):
"""Log after each training step."""
if step % self.log_frequency == 0:
self.tb_logger.log_scalars({
"train/loss": loss,
**{f"train/{k}": v for k, v in metrics.items()}
}, step=step)
def on_validation_end(self, epoch: int, val_metrics: dict):
"""Log after validation."""
self.tb_logger.log_scalars({
f"val/{k}": v for k, v in val_metrics.items()
}, step=epoch)
# Log generated samples
samples = self.model.sample(num_samples=16, rngs=self.rngs)
image_logger = ImageLogger(self.tb_logger.writer)
image_logger.log_images(samples, "samples/generated", epoch)
def on_training_end(self):
"""Close TensorBoard on training end."""
self.tb_logger.close()
Complete Example¤
Full training example with TensorBoard.
from flax import nnx
import jax
from workshop.generative_models.models.vae import VAE
from workshop.generative_models.training import Trainer
# Create model
model = VAE(
latent_dim=128,
image_shape=(28, 28, 1),
rngs=nnx.Rngs(0),
)
# Training configuration
config = {
"learning_rate": 1e-4,
"batch_size": 128,
"num_epochs": 50,
"tb_log_frequency": 100,
}
# Create trainer with TensorBoard
trainer = TensorBoardTrainer(
model=model,
config=config,
log_dir="./runs/vae_experiment",
)
# Train
trainer.train(train_data, val_data)
# View results
print("To view TensorBoard, run:")
print("tensorboard --logdir=./runs")
Launching TensorBoard¤
Basic Launch¤
# Launch TensorBoard
tensorboard --logdir=./runs
# Custom port
tensorboard --logdir=./runs --port=6007
# Multiple experiments
tensorboard --logdir=./runs --reload_interval=5
Comparing Experiments¤
# Compare multiple experiments
tensorboard --logdir_spec=baseline:./runs/baseline,improved:./runs/improved
Best Practices¤
DO¤
Recommended Practices
✅ Organize logs by experiment in separate directories
✅ Log periodically (every 100-1000 steps)
✅ Use meaningful tags for metrics and images
✅ Log validation samples to track generation quality
✅ Close writer when training completes
DON'T¤
Avoid These Mistakes
❌ Don't log every step (creates huge files)
❌ Don't log high-res images frequently (use max_images)
❌ Don't forget to flush the writer periodically
❌ Don't reuse log directories without clearing
Troubleshooting¤
| Issue | Cause | Solution |
|---|---|---|
| TensorBoard not showing data | Data not flushed | Call writer.flush() or close writer |
| Large log files | Logging too frequently | Reduce logging frequency |
| Images not appearing | Wrong format | Ensure channel-first format (C, H, W) |
| Port already in use | TensorBoard running | Use different port with --port |
| Slow performance | Too many logs | Reduce log frequency or clear old runs |
Summary¤
TensorBoard provides essential visualization:
- Real-time Monitoring: Track training progress live
- Scalar Metrics: Loss curves and validation metrics
- Image Galleries: Visualize generated samples
- Histograms: Monitor parameter distributions
- Embeddings: Explore latent spaces
Start visualizing your training today!
Next Steps¤
-
Weights & Biases
Advanced experiment tracking and sweeps
-
HuggingFace Hub
Share models with the community
-
Training Guide
Master the training system
-
Benchmarking
Evaluate models comprehensively