Skip to content

GAN User Guide¤

This guide provides practical instructions for training and using Generative Adversarial Networks (GANs) in Workshop. We cover all GAN variants, training strategies, common issues, and best practices.

Quick Start¤

Here's a minimal example to get you started:

import jax
import jax.numpy as jnp
from flax import nnx

from workshop.generative_models.models.gan import GAN, Generator, Discriminator

# Initialize RNG
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Create simple config
class GANConfig:
    latent_dim = 100
    loss_type = "vanilla"  # or "wasserstein", "least_squares", "hinge"

    # Generator config
    class generator:
        hidden_dims = [256, 512]
        output_shape = (1, 1, 28, 28)  # MNIST shape
        activation = "relu"
        batch_norm = True
        dropout_rate = 0.0

    # Discriminator config
    class discriminator:
        hidden_dims = [512, 256]
        activation = "leaky_relu"
        leaky_relu_slope = 0.2
        batch_norm = False
        dropout_rate = 0.3
        use_spectral_norm = False

config = GANConfig()

# Create GAN
gan = GAN(config, rngs=rngs)

# Generate samples
samples = gan.generate(n_samples=16, rngs=rngs)
print(f"Generated samples shape: {samples.shape}")  # (16, 1, 28, 28)

Creating GAN Components¤

Basic Generator¤

The generator transforms random noise into data samples:

from workshop.generative_models.models.gan import Generator

# Create generator
generator = Generator(
    hidden_dims=[128, 256, 512],      # Hidden layer sizes
    output_shape=(1, 3, 32, 32),      # Output: 3 channels, 32x32 images
    latent_dim=100,                    # Latent space dimension
    activation="relu",                 # Activation function
    batch_norm=True,                   # Use batch normalization
    dropout_rate=0.0,                  # Dropout rate (usually 0 for generator)
    rngs=rngs,
)

# Generate samples from random noise
z = jax.random.normal(rngs.params(), (batch_size, latent_dim))
fake_samples = generator(z, training=True)

Key Parameters:

  • hidden_dims: List of hidden layer dimensions (progressively increases capacity)
  • output_shape: Target data shape (batch, channels, height, width)
  • latent_dim: Size of input latent vector (typically 64-512)
  • batch_norm: Stabilizes training (recommended for generator)
  • activation: "relu" for generator, "leaky_relu" for discriminator

Basic Discriminator¤

The discriminator classifies samples as real or fake:

from workshop.generative_models.models.gan import Discriminator

# Create discriminator
discriminator = Discriminator(
    hidden_dims=[512, 256, 128],      # Hidden layer sizes (often mirrors generator)
    activation="leaky_relu",           # LeakyReLU prevents dying neurons
    leaky_relu_slope=0.2,              # Negative slope for LeakyReLU
    batch_norm=False,                  # Usually False for discriminator
    dropout_rate=0.3,                  # Dropout to prevent overfitting
    use_spectral_norm=False,           # Spectral normalization for stability
    rngs=rngs,
)

# Classify samples
real_data = jnp.ones((batch_size, 3, 32, 32))
fake_data = generator(z, training=True)

real_scores = discriminator(real_data, training=True)  # Should be close to 1
fake_scores = discriminator(fake_data, training=True)  # Should be close to 0

Key Parameters:

  • hidden_dims: Usually mirrors generator in reverse
  • activation: "leaky_relu" is standard (slope 0.2)
  • batch_norm: Usually False (can cause training issues)
  • dropout_rate: 0.3-0.5 helps prevent overfitting
  • use_spectral_norm: Improves training stability

GAN Variants¤

1. Vanilla GAN¤

The original GAN formulation with binary cross-entropy loss:

from workshop.generative_models.models.gan import GAN

class VanillaGANConfig:
    latent_dim = 100
    loss_type = "vanilla"

    class generator:
        hidden_dims = [256, 512]
        output_shape = (1, 1, 28, 28)
        activation = "relu"
        batch_norm = True
        dropout_rate = 0.0

    class discriminator:
        hidden_dims = [512, 256]
        activation = "leaky_relu"
        leaky_relu_slope = 0.2
        batch_norm = False
        dropout_rate = 0.3
        use_spectral_norm = False

config = VanillaGANConfig()
gan = GAN(config, rngs=rngs)

# Training step
def train_step(gan, batch, rngs):
    # Compute loss
    losses = gan.loss_fn(batch, None, rngs=rngs)

    return losses["loss"], losses

loss, metrics = train_step(gan, batch_data, rngs)

