Loss Functions¤
Workshop provides a comprehensive catalog of loss functions for training generative models. These losses are organized into categories based on their purpose and mathematical foundation.
Overview¤
-
40+ Loss Functions
Extensive catalog covering all generative model types
-
Composable Framework
Easily combine multiple losses with weights and scheduling
-
JAX-Optimized
Fully JIT-compatible and vectorized implementations
-
Numerically Stable
Careful handling of edge cases and numerical precision
Loss Categories¤
graph TD
A[Loss Functions] --> B[Reconstruction Losses]
A --> C[Adversarial Losses]
A --> D[Divergence Losses]
A --> E[Perceptual Losses]
A --> F[Composable Framework]
B --> B1[MSE]
B --> B2[MAE]
B --> B3[Huber]
B --> B4[Charbonnier]
C --> C1[Vanilla GAN]
C --> C2[LSGAN]
C --> C3[WGAN]
C --> C4[Hinge]
D --> D1[KL Divergence]
D --> D2[JS Divergence]
D --> D3[Wasserstein]
D --> D4[MMD]
E --> E1[Feature Reconstruction]
E --> E2[Style Loss]
E --> E3[Contextual Loss]
F --> F1[CompositeLoss]
F --> F2[WeightedLoss]
F --> F3[ScheduledLoss]
style A fill:#e1f5ff
style B fill:#b3e5fc
style C fill:#b3e5fc
style D fill:#b3e5fc
style E fill:#b3e5fc
style F fill:#81d4fa
Reconstruction Losses¤
Reconstruction losses compare predictions directly with targets, typically used in autoencoders and regression tasks.
Location: src/workshop/generative_models/core/losses/reconstruction.py
mse_loss (L2 Loss)¤
Mean Squared Error - penalizes squared differences between predictions and targets.
from workshop.generative_models.core.losses.reconstruction import mse_loss
import jax.numpy as jnp
predictions = jnp.array([1.0, 2.0, 3.0])
targets = jnp.array([1.1, 1.9, 3.2])
loss = mse_loss(predictions, targets, reduction="mean")
print(f"MSE Loss: {loss}") # 0.0233...
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
predictions |
jax.Array |
Required | Model predictions |
targets |
jax.Array |
Required | Ground truth values |
reduction |
str |
"mean" |
Reduction: "none", "mean", "sum" |
weights |
jax.Array \| None |
None |
Optional per-element weights |
axis |
int \| tuple \| None |
None |
Axis for reduction |
Use Cases:
- VAE reconstruction loss
- Image regression
- General reconstruction tasks
mae_loss (L1 Loss)¤
Mean Absolute Error - more robust to outliers than MSE.
from workshop.generative_models.core.losses.reconstruction import mae_loss
loss = mae_loss(predictions, targets, reduction="mean")
print(f"MAE Loss: {loss}") # 0.133...
Use Cases:
- Robust regression
- Outlier-heavy datasets
- Image reconstruction with noise
huber_loss¤
Smooth combination of L1 and L2 losses - quadratic for small errors, linear for large errors.
from workshop.generative_models.core.losses.reconstruction import huber_loss
# Delta controls the transition point
loss = huber_loss(predictions, targets, delta=1.0, reduction="mean")
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
delta |
float |
1.0 |
Threshold between quadratic and linear regions |
Use Cases:
- Robust regression with outliers
- Reinforcement learning (value functions)
- Object detection
charbonnier_loss¤
Differentiable variant of L1 loss with smoother gradients.
from workshop.generative_models.core.losses.reconstruction import charbonnier_loss
loss = charbonnier_loss(
predictions,
targets,
epsilon=1e-3, # Smoothing constant
alpha=1.0, # Exponent
reduction="mean"
)
Use Cases:
- Optical flow estimation
- Image super-resolution
- Smooth optimization landscapes
psnr_loss¤
Peak Signal-to-Noise Ratio expressed as a loss (negative PSNR).
from workshop.generative_models.core.losses.reconstruction import psnr_loss
# For normalized images (0-1)
loss = psnr_loss(pred_images, target_images, max_value=1.0)
# For images in range [0, 255]
loss = psnr_loss(pred_images, target_images, max_value=255.0)
Use Cases:
- Image quality assessment
- Super-resolution evaluation
- Compression evaluation
Adversarial Losses¤
Adversarial losses for training GANs and adversarial networks.
Location: src/workshop/generative_models/core/losses/adversarial.py
Vanilla GAN Losses¤
Original GAN formulation with binary cross-entropy.
from workshop.generative_models.core.losses.adversarial import (
vanilla_generator_loss,
vanilla_discriminator_loss
)
# Generator loss: -log(D(G(z)))
g_loss = vanilla_generator_loss(fake_scores)
# Discriminator loss: -log(D(x)) - log(1 - D(G(z)))
d_loss = vanilla_discriminator_loss(real_scores, fake_scores)
Pros: Simple, well-studied Cons: Vanishing gradients, mode collapse
LSGAN Losses¤
Least Squares GAN - uses mean squared error instead of cross-entropy.
from workshop.generative_models.core.losses.adversarial import (
least_squares_generator_loss,
least_squares_discriminator_loss
)
# Generator: minimize (D(G(z)) - 1)^2
g_loss = least_squares_generator_loss(
fake_scores,
target_real=1.0
)
# Discriminator: minimize (D(x) - 1)^2 + (D(G(z)) - 0)^2
d_loss = least_squares_discriminator_loss(
real_scores,
fake_scores,
target_real=1.0,
target_fake=0.0
)
Pros: More stable training, better gradients Cons: May require more careful tuning
Wasserstein GAN Losses¤
Wasserstein distance-based losses with better convergence properties.
from workshop.generative_models.core.losses.adversarial import (
wasserstein_generator_loss,
wasserstein_discriminator_loss
)
# Generator: minimize -D(G(z))
g_loss = wasserstein_generator_loss(fake_scores)
# Critic: minimize D(G(z)) - D(x)
d_loss = wasserstein_discriminator_loss(real_scores, fake_scores)
# Note: Requires gradient penalty or weight clipping
Pros: Stable training, meaningful loss curves Cons: Requires gradient penalty or weight clipping
Hinge Losses¤
Hinge loss formulation used in spectral normalization GANs.
from workshop.generative_models.core.losses.adversarial import (
hinge_generator_loss,
hinge_discriminator_loss
)
# Generator: -D(G(z))
g_loss = hinge_generator_loss(fake_scores)
# Discriminator: max(0, 1 - D(x)) + max(0, 1 + D(G(z)))
d_loss = hinge_discriminator_loss(real_scores, fake_scores)
Pros: Stable, works well with spectral normalization Cons: May need careful architecture design
Divergence Losses¤
Statistical divergence measures between probability distributions.
Location: src/workshop/generative_models/core/losses/divergence.py
kl_divergence¤
Kullback-Leibler divergence - measures information loss.
from workshop.generative_models.core.losses.divergence import kl_divergence
import distrax
# With distribution objects
p = distrax.Normal(loc=0.0, scale=1.0)
q = distrax.Normal(loc=0.5, scale=1.5)
kl = kl_divergence(p, q)
# With probability arrays
p_probs = jnp.array([0.2, 0.5, 0.3])
q_probs = jnp.array([0.3, 0.4, 0.3])
kl = kl_divergence(p_probs, q_probs, reduction="sum")
Formula: KL(P||Q) = Σ P(x) log(P(x) / Q(x))
Use Cases:
- VAE latent regularization
- Distribution matching
- Information theory applications
js_divergence¤
Jensen-Shannon divergence - symmetric variant of KL divergence.
from workshop.generative_models.core.losses.divergence import js_divergence
p = jnp.array([0.2, 0.5, 0.3])
q = jnp.array([0.1, 0.7, 0.2])
js = js_divergence(p, q)
Formula: JS(P||Q) = 0.5 (KL(P||M) + KL(Q||M)) where M = 0.5 (P + Q)
Properties:
- Symmetric: JS(P||Q) = JS(Q||P)
- Bounded: 0 ≤ JS ≤ log(2)
- Metric (satisfies triangle inequality)
wasserstein_distance¤
Earth Mover's Distance - optimal transport-based metric.
from workshop.generative_models.core.losses.divergence import wasserstein_distance
# 1D samples
p_samples = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
q_samples = jnp.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]])
# W1 distance
w1 = wasserstein_distance(p_samples, q_samples, p=1, axis=1)
# W2 distance
w2 = wasserstein_distance(p_samples, q_samples, p=2, axis=1)
Use Cases:
- GAN training (WGAN)
- Distribution comparison
- Robust statistics
maximum_mean_discrepancy¤
Kernel-based distribution distance measure.
from workshop.generative_models.core.losses.divergence import maximum_mean_discrepancy
import jax
# Generate samples
key = jax.random.key(0)
pred_samples = jax.random.normal(key, (2, 100, 5))
target_samples = jax.random.normal(key, (2, 100, 5))
# Compute MMD with RBF kernel
mmd = maximum_mean_discrepancy(
pred_samples,
target_samples,
kernel_type="rbf",
kernel_bandwidth=1.0
)
Kernel Types:
"rbf": Radial Basis Function (Gaussian)"linear": Linear kernel"polynomial": Polynomial kernel
Use Cases:
- Two-sample testing
- Domain adaptation
- Generative model evaluation
energy_distance¤
Metric between probability distributions based on Euclidean distance.
from workshop.generative_models.core.losses.divergence import energy_distance
energy = energy_distance(
pred_samples,
target_samples,
beta=1.0 # Power parameter (0 < beta <= 2)
)
Properties:
- Metric (satisfies triangle inequality)
- Generalizes Euclidean distance
- Computationally efficient
Perceptual Losses¤
Feature-based losses using pre-trained networks.
Location: src/workshop/generative_models/core/losses/perceptual.py
feature_reconstruction_loss¤
Compares features extracted from intermediate layers.
from workshop.generative_models.core.losses.perceptual import feature_reconstruction_loss
# Dictionary of features (e.g., from VGG)
features_real = {
"conv1": jnp.ones((2, 64, 64, 64)),
"conv2": jnp.ones((2, 32, 32, 128)),
"conv3": jnp.ones((2, 16, 16, 256)),
}
features_fake = {
"conv1": jnp.zeros((2, 64, 64, 64)),
"conv2": jnp.zeros((2, 32, 32, 128)),
"conv3": jnp.zeros((2, 16, 16, 256)),
}
# Weighted feature loss
loss = feature_reconstruction_loss(
features_real,
features_fake,
weights={"conv1": 1.0, "conv2": 0.5, "conv3": 0.25}
)
Use Cases:
- Image-to-image translation
- Style transfer (content loss)
- Super-resolution
style_loss¤
Gram matrix-based style matching.
from workshop.generative_models.core.losses.perceptual import style_loss
# Captures texture/style information
loss = style_loss(
features_real,
features_fake,
weights={"conv1": 1.0, "conv2": 1.0, "conv3": 1.0}
)
How it works:
- Computes Gram matrices of features
- Measures distance between Gram matrices
- Captures correlations between feature channels
Use Cases:
- Style transfer
- Texture synthesis
- Artistic image generation
contextual_loss¤
Robust to spatial misalignments, measures distributional similarity.
from workshop.generative_models.core.losses.perceptual import contextual_loss
# Single-layer features
feat_real = jnp.ones((2, 32, 32, 128))
feat_fake = jnp.zeros((2, 32, 32, 128))
loss = contextual_loss(
feat_real,
feat_fake,
band_width=0.1,
max_samples=512 # Memory-efficient
)
Use Cases:
- Non-aligned image matching
- Texture transfer
- Semantic image editing
PerceptualLoss Module¤
Composable NNX module combining multiple perceptual losses.
from workshop.generative_models.core.losses.perceptual import PerceptualLoss
from flax import nnx
# Create perceptual loss module
perceptual = PerceptualLoss(
feature_extractor=vgg_model, # Pre-trained VGG
layer_weights={
"conv1_2": 1.0,
"conv2_2": 1.0,
"conv3_3": 1.0,
"conv4_3": 1.0,
},
content_weight=1.0,
style_weight=10.0,
contextual_weight=0.1,
)
# Compute combined loss
loss = perceptual(pred_images, target_images)
Composable Loss Framework¤
NNX-based framework for combining multiple losses.
Location: src/workshop/generative_models/core/losses/composable.py
WeightedLoss¤
Apply a weight to any loss function.
from workshop.generative_models.core.losses.composable import WeightedLoss
from workshop.generative_models.core.losses.reconstruction import mse_loss
# Wrap loss with weight
weighted_mse = WeightedLoss(
loss_fn=mse_loss,
weight=0.5,
name="weighted_mse"
)
loss = weighted_mse(predictions, targets) # 0.5 * MSE
CompositeLoss¤
Combine multiple losses into a single module.
from workshop.generative_models.core.losses.composable import CompositeLoss, WeightedLoss
from workshop.generative_models.core.losses.reconstruction import mse_loss, mae_loss
# Create weighted losses
loss1 = WeightedLoss(mse_loss, weight=1.0, name="mse")
loss2 = WeightedLoss(mae_loss, weight=0.5, name="mae")
# Combine losses
composite = CompositeLoss(
losses=[loss1, loss2],
return_components=True
)
# Returns total loss and individual components
total_loss, loss_dict = composite(predictions, targets)
print(f"Total: {total_loss}")
print(f"Components: {loss_dict}") # {"mse": ..., "mae": ...}
ScheduledLoss¤
Dynamic loss weight scheduling for curriculum learning.
from workshop.generative_models.core.losses.composable import ScheduledLoss
def warmup_schedule(step):
"""Linear warmup over 1000 steps."""
return min(1.0, step / 1000.0)
scheduled_loss = ScheduledLoss(
loss_fn=mse_loss,
schedule_fn=warmup_schedule,
name="scheduled_mse"
)
# Weight increases over time
loss_step_0 = scheduled_loss(pred, target, step=0) # weight=0.0
loss_step_500 = scheduled_loss(pred, target, step=500) # weight=0.5
loss_step_1000 = scheduled_loss(pred, target, step=1000) # weight=1.0
Base Utilities¤
Helper classes and functions for loss management.
Location: src/workshop/generative_models/core/losses/base.py
LossCollection¤
Functional collection of multiple losses (non-NNX).
from workshop.generative_models.core.losses.base import LossCollection
from workshop.generative_models.core.losses.reconstruction import mse_loss, mae_loss
# Create collection
collection = LossCollection()
collection.add(mse_loss, weight=1.0, name="mse")
collection.add(mae_loss, weight=0.5, name="mae")
# Compute total loss
total_loss, loss_dict = collection(predictions, targets)
LossScheduler¤
Dynamically adjust loss weights during training.
from workshop.generative_models.core.losses.base import LossScheduler
# Initialize with initial weights
scheduler = LossScheduler({
"reconstruction": 1.0,
"kl_divergence": 0.1,
"perceptual": 0.5,
})
# Add linear warmup for KL divergence
scheduler.add_schedule(
"kl_divergence",
LossScheduler.linear_warmup(warmup_steps=1000, max_weight=1.0)
)
# Add cosine annealing for perceptual loss
scheduler.add_schedule(
"perceptual",
LossScheduler.cosine_annealing(period=5000, min_weight=0.1, max_weight=1.0)
)
# Update weights at each step
weights = scheduler.update(step=500)
print(weights) # Updated weights based on schedules
LossMetrics¤
Track loss values during training with exponential moving average.
from workshop.generative_models.core.losses.base import LossMetrics
# Create metrics tracker
metrics = LossMetrics(momentum=0.99)
# Update with loss values
for step in range(100):
loss_dict = {
"total": jnp.array(1.5),
"reconstruction": jnp.array(1.0),
"kl": jnp.array(0.5),
}
metrics.update(loss_dict)
# Get smoothed metrics
current_metrics = metrics.get_metrics()
print(current_metrics) # Smoothed loss values
Common Patterns¤
Pattern 1: VAE Loss¤
from workshop.generative_models.core.losses.reconstruction import mse_loss
from workshop.generative_models.core.losses.divergence import kl_divergence
import distrax
def vae_loss(x, reconstruction, mean, logvar):
"""Complete VAE loss."""
# Reconstruction loss
recon_loss = mse_loss(reconstruction, x, reduction="mean")
# KL divergence
posterior = distrax.MultivariateNormalDiag(mean, jnp.exp(0.5 * logvar))
prior = distrax.MultivariateNormalDiag(
jnp.zeros_like(mean),
jnp.ones_like(logvar)
)
kl_loss = kl_divergence(posterior, prior, reduction="mean")
# Total loss
total_loss = recon_loss + 0.5 * kl_loss
return {
"loss": total_loss,
"reconstruction_loss": recon_loss,
"kl_loss": kl_loss,
}
Pattern 2: GAN with Multiple Losses¤
from workshop.generative_models.core.losses.adversarial import (
hinge_generator_loss,
hinge_discriminator_loss
)
from workshop.generative_models.core.losses.perceptual import PerceptualLoss
# Perceptual loss for generator
perceptual_loss = PerceptualLoss(
feature_extractor=vgg,
content_weight=1.0,
style_weight=10.0,
)
def generator_loss(fake_images, real_images, fake_scores):
"""Complete generator loss."""
# Adversarial loss
adv_loss = hinge_generator_loss(fake_scores)
# Perceptual loss
perc_loss = perceptual_loss(fake_images, real_images)
# Total loss
total = adv_loss + 0.1 * perc_loss
return {
"loss": total,
"adversarial": adv_loss,
"perceptual": perc_loss,
}
Pattern 3: Composable Multi-Loss System¤
from workshop.generative_models.core.losses.composable import (
CompositeLoss,
WeightedLoss,
ScheduledLoss
)
# Create individual weighted losses
recon_loss = WeightedLoss(mse_loss, weight=1.0, name="reconstruction")
perc_loss = WeightedLoss(perceptual_fn, weight=0.1, name="perceptual")
# Scheduled KL loss with warmup
kl_loss = ScheduledLoss(
loss_fn=kl_divergence_fn,
schedule_fn=lambda step: min(1.0, step / 5000),
name="kl"
)
# Combine all losses
composite = CompositeLoss(
losses=[recon_loss, perc_loss, kl_loss],
return_components=True
)
# Use in training
total_loss, components = composite(batch, outputs, step=current_step)
Best Practices¤
DO¤
- ✅ Use
reduction="mean"for stable gradients - ✅ Scale losses to similar magnitudes when combining
- ✅ Use perceptual losses for visual quality
- ✅ Monitor individual loss components during training
- ✅ Use numerically stable variants (safe_log, safe_divide)
- ✅ Normalize inputs when using distance-based losses
- ✅ Use gradient clipping with adversarial losses
DON'T¤
- ❌ Mix different reduction types without careful consideration
- ❌ Use very different loss magnitudes without weighting
- ❌ Forget to normalize features for perceptual losses
- ❌ Use high-dimensional MMD without subsampling
- ❌ Apply losses on unnormalized data
- ❌ Ignore numerical stability (use eps parameters)
Performance Tips¤
Memory-Efficient Perceptual Loss¤
# Use lower resolution or fewer samples
perceptual = PerceptualLoss(
max_contextual_samples=256, # Reduce memory usage
...
)
Vectorized Loss Computation¤
# Use JAX's vmap for batched loss computation
from flax import nnx
@nnx.jit
def batch_loss(predictions, targets):
return nnx.vmap(mse_loss)(predictions, targets)
Gradient Accumulation for Large Losses¤
# For very large composite losses
@nnx.jit
@nnx.grad
def loss_with_accumulation(params, batch):
# Compute losses in chunks
...
Troubleshooting¤
Issue: "NaN in loss computation"¤
Solutions:
- Use epsilon parameters for numerical stability
- Check input normalization
- Use
safe_logandsafe_divideutilities - Clip gradients
from workshop.generative_models.core.losses.base import safe_log, safe_divide
# Instead of jnp.log(x)
log_x = safe_log(x, eps=1e-8)
# Instead of x / y
ratio = safe_divide(x, y, eps=1e-8)
Issue: "Loss values differ greatly in magnitude"¤
Solution: Scale losses to similar ranges:
# Bad: losses differ by orders of magnitude
total = recon_loss + kl_loss # e.g., 100.0 + 0.01
# Good: scale to similar magnitudes
total = recon_loss + 10.0 * kl_loss # e.g., 100.0 + 0.1
Issue: "Perceptual loss uses too much memory"¤
Solution: Reduce sample count and feature resolution:
Next Steps¤
-
Base Classes
Learn about model base classes
-
Device Management
Optimize loss computation on GPU/TPU
-
Training
Use losses in training loops
References¤
- Source Code:
src/workshop/generative_models/core/losses/ - Tests:
tests/workshop/generative_models/core/losses/ - Research Papers:
- GAN Paper
- LSGAN Paper
- WGAN Paper
- Perceptual Losses Paper