Skip to content

Weights & Biases Integration¤

Complete guide to experiment tracking, hyperparameter sweeps, and artifact management using Weights & Biases with Workshop.

Overview¤

Weights & Biases (W&B) is a powerful platform for tracking machine learning experiments. Workshop integrates seamlessly with W&B to log metrics, visualize training progress, run hyperparameter sweeps, and manage model artifacts.

W&B Benefits

  • Experiment Tracking: Log metrics, hyperparameters, and system info
  • Visualizations: Interactive charts and sample galleries
  • Hyperparameter Sweeps: Automated optimization
  • Artifact Management: Version models and datasets
  • Team Collaboration: Share results with teammates

Prerequisites¤

Install and configure W&B:

# Install wandb
pip install wandb

# Or using uv
uv pip install wandb

# Login
wandb login

Experiment Tracking¤

Basic Logging¤

Track metrics during training.

import wandb
from flax import nnx
import jax.numpy as jnp

class WandBTrainer:
    """Trainer with W&B logging."""

    def __init__(
        self,
        model,
        config: dict,
        project_name: str = "workshop-experiments",
    ):
        self.model = model
        self.config = config

        # Initialize W&B
        wandb.init(
            project=project_name,
            config=config,
            name=f"{config['model_type']}-{config.get('experiment_name', 'default')}",
        )

        # Log model architecture
        wandb.watch(self.model, log_freq=100)

    def train_step(self, batch, step: int):
        """Single training step with logging."""
        # Compute loss (simplified)
        loss, metrics = self.compute_loss(batch)

        # Log metrics
        wandb.log({
            "train/loss": float(loss),
            "train/step": step,
            **{f"train/{k}": float(v) for k, v in metrics.items()}
        }, step=step)

        return loss, metrics

    def validation_step(self, val_data, step: int):
        """Validation with logging."""
        val_loss, val_metrics = self.evaluate(val_data)

        # Log validation metrics
        wandb.log({
            "val/loss": float(val_loss),
            **{f"val/{k}": float(v) for k, v in val_metrics.items()}
        }, step=step)

        return val_loss, val_metrics

    def log_images(self, images, step: int, key: str = "samples"):
        """Log generated images."""
        # Convert to numpy and denormalize
        images_np = np.array(images)
        images_np = ((images_np + 1) / 2 * 255).astype(np.uint8)

        # Log to W&B
        wandb.log({
            key: [wandb.Image(img) for img in images_np]
        }, step=step)

    def finish(self):
        """Finish W&B run."""
        wandb.finish()

Advanced Metrics Logging¤

Track custom metrics and visualizations.

class AdvancedWandBLogger:
    """Advanced W&B logging with custom metrics."""

    def __init__(self, project: str, config: dict):
        wandb.init(project=project, config=config)
        self.step = 0

    def log_training_metrics(
        self,
        loss: float,
        reconstruction_loss: float,
        kl_loss: float,
        learning_rate: float,
    ):
        """Log detailed training metrics."""
        wandb.log({
            # Loss components
            "loss/total": loss,
            "loss/reconstruction": reconstruction_loss,
            "loss/kl_divergence": kl_loss,

            # Optimization
            "optimization/learning_rate": learning_rate,
            "optimization/step": self.step,

            # Loss ratios
            "analysis/recon_kl_ratio": reconstruction_loss / (kl_loss + 1e-8),
        }, step=self.step)

        self.step += 1

    def log_histograms(self, model):
        """Log parameter histograms."""
        state = nnx.state(model)

        for name, param in state.items():
            if isinstance(param, jnp.ndarray):
                wandb.log({
                    f"histograms/{name}": wandb.Histogram(np.array(param))
                }, step=self.step)

    def log_latent_space(
        self,
        latent_codes: jax.Array,
        labels: jax.Array = None,
    ):
        """Log latent space visualization."""
        # Convert to numpy
        latent_np = np.array(latent_codes)

        # Log scatter plot
        if labels is not None:
            labels_np = np.array(labels)
            data = [[x, y, label] for (x, y), label in zip(latent_np[:, :2], labels_np)]
            table = wandb.Table(data=data, columns=["x", "y", "label"])
            wandb.log({
                "latent_space": wandb.plot.scatter(
                    table, "x", "y", "label",
                    title="Latent Space Visualization"
                )
            }, step=self.step)
        else:
            wandb.log({
                "latent_space": wandb.Histogram(latent_np)
            }, step=self.step)

    def log_generation_quality(
        self,
        real_images: jax.Array,
        fake_images: jax.Array,
    ):
        """Log generation quality metrics."""
        from workshop.benchmarks.metrics import compute_fid, compute_inception_score

        # Compute metrics
        fid = compute_fid(real_images, fake_images)
        inception_score, _ = compute_inception_score(fake_images)

        wandb.log({
            "quality/fid": fid,
            "quality/inception_score": inception_score,
        }, step=self.step)