When to use:

  • Learning GANs for the first time
  • Simple datasets (MNIST, simple shapes)
  • Proof-of-concept experiments

Pros: Simple, well-understood Cons: Training instability, mode collapse

2. Deep Convolutional GAN (DCGAN)¤

Uses convolutional architecture for images:

from workshop.generative_models.models.gan import (
    DCGAN,
    DCGANGenerator,
    DCGANDiscriminator,
)

# Create DCGAN components directly
generator = DCGANGenerator(
    output_shape=(3, 64, 64),           # 3 channels, 64x64 output
    latent_dim=100,
    hidden_dims=(256, 128, 64, 32),     # Progressive channel reduction
    activation=jax.nn.relu,
    batch_norm=True,
    dropout_rate=0.0,
    rngs=rngs,
)

discriminator = DCGANDiscriminator(
    input_shape=(3, 64, 64),
    hidden_dims=(32, 64, 128, 256),     # Progressive channel increase
    activation=jax.nn.leaky_relu,
    leaky_relu_slope=0.2,
    batch_norm=False,                    # DCGAN: no batch norm in discriminator
    dropout_rate=0.3,
    use_spectral_norm=True,              # Recommended for stability
    rngs=rngs,
)

# Or use the full DCGAN model
from workshop.generative_models.core.configuration.gan import DCGANConfiguration

dcgan_config = DCGANConfiguration(
    image_size=64,
    channels=3,
    latent_dim=100,
    gen_hidden_dims=(256, 128, 64, 32),
    disc_hidden_dims=(32, 64, 128, 256),
    loss_type="vanilla",
    generator_lr=0.0002,
    discriminator_lr=0.0002,
    beta1=0.5,
    beta2=0.999,
)

dcgan = DCGAN(dcgan_config, rngs=rngs)

# Generate high-quality images
samples = dcgan.generate(n_samples=64, rngs=rngs)

DCGAN Architecture Guidelines:

  1. Replace pooling with strided convolutions
  2. Use batch normalization (except discriminator input and generator output)
  3. Remove fully connected layers (except for latent projection)
  4. Use ReLU in generator, LeakyReLU in discriminator
  5. Use Tanh activation in generator output

When to use:

  • Image generation tasks
  • 64×64 to 128×128 resolution
  • More stable training than vanilla GAN

Pros: More stable, better image quality Cons: Still can suffer from mode collapse

3. Wasserstein GAN (WGAN)¤

Uses Wasserstein distance for more stable training:

from workshop.generative_models.models.gan import (
    WGAN,
    WGANGenerator,
    WGANDiscriminator,
    compute_gradient_penalty,
)

# Create WGAN model
from workshop.generative_models.core.configuration import ModelConfiguration

wgan_config = ModelConfiguration(
    input_dim=100,                       # Latent dimension
    output_dim=(3, 64, 64),              # Output image shape
    hidden_dims=None,                    # Will use defaults
    metadata={
        "gan_params": {
            "gen_hidden_dims": (1024, 512, 256),
            "disc_hidden_dims": (256, 512, 1024),
            "gradient_penalty_weight": 10.0,    # Lambda for gradient penalty
            "critic_iterations": 5,              # Update critic 5x per generator
        }
    }
)

wgan = WGAN(wgan_config, rngs=rngs)

# Training loop for WGAN
def train_wgan_step(wgan, real_samples, rngs, n_critic=5):
    """Train WGAN with proper critic/generator balance."""

    # Train critic n_critic times
    for _ in range(n_critic):
        # Sample latent vectors
        z = jax.random.normal(rngs.sample(), (real_samples.shape[0], wgan.latent_dim))

        # Generate fake samples
        fake_samples = wgan.generator(z, training=True)

        # Compute discriminator loss with gradient penalty
        disc_loss = wgan.discriminator_loss(real_samples, fake_samples, rngs)

        # Update discriminator
        # (In practice, use nnx.Optimizer)

    # Train generator once
    z = jax.random.normal(rngs.sample(), (real_samples.shape[0], wgan.latent_dim))
    fake_samples = wgan.generator(z, training=True)
    gen_loss = wgan.generator_loss(fake_samples)

    # Update generator

    return {"disc_loss": disc_loss, "gen_loss": gen_loss}

Key Differences from Vanilla GAN:

  1. Critic instead of discriminator (no sigmoid at output)
  2. Wasserstein distance instead of JS divergence
  3. Gradient penalty enforces Lipschitz constraint
  4. Multiple critic updates per generator update (5:1 ratio)
  5. Instance normalization instead of batch norm in critic

