GAN User Guide¤
This guide provides practical instructions for training and using Generative Adversarial Networks (GANs) in Workshop. We cover all GAN variants, training strategies, common issues, and best practices.
Quick Start¤
Here's a minimal example to get you started:
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.gan import GAN, Generator, Discriminator
# Initialize RNG
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
# Create simple config
class GANConfig:
latent_dim = 100
loss_type = "vanilla" # or "wasserstein", "least_squares", "hinge"
# Generator config
class generator:
hidden_dims = [256, 512]
output_shape = (1, 1, 28, 28) # MNIST shape
activation = "relu"
batch_norm = True
dropout_rate = 0.0
# Discriminator config
class discriminator:
hidden_dims = [512, 256]
activation = "leaky_relu"
leaky_relu_slope = 0.2
batch_norm = False
dropout_rate = 0.3
use_spectral_norm = False
config = GANConfig()
# Create GAN
gan = GAN(config, rngs=rngs)
# Generate samples
samples = gan.generate(n_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}") # (16, 1, 28, 28)
Creating GAN Components¤
Basic Generator¤
The generator transforms random noise into data samples:
from workshop.generative_models.models.gan import Generator
# Create generator
generator = Generator(
hidden_dims=[128, 256, 512], # Hidden layer sizes
output_shape=(1, 3, 32, 32), # Output: 3 channels, 32x32 images
latent_dim=100, # Latent space dimension
activation="relu", # Activation function
batch_norm=True, # Use batch normalization
dropout_rate=0.0, # Dropout rate (usually 0 for generator)
rngs=rngs,
)
# Generate samples from random noise
z = jax.random.normal(rngs.params(), (batch_size, latent_dim))
fake_samples = generator(z, training=True)
Key Parameters:
hidden_dims: List of hidden layer dimensions (progressively increases capacity)output_shape: Target data shape (batch, channels, height, width)latent_dim: Size of input latent vector (typically 64-512)batch_norm: Stabilizes training (recommended for generator)activation: "relu" for generator, "leaky_relu" for discriminator
Basic Discriminator¤
The discriminator classifies samples as real or fake:
from workshop.generative_models.models.gan import Discriminator
# Create discriminator
discriminator = Discriminator(
hidden_dims=[512, 256, 128], # Hidden layer sizes (often mirrors generator)
activation="leaky_relu", # LeakyReLU prevents dying neurons
leaky_relu_slope=0.2, # Negative slope for LeakyReLU
batch_norm=False, # Usually False for discriminator
dropout_rate=0.3, # Dropout to prevent overfitting
use_spectral_norm=False, # Spectral normalization for stability
rngs=rngs,
)
# Classify samples
real_data = jnp.ones((batch_size, 3, 32, 32))
fake_data = generator(z, training=True)
real_scores = discriminator(real_data, training=True) # Should be close to 1
fake_scores = discriminator(fake_data, training=True) # Should be close to 0
Key Parameters:
hidden_dims: Usually mirrors generator in reverseactivation: "leaky_relu" is standard (slope 0.2)batch_norm: Usually False (can cause training issues)dropout_rate: 0.3-0.5 helps prevent overfittinguse_spectral_norm: Improves training stability
GAN Variants¤
1. Vanilla GAN¤
The original GAN formulation with binary cross-entropy loss:
from workshop.generative_models.models.gan import GAN
class VanillaGANConfig:
latent_dim = 100
loss_type = "vanilla"
class generator:
hidden_dims = [256, 512]
output_shape = (1, 1, 28, 28)
activation = "relu"
batch_norm = True
dropout_rate = 0.0
class discriminator:
hidden_dims = [512, 256]
activation = "leaky_relu"
leaky_relu_slope = 0.2
batch_norm = False
dropout_rate = 0.3
use_spectral_norm = False
config = VanillaGANConfig()
gan = GAN(config, rngs=rngs)
# Training step
def train_step(gan, batch, rngs):
# Compute loss
losses = gan.loss_fn(batch, None, rngs=rngs)
return losses["loss"], losses
loss, metrics = train_step(gan, batch_data, rngs)
When to use:
- Learning GANs for the first time
- Simple datasets (MNIST, simple shapes)
- Proof-of-concept experiments
Pros: Simple, well-understood Cons: Training instability, mode collapse
2. Deep Convolutional GAN (DCGAN)¤
Uses convolutional architecture for images:
from workshop.generative_models.models.gan import (
DCGAN,
DCGANGenerator,
DCGANDiscriminator,
)
# Create DCGAN components directly
generator = DCGANGenerator(
output_shape=(3, 64, 64), # 3 channels, 64x64 output
latent_dim=100,
hidden_dims=(256, 128, 64, 32), # Progressive channel reduction
activation=jax.nn.relu,
batch_norm=True,
dropout_rate=0.0,
rngs=rngs,
)
discriminator = DCGANDiscriminator(
input_shape=(3, 64, 64),
hidden_dims=(32, 64, 128, 256), # Progressive channel increase
activation=jax.nn.leaky_relu,
leaky_relu_slope=0.2,
batch_norm=False, # DCGAN: no batch norm in discriminator
dropout_rate=0.3,
use_spectral_norm=True, # Recommended for stability
rngs=rngs,
)
# Or use the full DCGAN model
from workshop.generative_models.core.configuration.gan import DCGANConfiguration
dcgan_config = DCGANConfiguration(
image_size=64,
channels=3,
latent_dim=100,
gen_hidden_dims=(256, 128, 64, 32),
disc_hidden_dims=(32, 64, 128, 256),
loss_type="vanilla",
generator_lr=0.0002,
discriminator_lr=0.0002,
beta1=0.5,
beta2=0.999,
)
dcgan = DCGAN(dcgan_config, rngs=rngs)
# Generate high-quality images
samples = dcgan.generate(n_samples=64, rngs=rngs)
DCGAN Architecture Guidelines:
- Replace pooling with strided convolutions
- Use batch normalization (except discriminator input and generator output)
- Remove fully connected layers (except for latent projection)
- Use ReLU in generator, LeakyReLU in discriminator
- Use Tanh activation in generator output
When to use:
- Image generation tasks
- 64×64 to 128×128 resolution
- More stable training than vanilla GAN
Pros: More stable, better image quality Cons: Still can suffer from mode collapse
3. Wasserstein GAN (WGAN)¤
Uses Wasserstein distance for more stable training:
from workshop.generative_models.models.gan import (
WGAN,
WGANGenerator,
WGANDiscriminator,
compute_gradient_penalty,
)
# Create WGAN model
from workshop.generative_models.core.configuration import ModelConfiguration
wgan_config = ModelConfiguration(
input_dim=100, # Latent dimension
output_dim=(3, 64, 64), # Output image shape
hidden_dims=None, # Will use defaults
metadata={
"gan_params": {
"gen_hidden_dims": (1024, 512, 256),
"disc_hidden_dims": (256, 512, 1024),
"gradient_penalty_weight": 10.0, # Lambda for gradient penalty
"critic_iterations": 5, # Update critic 5x per generator
}
}
)
wgan = WGAN(wgan_config, rngs=rngs)
# Training loop for WGAN
def train_wgan_step(wgan, real_samples, rngs, n_critic=5):
"""Train WGAN with proper critic/generator balance."""
# Train critic n_critic times
for _ in range(n_critic):
# Sample latent vectors
z = jax.random.normal(rngs.sample(), (real_samples.shape[0], wgan.latent_dim))
# Generate fake samples
fake_samples = wgan.generator(z, training=True)
# Compute discriminator loss with gradient penalty
disc_loss = wgan.discriminator_loss(real_samples, fake_samples, rngs)
# Update discriminator
# (In practice, use nnx.Optimizer)
# Train generator once
z = jax.random.normal(rngs.sample(), (real_samples.shape[0], wgan.latent_dim))
fake_samples = wgan.generator(z, training=True)
gen_loss = wgan.generator_loss(fake_samples)
# Update generator
return {"disc_loss": disc_loss, "gen_loss": gen_loss}
Key Differences from Vanilla GAN:
- Critic instead of discriminator (no sigmoid at output)
- Wasserstein distance instead of JS divergence
- Gradient penalty enforces Lipschitz constraint
- Multiple critic updates per generator update (5:1 ratio)
- Instance normalization instead of batch norm in critic
When to use:
- Need stable training
- Want meaningful loss metric
- High-resolution images
- Research experiments
Pros: Very stable, meaningful loss, better mode coverage Cons: Slower training, more complex
4. Least Squares GAN (LSGAN)¤
Uses least squares loss for smoother gradients:
from workshop.generative_models.models.gan import LSGAN, LSGANGenerator, LSGANDiscriminator
# Create LSGAN (similar interface to base GAN)
class LSGANConfig:
latent_dim = 100
loss_type = "least_squares" # Key difference
class generator:
hidden_dims = [256, 512]
output_shape = (1, 3, 32, 32)
activation = "relu"
batch_norm = True
dropout_rate = 0.0
class discriminator:
hidden_dims = [512, 256]
activation = "leaky_relu"
leaky_relu_slope = 0.2
batch_norm = False
dropout_rate = 0.3
use_spectral_norm = False
lsgan_config = LSGANConfig()
lsgan = GAN(lsgan_config, rngs=rngs) # Can use base GAN with loss_type
# Or use dedicated LSGAN classes
generator = LSGANGenerator(
output_shape=(3, 64, 64),
latent_dim=100,
rngs=rngs,
)
discriminator = LSGANDiscriminator(
input_shape=(3, 64, 64),
rngs=rngs,
)
# Training is similar to vanilla GAN
losses = lsgan.loss_fn(batch, None, rngs=rngs)
Key Difference:
Loss function uses squared error instead of log loss:
- Generator: Minimize \((D(G(z)) - 1)^2\)
- Discriminator: Minimize \((D(x) - 1)^2 + D(G(z))^2\)
When to use:
- Want smoother gradients than vanilla GAN
- Need more stable training than vanilla
- Image generation with less training instability
Pros: More stable than vanilla, penalizes far-from-boundary samples Cons: Still can mode collapse
5. Conditional GAN (cGAN)¤
Conditions generation on labels or other information:
from workshop.generative_models.models.gan import (
ConditionalGAN,
ConditionalGenerator,
ConditionalDiscriminator,
)
# Create conditional generator
cond_generator = ConditionalGenerator(
output_shape=(1, 28, 28),
latent_dim=100,
num_classes=10, # MNIST has 10 classes
hidden_dims=[256, 512],
embedding_dim=50, # Class embedding size
rngs=rngs,
)
# Create conditional discriminator
cond_discriminator = ConditionalDiscriminator(
input_shape=(1, 28, 28),
num_classes=10,
hidden_dims=[512, 256],
embedding_dim=50,
rngs=rngs,
)
# Generate conditioned on class label
labels = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # One of each digit
z = jax.random.normal(rngs.sample(), (10, 100))
# Generate specific digits
samples = cond_generator(z, labels, training=False)
# Discriminate with labels
real_data = load_mnist_batch()
real_labels = jnp.array([...]) # True labels
real_scores = cond_discriminator(real_data, real_labels, training=True)
fake_scores = cond_discriminator(samples, labels, training=True)
Key Features:
- Controlled generation: Specify what to generate
- Class conditioning: Generate specific categories
- Embedding layer: Maps labels to high-dimensional space
- Concatenation: Combines embeddings with features
When to use:
- Need to control generation (class, attributes)
- Have labeled data
- Want to generate specific categories
- Image-to-image translation with labels
Pros: Controlled generation, useful for labeled datasets Cons: Requires labels, more complex
6. CycleGAN¤
Unpaired image-to-image translation:
from workshop.generative_models.models.gan import (
CycleGAN,
CycleGANGenerator,
CycleGANDiscriminator,
)
# Create CycleGAN for domain transfer (e.g., horse ↔ zebra)
cyclegan = CycleGAN(
input_shape_x=(3, 256, 256), # Domain X (horses)
input_shape_y=(3, 256, 256), # Domain Y (zebras)
gen_hidden_dims=[64, 128, 256],
disc_hidden_dims=[64, 128, 256],
cycle_weight=10.0, # Cycle consistency weight
identity_weight=0.5, # Identity loss weight
rngs=rngs,
)
# Training step
def train_cyclegan_step(cyclegan, batch_x, batch_y, rngs):
"""Train CycleGAN with cycle consistency."""
# Forward cycle: X -> Y -> X
fake_y = cyclegan.generator_g(batch_x, training=True)
reconstructed_x = cyclegan.generator_f(fake_y, training=True)
# Backward cycle: Y -> X -> Y
fake_x = cyclegan.generator_f(batch_y, training=True)
reconstructed_y = cyclegan.generator_g(fake_x, training=True)
# Adversarial losses
disc_y_real = cyclegan.discriminator_y(batch_y, training=True)
disc_y_fake = cyclegan.discriminator_y(fake_y, training=True)
disc_x_real = cyclegan.discriminator_x(batch_x, training=True)
disc_x_fake = cyclegan.discriminator_x(fake_x, training=True)
# Cycle consistency losses
cycle_loss_x = jnp.mean(jnp.abs(reconstructed_x - batch_x))
cycle_loss_y = jnp.mean(jnp.abs(reconstructed_y - batch_y))
total_cycle_loss = cyclegan.cycle_weight * (cycle_loss_x + cycle_loss_y)
# Identity losses (optional, helps preserve color)
identity_x = cyclegan.generator_f(batch_x, training=True)
identity_y = cyclegan.generator_g(batch_y, training=True)
identity_loss_x = jnp.mean(jnp.abs(identity_x - batch_x))
identity_loss_y = jnp.mean(jnp.abs(identity_y - batch_y))
total_identity_loss = cyclegan.identity_weight * (identity_loss_x + identity_loss_y)
return {
"cycle_loss": total_cycle_loss,
"identity_loss": total_identity_loss,
"disc_x_loss": disc_loss_x,
"disc_y_loss": disc_loss_y,
}
Key Features:
- Two generators: G: X→Y and F: Y→X
- Two discriminators: D_X and D_Y
- Cycle consistency: x → G(x) → F(G(x)) ≈ x
- No paired data needed
When to use:
- Image-to-image translation without paired data
- Style transfer (photo ↔ painting)
- Domain adaptation (synthetic ↔ real)
- Seasonal changes (summer ↔ winter)
Pros: No paired data needed, flexible Cons: Computationally expensive (4 networks), can fail if domains too different
7. PatchGAN¤
Discriminator operates on image patches:
from workshop.generative_models.models.gan import (
PatchGANDiscriminator,
MultiScalePatchGANDiscriminator,
)
# Single-scale PatchGAN
patch_discriminator = PatchGANDiscriminator(
input_shape=(3, 256, 256),
hidden_dims=[64, 128, 256, 512],
kernel_size=4,
stride=2,
rngs=rngs,
)
# Returns N×N array of patch predictions
patch_scores = patch_discriminator(images, training=True) # Shape: (batch, H', W', 1)
# Multi-scale PatchGAN (better for high-resolution)
multiscale_discriminator = MultiScalePatchGANDiscriminator(
input_shape=(3, 256, 256),
hidden_dims=[64, 128, 256],
num_scales=3, # 3 different scales
rngs=rngs,
)
# Returns predictions at multiple scales
predictions = multiscale_discriminator(images, training=True)
Key Features:
- Patch-based: Classifies overlapping patches
- Local texture: Better for texture quality
- Efficient: Fewer parameters than full-image discriminator
- Multi-scale: Can combine predictions at different resolutions
When to use:
- High-resolution images (>256×256)
- Image-to-image translation (Pix2Pix)
- Focus on local texture quality
- With CycleGAN for better results
Pros: Efficient, good for textures, scales well Cons: May miss global structure issues
Training GANs¤
Basic Training Loop¤
Here's a complete training loop for a vanilla GAN:
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.gan import GAN
# Create model
gan = GAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))
# Create optimizers (separate for generator and discriminator)
gen_optimizer = nnx.Optimizer(
gan.generator,
nnx.adam(learning_rate=0.0002, b1=0.5, b2=0.999)
)
disc_optimizer = nnx.Optimizer(
gan.discriminator,
nnx.adam(learning_rate=0.0002, b1=0.5, b2=0.999)
)
# Training step
@nnx.jit
def train_step(gan, gen_opt, disc_opt, batch, rngs):
"""Single training step for vanilla GAN."""
# Discriminator update
def disc_loss_fn(disc):
# Get generator samples (stop gradient to not update generator)
z = jax.random.normal(rngs.sample(), (batch.shape[0], gan.latent_dim))
fake_samples = gan.generator(z, training=True)
fake_samples = jax.lax.stop_gradient(fake_samples)
# Discriminator scores
real_scores = disc(batch, training=True)
fake_scores = disc(fake_samples, training=True)
# Vanilla GAN discriminator loss
real_loss = -jnp.log(jnp.clip(real_scores, 1e-7, 1.0))
fake_loss = -jnp.log(jnp.clip(1.0 - fake_scores, 1e-7, 1.0))
return jnp.mean(real_loss + fake_loss)
# Compute discriminator loss and update
disc_loss, disc_grads = nnx.value_and_grad(disc_loss_fn)(gan.discriminator)
disc_opt.update(disc_grads)
# Generator update
def gen_loss_fn(gen):
# Generate samples
z = jax.random.normal(rngs.sample(), (batch.shape[0], gan.latent_dim))
fake_samples = gen(z, training=True)
# Get discriminator scores (stop gradient on discriminator)
disc = jax.lax.stop_gradient(gan.discriminator)
fake_scores = disc(fake_samples, training=True)
# Non-saturating generator loss
return -jnp.mean(jnp.log(jnp.clip(fake_scores, 1e-7, 1.0)))
# Compute generator loss and update
gen_loss, gen_grads = nnx.value_and_grad(gen_loss_fn)(gan.generator)
gen_opt.update(gen_grads)
return {
"disc_loss": disc_loss,
"gen_loss": gen_loss,
}
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
# Preprocess: scale to [-1, 1] for tanh output
batch = (batch / 127.5) - 1.0
# Training step
metrics = train_step(gan, gen_optimizer, disc_optimizer, batch, rngs)
# Log metrics
if step % log_interval == 0:
print(f"Epoch {epoch}, Step {step}")
print(f" Discriminator Loss: {metrics['disc_loss']:.4f}")
print(f" Generator Loss: {metrics['gen_loss']:.4f}")
# Generate samples for visualization
if step % sample_interval == 0:
samples = gan.generate(n_samples=16, rngs=rngs)
save_images(samples, f"samples_step_{step}.png")
WGAN Training Loop¤
WGAN requires multiple discriminator updates per generator update:
@nnx.jit
def train_wgan_step(wgan, gen_opt, critic_opt, batch, rngs, n_critic=5):
"""Training step for WGAN-GP."""
# Train critic n_critic times
critic_losses = []
for i in range(n_critic):
def critic_loss_fn(critic):
# Generate fake samples
z = jax.random.normal(rngs.sample(), (batch.shape[0], wgan.latent_dim))
fake_samples = wgan.generator(z, training=True)
fake_samples = jax.lax.stop_gradient(fake_samples)
# Get critic outputs
real_validity = critic(batch, training=True)
fake_validity = critic(fake_samples, training=True)
# Wasserstein loss
wasserstein_distance = jnp.mean(fake_validity) - jnp.mean(real_validity)
# Gradient penalty
alpha = jax.random.uniform(
rngs.sample(),
shape=(batch.shape[0], 1, 1, 1),
minval=0.0,
maxval=1.0
)
interpolated = alpha * batch + (1 - alpha) * fake_samples
def critic_interp_fn(x):
return jnp.sum(critic(x, training=True))
gradients = jax.grad(critic_interp_fn)(interpolated)
gradients = jnp.reshape(gradients, (batch.shape[0], -1))
gradient_norm = jnp.sqrt(jnp.sum(gradients**2, axis=1) + 1e-12)
gradient_penalty = jnp.mean((gradient_norm - 1.0) ** 2) * 10.0
return wasserstein_distance + gradient_penalty
# Update critic
critic_loss, critic_grads = nnx.value_and_grad(critic_loss_fn)(wgan.discriminator)
critic_opt.update(critic_grads)
critic_losses.append(critic_loss)
# Train generator once
def gen_loss_fn(gen):
z = jax.random.normal(rngs.sample(), (batch.shape[0], wgan.latent_dim))
fake_samples = gen(z, training=True)
critic = jax.lax.stop_gradient(wgan.discriminator)
fake_validity = critic(fake_samples, training=True)
# WGAN generator loss: maximize critic output
return -jnp.mean(fake_validity)
gen_loss, gen_grads = nnx.value_and_grad(gen_loss_fn)(wgan.generator)
gen_opt.update(gen_grads)
return {
"critic_loss": jnp.mean(jnp.array(critic_losses)),
"gen_loss": gen_loss,
}
Two-Timescale Update Rule (TTUR)¤
Use different learning rates for generator and discriminator:
# Generator: slower learning rate
gen_optimizer = nnx.Optimizer(
gan.generator,
nnx.adam(learning_rate=0.0001, b1=0.5, b2=0.999) # lr = 0.0001
)
# Discriminator: faster learning rate
disc_optimizer = nnx.Optimizer(
gan.discriminator,
nnx.adam(learning_rate=0.0004, b1=0.5, b2=0.999) # lr = 0.0004
)
Why it works:
- Discriminator needs to stay ahead to provide useful signal
- Prevents generator from overwhelming discriminator
- More stable training dynamics
Generation and Sampling¤
Basic Generation¤
# Generate samples
n_samples = 64
samples = gan.generate(n_samples=n_samples, rngs=rngs)
# Samples are in [-1, 1] range (from Tanh)
# Convert to [0, 255] for visualization
samples = ((samples + 1) / 2 * 255).astype(jnp.uint8)
Latent Space Interpolation¤
Smoothly interpolate between two points in latent space:
def interpolate_latent(gan, z1, z2, num_steps=10, rngs=None):
"""Interpolate between two latent vectors."""
# Create interpolation weights
alphas = jnp.linspace(0, 1, num_steps)
# Interpolate
interpolated_samples = []
for alpha in alphas:
z_interp = alpha * z2 + (1 - alpha) * z1
sample = gan.generator(z_interp[None, :], training=False)
interpolated_samples.append(sample[0])
return jnp.stack(interpolated_samples)
# Generate two random latent vectors
z1 = jax.random.normal(rngs.sample(), (latent_dim,))
z2 = jax.random.normal(rngs.sample(), (latent_dim,))
# Interpolate
interpolated = interpolate_latent(gan, z1, z2, num_steps=20)
Latent Space Exploration¤
Explore the latent space by varying dimensions:
def explore_latent_dimension(gan, dim_idx, num_samples=10, range_scale=3.0):
"""Explore a specific latent dimension."""
# Fixed random vector
z_base = jax.random.normal(rngs.sample(), (latent_dim,))
# Vary single dimension
values = jnp.linspace(-range_scale, range_scale, num_samples)
samples = []
for value in values:
z = z_base.at[dim_idx].set(value)
sample = gan.generator(z[None, :], training=False)
samples.append(sample[0])
return jnp.stack(samples)
# Explore dimension 0
samples_dim0 = explore_latent_dimension(gan, dim_idx=0, num_samples=10)
Conditional Generation¤
For conditional GANs, specify the condition:
# Generate specific digits (MNIST)
labels = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
z = jax.random.normal(rngs.sample(), (10, latent_dim))
samples = cond_generator(z, labels, training=False)
# Each sample corresponds to its label
Evaluation and Monitoring¤
Visual Inspection¤
The most important evaluation method for GANs:
import matplotlib.pyplot as plt
def visualize_samples(samples, nrow=8, title="Generated Samples"):
"""Visualize a grid of samples."""
n_samples = samples.shape[0]
ncol = (n_samples + nrow - 1) // nrow
# Convert from [-1, 1] to [0, 1]
samples = (samples + 1) / 2
fig, axes = plt.subplots(ncol, nrow, figsize=(nrow * 2, ncol * 2))
axes = axes.flatten()
for i, ax in enumerate(axes):
if i < n_samples:
# Transpose from (C, H, W) to (H, W, C)
img = jnp.transpose(samples[i], (1, 2, 0))
# Handle grayscale
if img.shape[-1] == 1:
img = img[:, :, 0]
ax.imshow(img, cmap='gray')
else:
ax.imshow(img)
ax.axis('off')
plt.suptitle(title)
plt.tight_layout()
plt.show()
# Generate and visualize
samples = gan.generate(n_samples=64, rngs=rngs)
visualize_samples(samples)
Loss Monitoring¤
Track both generator and discriminator losses:
# During training
history = {
"gen_loss": [],
"disc_loss": [],
"real_scores": [],
"fake_scores": [],
}
for epoch in range(num_epochs):
for batch in dataloader:
metrics = train_step(gan, gen_opt, disc_opt, batch, rngs)
history["gen_loss"].append(float(metrics["gen_loss"]))
history["disc_loss"].append(float(metrics["disc_loss"]))
# Plot losses
plt.figure(figsize=(10, 5))
plt.plot(history["gen_loss"], label="Generator Loss")
plt.plot(history["disc_loss"], label="Discriminator Loss")
plt.xlabel("Training Step")
plt.ylabel("Loss")
plt.legend()
plt.title("GAN Training Losses")
plt.show()
Healthy training signs:
- Both losses decrease initially then stabilize
- Losses oscillate but don't diverge
- Real scores stay around 0.7-0.9
- Fake scores start low, gradually increase
- Visual quality improves over time
Warning signs:
- Discriminator loss → 0 (too strong)
- Generator loss → ∞ (gradient vanishing)
- Mode collapse (all samples look same)
- Training instability (wild oscillations)
Inception Score (IS)¤
Measures quality and diversity:
def inception_score(samples, num_splits=10):
"""
Compute Inception Score for generated samples.
Requires pre-trained Inception model.
"""
# Get predictions from Inception model
predictions = inception_model(samples)
# Split into groups
split_scores = []
for k in range(num_splits):
part = predictions[k * (len(predictions) // num_splits):
(k + 1) * (len(predictions) // num_splits)]
# Compute KL divergence
py = jnp.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i]
scores.append(jnp.sum(pyx * jnp.log(pyx / py)))
split_scores.append(jnp.exp(jnp.mean(jnp.array(scores))))
return jnp.mean(jnp.array(split_scores)), jnp.std(jnp.array(split_scores))
# Compute IS
mean_is, std_is = inception_score(generated_samples)
print(f"Inception Score: {mean_is:.2f} ± {std_is:.2f}")
Higher is better (good models: 8-10 for ImageNet)
Fréchet Inception Distance (FID)¤
Measures similarity to real data:
def frechet_inception_distance(real_samples, fake_samples):
"""
Compute FID between real and generated samples.
Lower is better.
"""
# Get features from Inception model
real_features = inception_model.get_features(real_samples)
fake_features = inception_model.get_features(fake_samples)
# Compute statistics
mu_real = jnp.mean(real_features, axis=0)
mu_fake = jnp.mean(fake_features, axis=0)
sigma_real = jnp.cov(real_features.T)
sigma_fake = jnp.cov(fake_features.T)
# Compute FID
diff = mu_real - mu_fake
covmean = sqrtm(sigma_real @ sigma_fake)
fid = jnp.sum(diff**2) + jnp.trace(sigma_real + sigma_fake - 2*covmean)
return fid
# Compute FID
fid_score = frechet_inception_distance(real_data, generated_samples)
print(f"FID Score: {fid_score:.2f}")
Lower is better (good models: < 50, excellent: < 10)
Common Issues and Solutions¤
Mode Collapse¤
Symptom: Generator produces limited variety of samples.
Detection:
# Check sample diversity
samples = gan.generate(n_samples=100, rngs=rngs)
samples_flat = samples.reshape(samples.shape[0], -1)
# Compute pairwise distances
from scipy.spatial.distance import pdist
distances = pdist(samples_flat)
if jnp.mean(distances) < threshold:
print("Warning: Possible mode collapse detected!")
Solutions:
- Use WGAN or LSGAN:
- Minibatch discrimination:
# Add minibatch statistics to discriminator
def minibatch_stddev(x):
"""Compute standard deviation across batch."""
batch_std = jnp.std(x, axis=0, keepdims=True)
return jnp.mean(batch_std)
- Add noise to discriminator inputs:
# Gradually decay noise
noise_std = 0.1 * (1 - epoch / num_epochs)
noisy_real = real_data + jax.random.normal(key, real_data.shape) * noise_std
noisy_fake = fake_data + jax.random.normal(key, fake_data.shape) * noise_std
- Use feature matching:
# Match discriminator feature statistics
def feature_matching_loss(real_features, fake_features):
return jnp.mean((jnp.mean(real_features, axis=0) -
jnp.mean(fake_features, axis=0)) ** 2)
Training Instability¤
Symptom: Losses oscillate wildly, training doesn't converge.
Solutions:
- Use spectral normalization:
discriminator = Discriminator(
hidden_dims=[512, 256, 128],
use_spectral_norm=True, # Enable spectral norm
rngs=rngs,
)
- Two-timescale update rule:
- Gradient penalty (WGAN-GP):
- Label smoothing:
# Smooth labels for discriminator
real_labels = jnp.ones((batch_size, 1)) * 0.9 # Instead of 1.0
fake_labels = jnp.zeros((batch_size, 1)) + 0.1 # Instead of 0.0
Vanishing Gradients¤
Symptom: Generator loss stops decreasing, samples don't improve.
Solutions:
- Use non-saturating loss:
# Instead of: -log(1 - D(G(z)))
# Use: -log(D(G(z)))
gen_loss = -jnp.mean(jnp.log(jnp.clip(fake_scores, 1e-7, 1.0)))
- Reduce discriminator capacity:
- Update discriminator less frequently:
# Update discriminator every 2 generator updates
if step % 2 == 0:
disc_loss = train_discriminator(...)
gen_loss = train_generator(...)
Poor Sample Quality¤
Symptom: Blurry or unrealistic samples.
Solutions:
- Use DCGAN architecture:
# Replace MLP with convolutional architecture
from workshop.generative_models.models.gan import DCGAN
gan = DCGAN(config, rngs=rngs)
- Increase model capacity:
- Train longer:
- Better data preprocessing:
# Normalize to [-1, 1] for Tanh
data = (data / 127.5) - 1.0
# Ensure consistent shape
data = jnp.transpose(data, (0, 3, 1, 2)) # NHWC → NCHW
Best Practices¤
DO¤
✅ Use DCGAN guidelines for image generation:
# Strided convolutions, batch norm, LeakyReLU
generator = DCGANGenerator(...)
discriminator = DCGANDiscriminator(...)
✅ Scale data to [-1, 1] for Tanh output:
✅ Use Adam optimizer with β₁=0.5:
✅ Monitor both losses and samples:
✅ Use two-timescale updates (TTUR):
✅ Start with WGAN for stable training:
✅ Save checkpoints regularly:
DON'T¤
❌ Don't use batch norm in discriminator input:
# BAD
discriminator.layers[0] = BatchNorm(...)
# GOOD
discriminator.batch_norm = False # Or skip first layer
❌ Don't use same learning rate for G and D:
# BAD
gen_lr = disc_lr = 0.0002
# GOOD
gen_lr = 0.0001
disc_lr = 0.0004 # Discriminator learns faster
❌ Don't forget to scale data:
# BAD
data = data / 255.0 # [0, 1] doesn't match Tanh [-1, 1]
# GOOD
data = (data / 127.5) - 1.0 # [-1, 1] matches Tanh
❌ Don't ignore mode collapse warnings:
❌ Don't use too small batch sizes:
Summary¤
This guide covered:
- Creating GANs: Generators, discriminators, and full GAN models
- Variants: Vanilla, DCGAN, WGAN, LSGAN, cGAN, CycleGAN, PatchGAN
- Training: Basic loops, WGAN training, two-timescale updates
- Generation: Basic sampling, interpolation, conditional generation
- Evaluation: Visual inspection, IS, FID
- Troubleshooting: Mode collapse, instability, vanishing gradients
- Best practices: What to do and what to avoid
Next Steps¤
- Theory: See GAN Concepts for mathematical foundations
- API Reference: Check GAN API Documentation for detailed specifications
- Example: Follow MNIST GAN Tutorial for hands-on training
- Advanced: Explore StyleGAN and Progressive GAN for state-of-the-art results