Hyperparameter Sweeps¤

Sweep Configuration¤

Define hyperparameter search space.

# sweep_config.yaml or Python dict
sweep_config = {
    "method": "bayes",  # bayes, grid, or random
    "metric": {
        "name": "val/loss",
        "goal": "minimize",
    },
    "parameters": {
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-5,
            "max": 1e-2,
        },
        "latent_dim": {
            "values": [64, 128, 256, 512],
        },
        "beta": {
            "distribution": "uniform",
            "min": 0.1,
            "max": 10.0,
        },
        "batch_size": {
            "values": [32, 64, 128],
        },
        "architecture": {
            "values": ["conv", "resnet", "transformer"],
        },
    },
}

# Create sweep
sweep_id = wandb.sweep(
    sweep=sweep_config,
    project="workshop-sweeps",
)

Running Sweeps¤

Execute hyperparameter sweep.

def train_with_sweep():
    """Training function for W&B sweep."""
    # Initialize W&B run
    run = wandb.init()

    # Get hyperparameters from sweep
    config = wandb.config

    # Build model with sweep config
    model = build_model(
        latent_dim=config.latent_dim,
        architecture=config.architecture,
    )

    # Train model
    for epoch in range(config.num_epochs):
        # Training loop
        train_loss = train_epoch(
            model,
            learning_rate=config.learning_rate,
            batch_size=config.batch_size,
            beta=config.beta,
        )

        # Validation
        val_loss = validate(model)

        # Log metrics
        wandb.log({
            "train/loss": train_loss,
            "val/loss": val_loss,
            "epoch": epoch,
        })

    # Finish run
    wandb.finish()


# Run sweep agent
wandb.agent(
    sweep_id=sweep_id,
    function=train_with_sweep,
    count=50,  # Number of runs
)

Multi-Objective Optimization¤

Optimize for multiple metrics simultaneously.

sweep_config_multi = {
    "method": "bayes",
    "metric": {
        "name": "combined_score",
        "goal": "maximize",
    },
    "parameters": {
        # ... same as before
    },
}

def train_with_multi_objective():
    """Train with multiple objectives."""
    run = wandb.init()
    config = wandb.config

    # Training...
    val_loss = validate(model)
    fid_score = compute_fid(model)
    inference_time = benchmark_inference(model)

    # Combine objectives
    # Lower is better for loss and FID, lower is better for time
    combined_score = -val_loss - fid_score / 100 - inference_time

    wandb.log({
        "val/loss": val_loss,
        "quality/fid": fid_score,
        "performance/inference_time": inference_time,
        "combined_score": combined_score,
    })

    wandb.finish()

Artifact Management¤

Model Artifacts¤

Version and track trained models.