When to use:

  • Need stable training
  • Want meaningful loss metric
  • High-resolution images
  • Research experiments

Pros: Very stable, meaningful loss, better mode coverage Cons: Slower training, more complex

4. Least Squares GAN (LSGAN)¤

Uses least squares loss for smoother gradients:

from workshop.generative_models.models.gan import LSGAN, LSGANGenerator, LSGANDiscriminator

# Create LSGAN (similar interface to base GAN)
class LSGANConfig:
    latent_dim = 100
    loss_type = "least_squares"    # Key difference

    class generator:
        hidden_dims = [256, 512]
        output_shape = (1, 3, 32, 32)
        activation = "relu"
        batch_norm = True
        dropout_rate = 0.0

    class discriminator:
        hidden_dims = [512, 256]
        activation = "leaky_relu"
        leaky_relu_slope = 0.2
        batch_norm = False
        dropout_rate = 0.3
        use_spectral_norm = False

lsgan_config = LSGANConfig()
lsgan = GAN(lsgan_config, rngs=rngs)  # Can use base GAN with loss_type

# Or use dedicated LSGAN classes
generator = LSGANGenerator(
    output_shape=(3, 64, 64),
    latent_dim=100,
    rngs=rngs,
)

discriminator = LSGANDiscriminator(
    input_shape=(3, 64, 64),
    rngs=rngs,
)

# Training is similar to vanilla GAN
losses = lsgan.loss_fn(batch, None, rngs=rngs)

Key Difference:

Loss function uses squared error instead of log loss:

  • Generator: Minimize \((D(G(z)) - 1)^2\)
  • Discriminator: Minimize \((D(x) - 1)^2 + D(G(z))^2\)

When to use:

  • Want smoother gradients than vanilla GAN
  • Need more stable training than vanilla
  • Image generation with less training instability

Pros: More stable than vanilla, penalizes far-from-boundary samples Cons: Still can mode collapse

5. Conditional GAN (cGAN)¤

Conditions generation on labels or other information:

from workshop.generative_models.models.gan import (
    ConditionalGAN,
    ConditionalGenerator,
    ConditionalDiscriminator,
)

# Create conditional generator
cond_generator = ConditionalGenerator(
    output_shape=(1, 28, 28),
    latent_dim=100,
    num_classes=10,                 # MNIST has 10 classes
    hidden_dims=[256, 512],
    embedding_dim=50,               # Class embedding size
    rngs=rngs,
)

# Create conditional discriminator
cond_discriminator = ConditionalDiscriminator(
    input_shape=(1, 28, 28),
    num_classes=10,
    hidden_dims=[512, 256],
    embedding_dim=50,
    rngs=rngs,
)

# Generate conditioned on class label
labels = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])  # One of each digit
z = jax.random.normal(rngs.sample(), (10, 100))

# Generate specific digits
samples = cond_generator(z, labels, training=False)

# Discriminate with labels
real_data = load_mnist_batch()
real_labels = jnp.array([...])  # True labels

real_scores = cond_discriminator(real_data, real_labels, training=True)
fake_scores = cond_discriminator(samples, labels, training=True)

Key Features:

  • Controlled generation: Specify what to generate
  • Class conditioning: Generate specific categories
  • Embedding layer: Maps labels to high-dimensional space
  • Concatenation: Combines embeddings with features

When to use:

  • Need to control generation (class, attributes)
  • Have labeled data
  • Want to generate specific categories
  • Image-to-image translation with labels

Pros: Controlled generation, useful for labeled datasets Cons: Requires labels, more complex

6. CycleGAN¤

Unpaired image-to-image translation:

from workshop.generative_models.models.gan import (
    CycleGAN,
    CycleGANGenerator,
    CycleGANDiscriminator,
)

# Create CycleGAN for domain transfer (e.g., horse ↔ zebra)
cyclegan = CycleGAN(
    input_shape_x=(3, 256, 256),       # Domain X (horses)
    input_shape_y=(3, 256, 256),       # Domain Y (zebras)
    gen_hidden_dims=[64, 128, 256],
    disc_hidden_dims=[64, 128, 256],
    cycle_weight=10.0,                 # Cycle consistency weight
    identity_weight=0.5,               # Identity loss weight
    rngs=rngs,
)

