Loss Functions for Generative Models¤
Level: Intermediate | Runtime: ~30 seconds (CPU) | Format: Python + Jupyter
Prerequisites: Basic understanding of loss functions and JAX | Target Audience: Users learning Workshop's loss API
Overview¤
This example provides a comprehensive guide to loss functions in Workshop, covering everything from simple functional losses to advanced composable loss systems. Learn how to use built-in losses, create custom compositions, and apply specialized losses for VAEs, GANs, and geometric models.
What You'll Learn¤
-
Functional Losses
Simple loss functions (MSE, MAE) with flexible reduction modes
-
Composable System
Combine weighted losses with component tracking
-
VAE Losses
Reconstruction + KL divergence for variational autoencoders
-
GAN Losses
Generator and discriminator losses (Standard, LS-GAN, Wasserstein)
-
Scheduled Losses
Time-varying loss weights for curriculum learning
-
Geometric Losses
Chamfer distance and mesh losses for 3D data
Files¤
This example is available in two formats:
- Python Script:
loss_examples.py - Jupyter Notebook:
loss_examples.ipynb
Quick Start¤
Run the Python Script¤
# Activate environment
source activate.sh
# Run the example
python examples/generative_models/loss_examples.py
Run the Jupyter Notebook¤
# Activate environment
source activate.sh
# Launch Jupyter
jupyter lab examples/generative_models/loss_examples.ipynb
Key Concepts¤
1. Functional Losses¤
Simple, stateless loss functions for common use cases:
from workshop.generative_models.core.losses import mse_loss, mae_loss
# Mean Squared Error
loss = mse_loss(predictions, targets, reduction="mean")
# Mean Absolute Error
loss = mae_loss(predictions, targets, reduction="sum")
Available Reductions:
"mean": Average over all elements (default)"sum": Sum all elements"none": Return per-element losses
2. Weighted Losses¤
Apply fixed weights to loss components:
from workshop.generative_models.core.losses import WeightedLoss
# Create weighted loss
weighted_mse = WeightedLoss(
loss_fn=mse_loss,
weight=2.0,
name="weighted_reconstruction"
)
# Compute weighted loss
loss_value = weighted_mse(predictions, targets)
3. Composite Losses¤
Combine multiple loss functions:
from workshop.generative_models.core.losses import CompositeLoss
composite = CompositeLoss([
WeightedLoss(mse_loss, weight=1.0, name="reconstruction"),
WeightedLoss(mae_loss, weight=0.5, name="l1_penalty"),
], return_components=True)
# Get total loss and components
total_loss, components = composite(predictions, targets)
# components = {"reconstruction": 0.15, "l1_penalty": 0.08}
4. VAE Losses¤
VAE loss combines reconstruction and KL divergence:
def vae_loss(reconstruction, targets, mean, logvar, beta=1.0):
# Reconstruction loss
recon_loss = mse_loss(reconstruction, targets)
# KL divergence (assuming standard normal prior)
kl_loss = -0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar))
kl_loss = kl_loss / targets.shape[0] # Normalize by batch size
# Total VAE loss
return recon_loss + beta * kl_loss
β Parameter:
β = 1.0: Standard VAEβ > 1.0: β-VAE (encourages disentanglement)β < 1.0: Less regularization, better reconstruction
5. GAN Losses¤
Workshop provides pre-configured GAN loss suites:
from workshop.generative_models.core.losses import create_gan_loss_suite
# Create GAN losses
gen_loss, disc_loss = create_gan_loss_suite(
generator_loss_type="lsgan",
discriminator_loss_type="lsgan"
)
# Generator loss (want discriminator to output 1 for fake)
g_loss = gen_loss(fake_scores)
# Discriminator loss (real→1, fake→0)
d_loss = disc_loss(real_scores, fake_scores)
Available GAN Loss Types:
"standard": Binary cross-entropy (original GAN)"lsgan": Least-squares GAN (more stable)"wgan": Wasserstein GAN (requires gradient penalty)
6. Scheduled Losses¤
Time-varying loss weights for curriculum learning:
from workshop.generative_models.core.losses import ScheduledLoss
# Define schedule function
def warmup_schedule(step):
"""Linear warmup from 0 to 1 over 1000 steps."""
return jnp.minimum(1.0, step / 1000.0)
# Create scheduled loss
scheduled_loss = ScheduledLoss(
loss_fn=perceptual_loss,
schedule_fn=warmup_schedule,
name="scheduled_perceptual"
)
# Loss weight increases with training steps
loss_value = scheduled_loss(..., step=500) # weight = 0.5
7. Geometric Losses¤
Specialized losses for 3D data:
Chamfer Distance¤
Measures point cloud similarity:
from workshop.generative_models.core.losses import chamfer_distance
# Point clouds: (batch, num_points, 3)
pred_points = jax.random.normal(key, (4, 1000, 3))
target_points = jax.random.normal(key, (4, 1000, 3))
loss = chamfer_distance(pred_points, target_points)
Mesh Loss¤
Multi-component loss for mesh quality:
from workshop.generative_models.core.losses import MeshLoss
mesh_loss = MeshLoss(
vertex_weight=1.0, # Vertex position accuracy
normal_weight=0.1, # Surface normal consistency
edge_weight=0.1, # Edge length preservation
laplacian_weight=0.01 # Smoothness regularization
)
# Mesh format: (vertices, faces, normals)
pred_mesh = (vertices_pred, faces, normals_pred)
target_mesh = (vertices_target, faces, normals_target)
loss = mesh_loss(pred_mesh, target_mesh)
8. Perceptual Loss¤
Feature-based loss using pre-trained networks:
from workshop.generative_models.core.losses import PerceptualLoss
perceptual = PerceptualLoss(
content_weight=1.0,
style_weight=0.01
)
# Requires feature extraction from images
loss = perceptual(
pred_images=generated_images,
target_images=real_images,
features_pred=extracted_features_pred,
features_target=extracted_features_target
)
9. Total Variation Loss¤
Smoothness regularization for images:
from workshop.generative_models.core.losses import total_variation_loss
# Encourages spatial smoothness
tv_loss = total_variation_loss(generated_images)
# Often combined with other losses
total_loss = reconstruction_loss + 0.001 * tv_loss
Code Structure¤
The example demonstrates seven loss usage patterns:
- Functional Usage - Simple MSE and MAE losses
- Composable Loss - Weighted loss combination
- VAE Training - Reconstruction + KL divergence
- GAN Training - Generator and discriminator losses
- Scheduled Loss - Progressive loss weight ramping
- Geometric Losses - Chamfer distance and mesh losses
- Complete Training - Full training loop with losses
Features Demonstrated¤
- ✅ Functional losses with flexible reduction modes
- ✅ Weighted loss composition with component tracking
- ✅ VAE loss (reconstruction + KL divergence)
- ✅ GAN loss suites (standard, LS-GAN, Wasserstein)
- ✅ Scheduled losses for curriculum learning
- ✅ Geometric losses for 3D data (Chamfer, mesh)
- ✅ Perceptual loss with feature extraction
- ✅ Total variation loss for smoothness
- ✅ Integration with Flax NNX training loops
Experiments to Try¤
- Adjust Loss Weights
# Try different β values for VAE
composite = CompositeLoss([
WeightedLoss(recon_loss, weight=1.0, name="recon"),
WeightedLoss(kl_loss, weight=4.0, name="kl"), # β = 4.0
])
- Compare GAN Loss Types
# Standard GAN
gen_loss, disc_loss = create_gan_loss_suite("standard", "standard")
# LS-GAN (often more stable)
gen_loss, disc_loss = create_gan_loss_suite("lsgan", "lsgan")
- Custom Schedule Functions
# Exponential warmup
def exp_schedule(step):
return 1.0 - jnp.exp(-step / 1000.0)
# Cosine annealing
def cosine_schedule(step):
return 0.5 * (1 + jnp.cos(jnp.pi * step / total_steps))
- Geometric Loss Weights
# Adjust mesh loss components
mesh_loss = MeshLoss(
vertex_weight=2.0, # Emphasize position accuracy
normal_weight=0.5, # More weight on normals
edge_weight=0.1,
laplacian_weight=0.01
)
Next Steps¤
-
VAE Examples
Apply losses in VAE training
-
GAN Examples
Use GAN losses in training
-
Geometric Models
Apply geometric losses
-
Framework Features
Understand composable design
Troubleshooting¤
Shape Mismatch Errors¤
Symptom: ValueError about incompatible shapes
Solution: Ensure predictions and targets have the same shape
print(f"Predictions: {predictions.shape}")
print(f"Targets: {targets.shape}")
# Reshape if needed
predictions = predictions.reshape(targets.shape)
NaN in KL Divergence¤
Symptom: KL loss becomes NaN during VAE training
Cause: Numerical instability in exp(logvar) for large logvar
Solution: Clip logvar values
logvar = jnp.clip(logvar, -10.0, 10.0)
kl_loss = -0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar))
GAN Loss Not Converging¤
Symptom: Generator or discriminator loss diverges
Solution: Try LS-GAN loss instead of standard GAN
Composite Loss Component Mismatch¤
Symptom: KeyError when accessing loss components
Solution: Set return_components=True in CompositeLoss
composite = CompositeLoss([...], return_components=True)
total, components = composite(pred, target) # Returns tuple
Additional Resources¤
Documentation¤
- Loss Functions API Reference - Complete loss function documentation
- VAE Loss Theory - Mathematical derivation of VAE loss
- GAN Training Guide - Best practices for GAN losses
Related Examples¤
- Framework Features Demo - Composable loss system
- VAE MNIST Tutorial - VAE loss in practice
- GAN MNIST Tutorial - GAN loss in practice
- Geometric Benchmark - Geometric losses
Papers¤
- VAE: Auto-Encoding Variational Bayes (Kingma & Welling, 2013)
- β-VAE: β-VAE: Learning Basic Visual Concepts (Higgins et al., 2017)
- LS-GAN: Least Squares GAN (Mao et al., 2017)
- Perceptual Loss: Perceptual Losses (Johnson et al., 2016)
- Chamfer Distance: Learning Representations and Generative Models for 3D Point Clouds