class ArtifactManager:
    """Manage W&B artifacts for models and datasets."""

    def __init__(self, project: str):
        self.project = project

    def save_model_artifact(
        self,
        model,
        artifact_name: str,
        metadata: dict = None,
    ):
        """Save model as W&B artifact.

        Args:
            model: Trained model
            artifact_name: Artifact name (e.g., "vae-mnist")
            metadata: Additional metadata
        """
        # Create artifact
        artifact = wandb.Artifact(
            name=artifact_name,
            type="model",
            metadata=metadata or {},
        )

        # Save model to temporary directory
        import tempfile
        with tempfile.TemporaryDirectory() as tmpdir:
            model_path = f"{tmpdir}/model"

            # Export model
            state = nnx.state(model)
            with open(f"{model_path}/params.pkl", "wb") as f:
                import pickle
                pickle.dump(state, f)

            # Add to artifact
            artifact.add_dir(model_path)

        # Log artifact
        wandb.log_artifact(artifact)

        print(f"Model saved as artifact: {artifact_name}")

    def load_model_artifact(
        self,
        artifact_name: str,
        version: str = "latest",
    ):
        """Load model from artifact.

        Args:
            artifact_name: Artifact name
            version: Artifact version ("latest" or "v0", "v1", etc.)

        Returns:
            Loaded model
        """
        # Download artifact
        artifact = wandb.use_artifact(
            f"{artifact_name}:{version}",
            type="model",
        )
        artifact_dir = artifact.download()

        # Load model
        import pickle
        with open(f"{artifact_dir}/params.pkl", "rb") as f:
            state = pickle.load(f)

        # Reconstruct model
        # (Simplified - actual implementation needs proper model reconstruction)
        model = reconstruct_model(state)

        return model

    def save_dataset_artifact(
        self,
        data: jax.Array,
        artifact_name: str,
        description: str = "",
    ):
        """Save dataset as artifact.

        Args:
            data: Dataset array
            artifact_name: Artifact name
            description: Dataset description
        """
        artifact = wandb.Artifact(
            name=artifact_name,
            type="dataset",
            description=description,
        )

        # Save as numpy array
        import tempfile
        with tempfile.TemporaryDirectory() as tmpdir:
            data_path = f"{tmpdir}/data.npy"
            np.save(data_path, np.array(data))
            artifact.add_file(data_path)

        wandb.log_artifact(artifact)

    def link_artifacts(
        self,
        model_artifact: str,
        dataset_artifact: str,
    ):
        """Link model to training dataset.

        Args:
            model_artifact: Model artifact name
            dataset_artifact: Dataset artifact name
        """
        # Get artifacts
        model = wandb.use_artifact(f"{model_artifact}:latest")
        dataset = wandb.use_artifact(f"{dataset_artifact}:latest")

        # Link them
        model.link(f"trained_on_{dataset_artifact}")
        wandb.log_artifact(model)

Artifact Lineage¤

Track relationships between artifacts.

def create_artifact_lineage():
    """Create artifact lineage for full experiment tracking."""
    run = wandb.init(project="workshop-lineage")

    # Log dataset
    dataset_artifact = wandb.Artifact("mnist-train", type="dataset")
    dataset_artifact.add_file("data/mnist_train.npy")
    wandb.log_artifact(dataset_artifact)

    # Use dataset in training
    dataset = wandb.use_artifact("mnist-train:latest")
    dataset_dir = dataset.download()

    # Train model...
    model = train_model(dataset_dir)

    # Log trained model (automatically linked to dataset)
    model_artifact = wandb.Artifact("vae-model", type="model")
    model_artifact.add_dir("models/vae")
    wandb.log_artifact(model_artifact)

    # Log evaluation results
    eval_artifact = wandb.Artifact("vae-evaluation", type="evaluation")
    eval_artifact.add_file("results/metrics.json")
    wandb.log_artifact(eval_artifact)

    wandb.finish()

Report Generation¤

Creating Reports¤

Generate shareable experiment reports.

class ReportGenerator:
    """Generate W&B reports."""

    @staticmethod
    def create_experiment_report(
        project: str,
        runs: list[str],
        title: str,
    ) -> str:
        """Create comparison report.

        Args:
            project: W&B project name
            runs: List of run IDs to compare
            title: Report title

        Returns:
            Report URL
        """
        import wandb

        # Create report
        report = wandb.apis.reports.Report(
            project=project,
            title=title,
            description="Experiment comparison report",
        )

        # Add run comparison
        report.blocks = [
            wandb.apis.reports.RunComparer(
                diff_only=False,
                runsets=[
                    wandb.apis.reports.Runset(
                        project=project,
                        filters={"$or": [{"name": run_id} for run_id in runs]},
                    )
                ],
            ),
            wandb.apis.reports.LinePlot(
                title="Training Loss Comparison",
                x="step",
                y=["train/loss"],
                runsets=[
                    wandb.apis.reports.Runset(
                        project=project,
                        filters={"$or": [{"name": run_id} for run_id in runs]},
                    )
                ],
            ),
        ]

        # Save and get URL
        report.save()
        return report.url

Custom Visualizations¤

Create custom plots for reports.