# Training step
def train_cyclegan_step(cyclegan, batch_x, batch_y, rngs):
    """Train CycleGAN with cycle consistency."""

    # Forward cycle: X -> Y -> X
    fake_y = cyclegan.generator_g(batch_x, training=True)
    reconstructed_x = cyclegan.generator_f(fake_y, training=True)

    # Backward cycle: Y -> X -> Y
    fake_x = cyclegan.generator_f(batch_y, training=True)
    reconstructed_y = cyclegan.generator_g(fake_x, training=True)

    # Adversarial losses
    disc_y_real = cyclegan.discriminator_y(batch_y, training=True)
    disc_y_fake = cyclegan.discriminator_y(fake_y, training=True)
    disc_x_real = cyclegan.discriminator_x(batch_x, training=True)
    disc_x_fake = cyclegan.discriminator_x(fake_x, training=True)

    # Cycle consistency losses
    cycle_loss_x = jnp.mean(jnp.abs(reconstructed_x - batch_x))
    cycle_loss_y = jnp.mean(jnp.abs(reconstructed_y - batch_y))

    total_cycle_loss = cyclegan.cycle_weight * (cycle_loss_x + cycle_loss_y)

    # Identity losses (optional, helps preserve color)
    identity_x = cyclegan.generator_f(batch_x, training=True)
    identity_y = cyclegan.generator_g(batch_y, training=True)

    identity_loss_x = jnp.mean(jnp.abs(identity_x - batch_x))
    identity_loss_y = jnp.mean(jnp.abs(identity_y - batch_y))

    total_identity_loss = cyclegan.identity_weight * (identity_loss_x + identity_loss_y)

    return {
        "cycle_loss": total_cycle_loss,
        "identity_loss": total_identity_loss,
        "disc_x_loss": disc_loss_x,
        "disc_y_loss": disc_loss_y,
    }

Key Features:

  • Two generators: G: X→Y and F: Y→X
  • Two discriminators: D_X and D_Y
  • Cycle consistency: x → G(x) → F(G(x)) ≈ x
  • No paired data needed

When to use:

  • Image-to-image translation without paired data
  • Style transfer (photo ↔ painting)
  • Domain adaptation (synthetic ↔ real)
  • Seasonal changes (summer ↔ winter)

Pros: No paired data needed, flexible Cons: Computationally expensive (4 networks), can fail if domains too different

7. PatchGAN¤

Discriminator operates on image patches:

from workshop.generative_models.models.gan import (
    PatchGANDiscriminator,
    MultiScalePatchGANDiscriminator,
)

# Single-scale PatchGAN
patch_discriminator = PatchGANDiscriminator(
    input_shape=(3, 256, 256),
    hidden_dims=[64, 128, 256, 512],
    kernel_size=4,
    stride=2,
    rngs=rngs,
)

# Returns N×N array of patch predictions
patch_scores = patch_discriminator(images, training=True)  # Shape: (batch, H', W', 1)

# Multi-scale PatchGAN (better for high-resolution)
multiscale_discriminator = MultiScalePatchGANDiscriminator(
    input_shape=(3, 256, 256),
    hidden_dims=[64, 128, 256],
    num_scales=3,                   # 3 different scales
    rngs=rngs,
)

# Returns predictions at multiple scales
predictions = multiscale_discriminator(images, training=True)

Key Features:

  • Patch-based: Classifies overlapping patches
  • Local texture: Better for texture quality
  • Efficient: Fewer parameters than full-image discriminator
  • Multi-scale: Can combine predictions at different resolutions

When to use:

  • High-resolution images (>256×256)
  • Image-to-image translation (Pix2Pix)
  • Focus on local texture quality
  • With CycleGAN for better results

Pros: Efficient, good for textures, scales well Cons: May miss global structure issues

Training GANs¤

Basic Training Loop¤

Here's a complete training loop for a vanilla GAN:

import jax
import jax.numpy as jnp
from flax import nnx

from workshop.generative_models.models.gan import GAN

# Create model
gan = GAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))

# Create optimizers (separate for generator and discriminator)
gen_optimizer = nnx.Optimizer(
    gan.generator,
    nnx.adam(learning_rate=0.0002, b1=0.5, b2=0.999)
)

disc_optimizer = nnx.Optimizer(
    gan.discriminator,
    nnx.adam(learning_rate=0.0002, b1=0.5, b2=0.999)
)

# Training step
@nnx.jit
def train_step(gan, gen_opt, disc_opt, batch, rngs):
    """Single training step for vanilla GAN."""

    # Discriminator update
    def disc_loss_fn(disc):
        # Get generator samples (stop gradient to not update generator)
        z = jax.random.normal(rngs.sample(), (batch.shape[0], gan.latent_dim))
        fake_samples = gan.generator(z, training=True)
        fake_samples = jax.lax.stop_gradient(fake_samples)

        # Discriminator scores
        real_scores = disc(batch, training=True)
        fake_scores = disc(fake_samples, training=True)

        # Vanilla GAN discriminator loss
        real_loss = -jnp.log(jnp.clip(real_scores, 1e-7, 1.0))
        fake_loss = -jnp.log(jnp.clip(1.0 - fake_scores, 1e-7, 1.0))

        return jnp.mean(real_loss + fake_loss)

    # Compute discriminator loss and update
    disc_loss, disc_grads = nnx.value_and_grad(disc_loss_fn)(gan.discriminator)
    disc_opt.update(disc_grads)

    # Generator update
    def gen_loss_fn(gen):
        # Generate samples
        z = jax.random.normal(rngs.sample(), (batch.shape[0], gan.latent_dim))
        fake_samples = gen(z, training=True)

        # Get discriminator scores (stop gradient on discriminator)
        disc = jax.lax.stop_gradient(gan.discriminator)
        fake_scores = disc(fake_samples, training=True)

        # Non-saturating generator loss
        return -jnp.mean(jnp.log(jnp.clip(fake_scores, 1e-7, 1.0)))

    # Compute generator loss and update
    gen_loss, gen_grads = nnx.value_and_grad(gen_loss_fn)(gan.generator)
    gen_opt.update(gen_grads)

    return {
        "disc_loss": disc_loss,
        "gen_loss": gen_loss,
    }

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        # Preprocess: scale to [-1, 1] for tanh output
        batch = (batch / 127.5) - 1.0

        # Training step
        metrics = train_step(gan, gen_optimizer, disc_optimizer, batch, rngs)

        # Log metrics
        if step % log_interval == 0:
            print(f"Epoch {epoch}, Step {step}")
            print(f"  Discriminator Loss: {metrics['disc_loss']:.4f}")
            print(f"  Generator Loss: {metrics['gen_loss']:.4f}")

        # Generate samples for visualization
        if step % sample_interval == 0:
            samples = gan.generate(n_samples=16, rngs=rngs)
            save_images(samples, f"samples_step_{step}.png")

WGAN Training Loop¤

WGAN requires multiple discriminator updates per generator update:

@nnx.jit
def train_wgan_step(wgan, gen_opt, critic_opt, batch, rngs, n_critic=5):
    """Training step for WGAN-GP."""

    # Train critic n_critic times
    critic_losses = []
    for i in range(n_critic):
        def critic_loss_fn(critic):
            # Generate fake samples
            z = jax.random.normal(rngs.sample(), (batch.shape[0], wgan.latent_dim))
            fake_samples = wgan.generator(z, training=True)
            fake_samples = jax.lax.stop_gradient(fake_samples)

            # Get critic outputs
            real_validity = critic(batch, training=True)
            fake_validity = critic(fake_samples, training=True)

            # Wasserstein loss
            wasserstein_distance = jnp.mean(fake_validity) - jnp.mean(real_validity)

            # Gradient penalty
            alpha = jax.random.uniform(
                rngs.sample(),
                shape=(batch.shape[0], 1, 1, 1),
                minval=0.0,
                maxval=1.0
            )
            interpolated = alpha * batch + (1 - alpha) * fake_samples

            def critic_interp_fn(x):
                return jnp.sum(critic(x, training=True))

            gradients = jax.grad(critic_interp_fn)(interpolated)
            gradients = jnp.reshape(gradients, (batch.shape[0], -1))
            gradient_norm = jnp.sqrt(jnp.sum(gradients**2, axis=1) + 1e-12)
            gradient_penalty = jnp.mean((gradient_norm - 1.0) ** 2) * 10.0

            return wasserstein_distance + gradient_penalty

        # Update critic
        critic_loss, critic_grads = nnx.value_and_grad(critic_loss_fn)(wgan.discriminator)
        critic_opt.update(critic_grads)
        critic_losses.append(critic_loss)

    # Train generator once
    def gen_loss_fn(gen):
        z = jax.random.normal(rngs.sample(), (batch.shape[0], wgan.latent_dim))
        fake_samples = gen(z, training=True)

        critic = jax.lax.stop_gradient(wgan.discriminator)
        fake_validity = critic(fake_samples, training=True)

        # WGAN generator loss: maximize critic output
        return -jnp.mean(fake_validity)

    gen_loss, gen_grads = nnx.value_and_grad(gen_loss_fn)(wgan.generator)
    gen_opt.update(gen_grads)

    return {
        "critic_loss": jnp.mean(jnp.array(critic_losses)),
        "gen_loss": gen_loss,
    }