def log_custom_visualizations(model, test_data):
    """Log custom visualizations to W&B."""

    # 1. Sample Grid
    samples = model.sample(num_samples=64)
    wandb.log({
        "visualizations/sample_grid": wandb.Image(
            create_image_grid(samples, nrow=8)
        )
    })

    # 2. Reconstruction Comparison
    reconstructions = model.reconstruct(test_data[:8])
    comparison = np.concatenate([test_data[:8], reconstructions], axis=0)
    wandb.log({
        "visualizations/reconstruction": wandb.Image(
            create_image_grid(comparison, nrow=8)
        )
    })

    # 3. Latent Space 2D Projection
    from sklearn.manifold import TSNE

    latents, labels = encode_dataset(model, test_data)
    tsne = TSNE(n_components=2)
    latents_2d = tsne.fit_transform(np.array(latents))

    # Create scatter plot
    data = [[x, y, int(label)] for (x, y), label in zip(latents_2d, labels)]
    table = wandb.Table(data=data, columns=["x", "y", "label"])
    wandb.log({
        "visualizations/latent_tsne": wandb.plot.scatter(
            table, "x", "y", "label",
            title="Latent Space t-SNE"
        )
    })

    # 4. Interpolation Video
    interpolation_frames = create_interpolation(model, num_frames=60)
    wandb.log({
        "visualizations/interpolation": wandb.Video(
            interpolation_frames, fps=30, format="gif"
        )
    })

Integration with Workshop Training¤

Complete Training Example¤

Full integration example.

from workshop.generative_models.training import Trainer
from workshop.generative_models.core import DeviceManager

class WorkshopWandBTrainer(Trainer):
    """Workshop Trainer with W&B integration."""

    def __init__(
        self,
        model,
        config: dict,
        wandb_project: str = "workshop",
        **kwargs
    ):
        super().__init__(model, config, **kwargs)

        # Initialize W&B
        wandb.init(
            project=wandb_project,
            config=config,
            name=config.get("experiment_name"),
        )

        self.wandb_log_frequency = config.get("wandb_log_frequency", 100)

    def on_train_step_end(self, step: int, loss: float, metrics: dict):
        """Called after each training step."""
        if step % self.wandb_log_frequency == 0:
            wandb.log({
                "train/loss": float(loss),
                "train/step": step,
                **{f"train/{k}": float(v) for k, v in metrics.items()}
            }, step=step)

    def on_validation_end(self, epoch: int, val_metrics: dict):
        """Called after validation."""
        wandb.log({
            "val/epoch": epoch,
            **{f"val/{k}": float(v) for k, v in val_metrics.items()}
        }, step=epoch)

        # Log sample images
        samples = self.model.sample(num_samples=16, rngs=self.rngs)
        wandb.log({
            "samples": [wandb.Image(img) for img in np.array(samples)]
        }, step=epoch)

    def on_training_end(self):
        """Called when training completes."""
        # Save final model as artifact
        artifact = wandb.Artifact("final_model", type="model")
        artifact.add_dir(self.checkpoint_dir)
        wandb.log_artifact(artifact)

        wandb.finish()


# Usage
config = {
    "model_type": "vae",
    "latent_dim": 128,
    "learning_rate": 1e-4,
    "batch_size": 128,
    "num_epochs": 100,
    "experiment_name": "vae-mnist-baseline",
}

trainer = WorkshopWandBTrainer(
    model=model,
    config=config,
    wandb_project="workshop-experiments",
)

trainer.train(train_data, val_data)

Best Practices¤

DO¤

Recommended Practices

Log hyperparameters at the start of each run

Use meaningful run names for easy identification

Tag runs with experiment type (baseline, ablation, etc.)

Save artifacts for reproducibility

Create reports for team sharing

Use sweeps for systematic hyperparameter search

DON'T¤

Avoid These Mistakes

Don't log too frequently (causes overhead)

Don't hardcode API keys (use environment variables)

Don't skip run names (default names are hard to track)

Don't log high-resolution images every step (use subsampling)

Don't forget to call wandb.finish()

Don't create too many artifacts (version strategically)


Troubleshooting¤

Issue Cause Solution
Slow logging Logging too frequently Reduce log frequency to every 100-1000 steps
Missing metrics Not calling wandb.log() Ensure metrics are logged in training loop
Artifact upload fails Large file size Use compression or split artifacts
Sweep not starting Invalid config Validate sweep config with W&B docs
Run not appearing Network issues Check internet connection, retry
Memory leak Not finishing runs Always call wandb.finish()

Summary¤

W&B integration provides powerful experiment tracking:

  • Metrics Logging: Track training progress in real-time
  • Hyperparameter Sweeps: Automate optimization
  • Artifacts: Version models and datasets
  • Reports: Share results with teammates
  • Visualization: Interactive charts and galleries

Start tracking your experiments systematically!


Next Steps¤