Two-Timescale Update Rule (TTUR)¤

Use different learning rates for generator and discriminator:

# Generator: slower learning rate
gen_optimizer = nnx.Optimizer(
    gan.generator,
    nnx.adam(learning_rate=0.0001, b1=0.5, b2=0.999)  # lr = 0.0001
)

# Discriminator: faster learning rate
disc_optimizer = nnx.Optimizer(
    gan.discriminator,
    nnx.adam(learning_rate=0.0004, b1=0.5, b2=0.999)  # lr = 0.0004
)

Why it works:

  • Discriminator needs to stay ahead to provide useful signal
  • Prevents generator from overwhelming discriminator
  • More stable training dynamics

Generation and Sampling¤

Basic Generation¤

# Generate samples
n_samples = 64
samples = gan.generate(n_samples=n_samples, rngs=rngs)

# Samples are in [-1, 1] range (from Tanh)
# Convert to [0, 255] for visualization
samples = ((samples + 1) / 2 * 255).astype(jnp.uint8)

Latent Space Interpolation¤

Smoothly interpolate between two points in latent space:

def interpolate_latent(gan, z1, z2, num_steps=10, rngs=None):
    """Interpolate between two latent vectors."""
    # Create interpolation weights
    alphas = jnp.linspace(0, 1, num_steps)

    # Interpolate
    interpolated_samples = []
    for alpha in alphas:
        z_interp = alpha * z2 + (1 - alpha) * z1
        sample = gan.generator(z_interp[None, :], training=False)
        interpolated_samples.append(sample[0])

    return jnp.stack(interpolated_samples)

# Generate two random latent vectors
z1 = jax.random.normal(rngs.sample(), (latent_dim,))
z2 = jax.random.normal(rngs.sample(), (latent_dim,))

# Interpolate
interpolated = interpolate_latent(gan, z1, z2, num_steps=20)

Latent Space Exploration¤

Explore the latent space by varying dimensions:

def explore_latent_dimension(gan, dim_idx, num_samples=10, range_scale=3.0):
    """Explore a specific latent dimension."""
    # Fixed random vector
    z_base = jax.random.normal(rngs.sample(), (latent_dim,))

    # Vary single dimension
    values = jnp.linspace(-range_scale, range_scale, num_samples)

    samples = []
    for value in values:
        z = z_base.at[dim_idx].set(value)
        sample = gan.generator(z[None, :], training=False)
        samples.append(sample[0])

    return jnp.stack(samples)

# Explore dimension 0
samples_dim0 = explore_latent_dimension(gan, dim_idx=0, num_samples=10)

Conditional Generation¤

For conditional GANs, specify the condition:

# Generate specific digits (MNIST)
labels = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
z = jax.random.normal(rngs.sample(), (10, latent_dim))

samples = cond_generator(z, labels, training=False)
# Each sample corresponds to its label

Evaluation and Monitoring¤

Visual Inspection¤

The most important evaluation method for GANs:

import matplotlib.pyplot as plt

def visualize_samples(samples, nrow=8, title="Generated Samples"):
    """Visualize a grid of samples."""
    n_samples = samples.shape[0]
    ncol = (n_samples + nrow - 1) // nrow

    # Convert from [-1, 1] to [0, 1]
    samples = (samples + 1) / 2

    fig, axes = plt.subplots(ncol, nrow, figsize=(nrow * 2, ncol * 2))
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        if i < n_samples:
            # Transpose from (C, H, W) to (H, W, C)
            img = jnp.transpose(samples[i], (1, 2, 0))
            # Handle grayscale
            if img.shape[-1] == 1:
                img = img[:, :, 0]
                ax.imshow(img, cmap='gray')
            else:
                ax.imshow(img)
        ax.axis('off')

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Generate and visualize
samples = gan.generate(n_samples=64, rngs=rngs)
visualize_samples(samples)

Loss Monitoring¤

Track both generator and discriminator losses:

# During training
history = {
    "gen_loss": [],
    "disc_loss": [],
    "real_scores": [],
    "fake_scores": [],
}

for epoch in range(num_epochs):
    for batch in dataloader:
        metrics = train_step(gan, gen_opt, disc_opt, batch, rngs)

        history["gen_loss"].append(float(metrics["gen_loss"]))
        history["disc_loss"].append(float(metrics["disc_loss"]))

# Plot losses
plt.figure(figsize=(10, 5))
plt.plot(history["gen_loss"], label="Generator Loss")
plt.plot(history["disc_loss"], label="Discriminator Loss")
plt.xlabel("Training Step")
plt.ylabel("Loss")
plt.legend()
plt.title("GAN Training Losses")
plt.show()

Healthy training signs:

  • Both losses decrease initially then stabilize
  • Losses oscillate but don't diverge
  • Real scores stay around 0.7-0.9
  • Fake scores start low, gradually increase
  • Visual quality improves over time

Warning signs:

  • Discriminator loss → 0 (too strong)
  • Generator loss → ∞ (gradient vanishing)
  • Mode collapse (all samples look same)
  • Training instability (wild oscillations)

Inception Score (IS)¤

Measures quality and diversity:

def inception_score(samples, num_splits=10):
    """
    Compute Inception Score for generated samples.
    Requires pre-trained Inception model.
    """
    # Get predictions from Inception model
    predictions = inception_model(samples)

    # Split into groups
    split_scores = []
    for k in range(num_splits):
        part = predictions[k * (len(predictions) // num_splits):
                          (k + 1) * (len(predictions) // num_splits)]
        # Compute KL divergence
        py = jnp.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i]
            scores.append(jnp.sum(pyx * jnp.log(pyx / py)))
        split_scores.append(jnp.exp(jnp.mean(jnp.array(scores))))

    return jnp.mean(jnp.array(split_scores)), jnp.std(jnp.array(split_scores))

# Compute IS
mean_is, std_is = inception_score(generated_samples)
print(f"Inception Score: {mean_is:.2f} ± {std_is:.2f}")

Higher is better (good models: 8-10 for ImageNet)

Fréchet Inception Distance (FID)¤

Measures similarity to real data:

def frechet_inception_distance(real_samples, fake_samples):
    """
    Compute FID between real and generated samples.
    Lower is better.
    """
    # Get features from Inception model
    real_features = inception_model.get_features(real_samples)
    fake_features = inception_model.get_features(fake_samples)

    # Compute statistics
    mu_real = jnp.mean(real_features, axis=0)
    mu_fake = jnp.mean(fake_features, axis=0)

    sigma_real = jnp.cov(real_features.T)
    sigma_fake = jnp.cov(fake_features.T)

    # Compute FID
    diff = mu_real - mu_fake
    covmean = sqrtm(sigma_real @ sigma_fake)

    fid = jnp.sum(diff**2) + jnp.trace(sigma_real + sigma_fake - 2*covmean)

    return fid

# Compute FID
fid_score = frechet_inception_distance(real_data, generated_samples)
print(f"FID Score: {fid_score:.2f}")

Lower is better (good models: < 50, excellent: < 10)

Common Issues and Solutions¤

Mode Collapse¤

Symptom: Generator produces limited variety of samples.

Detection:

# Check sample diversity
samples = gan.generate(n_samples=100, rngs=rngs)
samples_flat = samples.reshape(samples.shape[0], -1)

# Compute pairwise distances
from scipy.spatial.distance import pdist
distances = pdist(samples_flat)

if jnp.mean(distances) < threshold:
    print("Warning: Possible mode collapse detected!")

Solutions:

  1. Use WGAN or LSGAN:
config.loss_type = "wasserstein"  # or "least_squares"
  1. Minibatch discrimination:
# Add minibatch statistics to discriminator
def minibatch_stddev(x):
    """Compute standard deviation across batch."""
    batch_std = jnp.std(x, axis=0, keepdims=True)
    return jnp.mean(batch_std)
  1. Add noise to discriminator inputs:
# Gradually decay noise
noise_std = 0.1 * (1 - epoch / num_epochs)
noisy_real = real_data + jax.random.normal(key, real_data.shape) * noise_std
noisy_fake = fake_data + jax.random.normal(key, fake_data.shape) * noise_std
  1. Use feature matching:
# Match discriminator feature statistics
def feature_matching_loss(real_features, fake_features):
    return jnp.mean((jnp.mean(real_features, axis=0) -
                     jnp.mean(fake_features, axis=0)) ** 2)

Training Instability¤

Symptom: Losses oscillate wildly, training doesn't converge.

Solutions:

  1. Use spectral normalization:
discriminator = Discriminator(
    hidden_dims=[512, 256, 128],
    use_spectral_norm=True,  # Enable spectral norm
    rngs=rngs,
)
  1. Two-timescale update rule:
# Different learning rates
gen_lr = 0.0001
disc_lr = 0.0004
  1. Gradient penalty (WGAN-GP):
# Use WGAN with gradient penalty
wgan_config.gradient_penalty_weight = 10.0
  1. Label smoothing:
# Smooth labels for discriminator
real_labels = jnp.ones((batch_size, 1)) * 0.9  # Instead of 1.0
fake_labels = jnp.zeros((batch_size, 1)) + 0.1  # Instead of 0.0

Vanishing Gradients¤

Symptom: Generator loss stops decreasing, samples don't improve.

Solutions:

  1. Use non-saturating loss:
# Instead of: -log(1 - D(G(z)))
# Use: -log(D(G(z)))
gen_loss = -jnp.mean(jnp.log(jnp.clip(fake_scores, 1e-7, 1.0)))
  1. Reduce discriminator capacity:
# Make discriminator weaker
config.discriminator.hidden_dims = [256, 128]  # Smaller than [512, 256]
  1. Update discriminator less frequently:
# Update discriminator every 2 generator updates
if step % 2 == 0:
    disc_loss = train_discriminator(...)
gen_loss = train_generator(...)

Poor Sample Quality¤

Symptom: Blurry or unrealistic samples.

Solutions:

  1. Use DCGAN architecture:
# Replace MLP with convolutional architecture
from workshop.generative_models.models.gan import DCGAN
gan = DCGAN(config, rngs=rngs)
  1. Increase model capacity:
config.generator.hidden_dims = [512, 1024, 2048]  # Larger
  1. Train longer:
num_epochs = 200  # GANs need many epochs
  1. Better data preprocessing:
# Normalize to [-1, 1] for Tanh
data = (data / 127.5) - 1.0

# Ensure consistent shape
data = jnp.transpose(data, (0, 3, 1, 2))  # NHWC → NCHW

Best Practices¤

DO¤

Use DCGAN guidelines for image generation:

# Strided convolutions, batch norm, LeakyReLU
generator = DCGANGenerator(...)
discriminator = DCGANDiscriminator(...)

Scale data to [-1, 1] for Tanh output:

data = (data / 127.5) - 1.0

Use Adam optimizer with β₁=0.5:

optimizer = nnx.adam(learning_rate=0.0002, b1=0.5, b2=0.999)

Monitor both losses and samples:

if step % 100 == 0:
    visualize_samples(gan.generate(16, rngs=rngs))

Use two-timescale updates (TTUR):

gen_lr = 0.0001
disc_lr = 0.0004

Start with WGAN for stable training:

config.loss_type = "wasserstein"

Save checkpoints regularly:

if epoch % 10 == 0:
    nnx.save_checkpoint(gan, f"checkpoints/gan_epoch_{epoch}")

DON'T¤

Don't use batch norm in discriminator input:

# BAD
discriminator.layers[0] = BatchNorm(...)

# GOOD
discriminator.batch_norm = False  # Or skip first layer

Don't use same learning rate for G and D:

# BAD
gen_lr = disc_lr = 0.0002

# GOOD
gen_lr = 0.0001
disc_lr = 0.0004  # Discriminator learns faster

Don't forget to scale data:

# BAD
data = data / 255.0  # [0, 1] doesn't match Tanh [-1, 1]

# GOOD
data = (data / 127.5) - 1.0  # [-1, 1] matches Tanh

Don't ignore mode collapse warnings:

# Check diversity regularly
if jnp.std(samples) < 0.1:
    print("Warning: Possible mode collapse!")

Don't use too small batch sizes:

# BAD
batch_size = 8  # Too small, unstable

# GOOD
batch_size = 64  # Better stability

Summary¤

This guide covered:

  • Creating GANs: Generators, discriminators, and full GAN models
  • Variants: Vanilla, DCGAN, WGAN, LSGAN, cGAN, CycleGAN, PatchGAN
  • Training: Basic loops, WGAN training, two-timescale updates
  • Generation: Basic sampling, interpolation, conditional generation
  • Evaluation: Visual inspection, IS, FID
  • Troubleshooting: Mode collapse, instability, vanishing gradients
  • Best practices: What to do and what to avoid

Next Steps¤

  • Theory: See GAN Concepts for mathematical foundations
  • API Reference: Check GAN API Documentation for detailed specifications
  • Example: Follow MNIST GAN Tutorial for hands-on training
  • Advanced: Explore StyleGAN and Progressive GAN for state-of-the-art results