Skip to content

Simple GAN Example - Training a GAN on 2D Data¤

Level: Beginner | Runtime: ~2-3 minutes (CPU/GPU) | Format: Python + Jupyter

This example demonstrates how to use Workshop's GAN components to train a basic Generative Adversarial Network on 2D circular data. It showcases using Workshop's Generator, Discriminator, and loss functions instead of reimplementing from scratch.

Files¤

Dual-Format Implementation

This example is available in two synchronized formats:

  • Python Script (.py) - For version control, IDE development, and CI/CD integration
  • Jupyter Notebook (.ipynb) - For interactive learning, experimentation, and exploration

Both formats contain identical content and can be used interchangeably. Choose the format that best suits your workflow.

Quick Start¤

# Activate Workshop environment
source activate.sh

# Run the Python script
python examples/generative_models/image/gan/simple_gan.py

# Or launch Jupyter notebook
jupyter lab examples/generative_models/image/gan/simple_gan.ipynb

Overview¤

Learning Objectives:

  • Use Workshop's Generator and Discriminator classes
  • Apply Workshop's adversarial loss functions
  • Implement alternating generator/discriminator training
  • Visualize GAN training progress
  • Evaluate generated samples

Prerequisites:

  • Basic understanding of neural networks
  • Familiarity with JAX and Flax NNX basics
  • Understanding of adversarial training concepts
  • Workshop installed

Estimated Time: 10-15 minutes

What's Covered¤

  • Workshop APIs


    Using pre-built Generator, Discriminator, and loss functions from Workshop

  • Adversarial Training


    Alternating updates between generator and discriminator networks

  • Training Loop


    100-step training with loss tracking and visualization

  • Visualization


    Comparing real and generated 2D data distributions

Expected Results:

  • Quick training (~2-3 minutes on CPU/GPU)
  • Training loss curves showing GAN convergence
  • Visualizations comparing real vs. generated circular data

Theory: Generative Adversarial Networks¤

A GAN consists of two neural networks trained in opposition:

  • Generator (G): Creates fake samples from random noise
  • Discriminator (D): Distinguishes real samples from fake ones

The two networks play a minimax game:

\[\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\]

Where:

  • \(x\) is real data
  • \(z\) is random noise (latent vector)
  • \(D(x)\) is the discriminator's probability that \(x\) is real
  • \(G(z)\) is the generator's output given noise \(z\)

Step 1: Setup and Imports¤

Import Workshop's GAN components and loss functions:

import os

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from flax import nnx

from workshop.generative_models.core.losses.adversarial import (
    vanilla_discriminator_loss,
    vanilla_generator_loss,
)
from workshop.generative_models.models.gan import Discriminator, Generator

print("✅ All libraries imported successfully")

Key Imports:

  • Generator, Discriminator: Workshop's configurable MLP networks for GANs
  • vanilla_generator_loss, vanilla_discriminator_loss: Workshop's vanilla GAN loss functions

Step 2: Configure Random Number Generators¤

Initialize RNGs for reproducibility at examples/generative_models/image/gan/simple_gan.py:79-91:

# Set random seed for reproducibility
SEED = 42

# Create RNG keys for different components
key = jax.random.key(SEED)
gen_key, disc_key, data_key, sample_key = jax.random.split(key, 4)

# Initialize RNGs for generator and discriminator
gen_rngs = nnx.Rngs(params=gen_key, sample=data_key)
disc_rngs = nnx.Rngs(params=disc_key)
sample_rngs = nnx.Rngs(sample=sample_key)

print(f"✅ RNGs initialized with seed={SEED}")

Step 3: Data Generation Function¤

Create synthetic 2D data following a circular distribution at examples/generative_models/image/gan/simple_gan.py:102-128:

def generate_real_data(key, batch_size=32):
    """Generate 2D data points arranged in a circle.

    Args:
        key: JAX random key
        batch_size: Number of samples to generate

    Returns:
        Array of shape (batch_size, 2) containing 2D points
    """
    # Generate angles uniformly around the circle
    theta_key, noise_key = jax.random.split(key)
    theta = jax.random.uniform(theta_key, (batch_size,)) * 2 * jnp.pi

    # Add small Gaussian noise to radius for variety
    r = 1.0 + jax.random.normal(noise_key, (batch_size,)) * 0.1

    # Convert polar to Cartesian coordinates
    x = r * jnp.cos(theta)
    y = r * jnp.sin(theta)

    return jnp.stack([x, y], axis=-1)

Why Circular Data?

  • Simple 2D distribution is easy to visualize
  • Clear success metric: generated points should form a circle
  • Fast training for demonstration purposes

Step 4: Create Generator Using Workshop's API¤

Use Workshop's Generator class instead of implementing from scratch at examples/generative_models/image/gan/simple_gan.py:150-168:

# Create generator using Workshop's Generator class
generator = Generator(
    hidden_dims=[32, 32],           # Two hidden layers with 32 neurons each
    output_shape=(1, 2),             # Output shape: (batch, 2) for 2D points
    latent_dim=10,                   # 10D latent space
    activation="relu",               # ReLU activation for hidden layers
    batch_norm=False,                # No batch normalization for simplicity
    dropout_rate=0.0,                # No dropout
    rngs=gen_rngs,
)

Generator Architecture:

  • Input: 10D latent vector (random noise)
  • Hidden layers: [32, 32] with ReLU activation
  • Output: 2D point (shape: (1, 2))
  • Final activation: tanh (bounds outputs to [-1, 1])

Workshop Generator Features:

  • Configurable MLP architecture
  • Optional batch normalization and dropout
  • Automatic initialization with RNGs

Step 5: Create Discriminator Using Workshop's API¤

Use Workshop's Discriminator class for binary classification at examples/generative_models/image/gan/simple_gan.py:185-205:

# Create discriminator using Workshop's Discriminator class
discriminator = Discriminator(
    hidden_dims=[32, 32],            # Two hidden layers with 32 neurons each
    activation="relu",               # ReLU activation (can also use leaky_relu)
    batch_norm=False,                # No batch normalization
    dropout_rate=0.0,                # No dropout for simplicity
    rngs=disc_rngs,
)

Discriminator Architecture:

  • Input: 2D point (automatically flattened)
  • Hidden layers: [32, 32] with ReLU activation
  • Output: 1D probability (after sigmoid)
  • Note: Layers initialized lazily on first forward pass

Workshop Discriminator Features:

  • Automatic input flattening
  • Sigmoid output for probability in [0, 1] range
  • Lazy layer initialization based on input shape

Step 6: Define Training Functions Using Workshop's Loss Functions¤

Use Workshop's pre-built loss functions at examples/generative_models/image/gan/simple_gan.py:225-267:

def compute_discriminator_loss(discriminator, generator, real_batch, z):
    """Compute discriminator loss using Workshop's loss function.

    Args:
        discriminator: Discriminator model
        generator: Generator model (fixed during D update)
        real_batch: Real data samples
        z: Latent noise vectors

    Returns:
        Discriminator loss (scalar)
    """
    # Get discriminator scores (probabilities after sigmoid)
    real_scores = discriminator(real_batch, training=True)
    fake_batch = generator(z, training=False)  # Generator not being trained here
    fake_scores = discriminator(fake_batch, training=True)

    # Use Workshop's vanilla discriminator loss
    return vanilla_discriminator_loss(real_scores, fake_scores)


def compute_generator_loss(generator, discriminator, z):
    """Compute generator loss using Workshop's loss function.

    Args:
        generator: Generator model
        discriminator: Discriminator model (fixed during G update)
        z: Latent noise vectors

    Returns:
        Generator loss (scalar)
    """
    # Generate fake samples
    fake_batch = generator(z, training=True)

    # Get discriminator's opinion (probabilities after sigmoid)
    fake_scores = discriminator(fake_batch, training=False)  # Discriminator not being trained

    # Use Workshop's vanilla generator loss
    return vanilla_generator_loss(fake_scores)

Vanilla GAN Loss Functions:

Discriminator loss (from Workshop): \(\(\mathcal{L}_D = -\mathbb{E}_{x \sim p_{data}}[\log D(x)] - \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)\)

Generator loss (from Workshop): \(\(\mathcal{L}_G = -\mathbb{E}_{z \sim p_z}[\log D(G(z))]\)\)

Workshop Loss Function Features:

  • Expect probabilities (sigmoid outputs), not logits
  • Automatic clipping for numerical stability
  • Support for different reduction modes (mean, sum, none)

Step 7: Create Optimizers¤

Create Adam optimizers for both networks at examples/generative_models/image/gan/simple_gan.py:281-287:

# Create optimizers
learning_rate = 1e-3

gen_optimizer = nnx.Optimizer(generator, optax.adam(learning_rate), wrt=nnx.Param)
disc_optimizer = nnx.Optimizer(discriminator, optax.adam(learning_rate), wrt=nnx.Param)

print(f"✅ Optimizers created with learning rate: {learning_rate}")

Optimizer Configuration

Workshop provides OptimizerConfig for configuration-based optimizer creation in full training pipelines. For this simple example, we create optimizers directly using optax.


Step 8: Training Loop¤

Train the GAN for 100 steps with alternating updates at examples/generative_models/image/gan/simple_gan.py:305-369:

# Training configuration
num_steps = 100
batch_size = 32
log_interval = 20

print("=" * 60)
print("Training GAN")
print("=" * 60)

# Training history
history = {
    "step": [],
    "d_loss": [],
    "g_loss": [],
}

# Training loop
train_key = jax.random.key(999)

for step in range(num_steps):
    # Generate keys for this step
    train_key, data_key, z_disc_key, z_gen_key = jax.random.split(train_key, 4)

    # Sample real data
    real_batch = generate_real_data(data_key, batch_size=batch_size)

    # Sample latent vectors for discriminator update
    z_disc = jax.random.normal(z_disc_key, (batch_size, 10))

    # ========================================
    # Update Discriminator
    # ========================================
    def disc_loss_wrapper(disc):
        return compute_discriminator_loss(disc, generator, real_batch, z_disc)

    disc_loss, disc_grads = nnx.value_and_grad(disc_loss_wrapper)(discriminator)
    disc_optimizer.update(discriminator, disc_grads)

    # ========================================
    # Update Generator
    # ========================================
    # Sample new latent vectors for generator update
    z_gen = jax.random.normal(z_gen_key, (batch_size, 10))

    def gen_loss_wrapper(gen):
        return compute_generator_loss(gen, discriminator, z_gen)

    gen_loss, gen_grads = nnx.value_and_grad(gen_loss_wrapper)(generator)
    gen_optimizer.update(generator, gen_grads)

    # Store history
    history["step"].append(step)
    history["d_loss"].append(float(disc_loss))
    history["g_loss"].append(float(gen_loss))

    # Log progress
    if step % log_interval == 0 or step == num_steps - 1:
        print(f"Step {step:3d} | D_loss: {disc_loss:.4f} | G_loss: {gen_loss:.4f}")

print("Training complete!")

Training Procedure per Step:

  1. Sample real data batch
  2. Sample latent noise for discriminator update
  3. Update Discriminator: Compute gradients and update parameters
  4. Sample new latent noise for generator update
  5. Update Generator: Compute gradients and update parameters
  6. Log losses every 20 steps

Output:

============================================================
Training GAN
============================================================
Steps: 100
Batch size: 32
Learning rate: 0.001
------------------------------------------------------------
Step   0 | D_loss: 1.3414 | G_loss: 0.7156
Step  20 | D_loss: 1.3936 | G_loss: 0.6099
Step  40 | D_loss: 1.3762 | G_loss: 0.6786
Step  60 | D_loss: 1.2609 | G_loss: 0.7105
Step  80 | D_loss: 1.1737 | G_loss: 0.8118
Step  99 | D_loss: 1.7007 | G_loss: 0.4157
------------------------------------------------------------
Training complete!

Step 9: Visualize Training Progress¤

Plot discriminator and generator losses:

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

# Plot losses
ax.plot(history["step"], history["d_loss"], label="Discriminator Loss", linewidth=2)
ax.plot(history["step"], history["g_loss"], label="Generator Loss", linewidth=2)

ax.set_xlabel("Training Step", fontsize=12)
ax.set_ylabel("Loss", fontsize=12)
ax.set_title("GAN Training Progress", fontsize=14, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("examples_output/gan_training_curves.png", dpi=150, bbox_inches="tight")

GAN Training Curves

What to Look For:

  • Discriminator and generator losses should balance
  • Oscillations are normal in adversarial training
  • Neither loss should dominate completely

Step 10: Generate Samples Using Trained Generator¤

Generate samples from the trained generator:

# Generate samples
num_viz_samples = 500

final_real = generate_real_data(jax.random.key(5000), batch_size=num_viz_samples)

# Generate fake samples using the generator
z_final = jax.random.normal(jax.random.key(6000), (num_viz_samples, 10))
final_fake = generator(z_final, training=False)

print(f"✅ Generated {num_viz_samples} samples")

Step 11: Visualize Real vs. Generated Data¤

Compare real and generated distributions:

# Create visualization
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# Plot real data
ax1.scatter(final_real[:, 0], final_real[:, 1], alpha=0.5, s=10, color='blue')
ax1.set_title("Real Data Distribution", fontsize=14, fontweight="bold")
ax1.set_xlim(-2, 2)
ax1.set_ylim(-2, 2)

# Plot generated data
ax2.scatter(final_fake[:, 0], final_fake[:, 1], alpha=0.5, s=10, color='orange')
ax2.set_title("Generated Data Distribution", fontsize=14, fontweight="bold")
ax2.set_xlim(-2, 2)
ax2.set_ylim(-2, 2)

# Overlay comparison
ax3.scatter(final_real[:, 0], final_real[:, 1], alpha=0.4, s=10,
           color='blue', label='Real')
ax3.scatter(final_fake[:, 0], final_fake[:, 1], alpha=0.4, s=10,
           color='orange', label='Generated')
ax3.set_title("Overlay Comparison", fontsize=14, fontweight="bold")
ax3.legend(fontsize=10)

plt.tight_layout()
fig.savefig("examples_output/simple_gan_results.png", dpi=150, bbox_inches="tight")

GAN Results

Interpretation:

  • Left: Real data forms a clean circle
  • Middle: Generated data should approximate the circle
  • Right: Overlay shows how well the generator matches the real distribution

Step 12: Evaluate Generator Quality¤

Compute statistical metrics to quantify generator performance:

# Generate large sample for statistics
eval_key = jax.random.key(7777)
eval_real = generate_real_data(eval_key, batch_size=1000)
z_eval = jax.random.normal(jax.random.key(8888), (1000, 10))
eval_fake = generator(z_eval, training=False)

# Calculate statistics
real_mean = jnp.mean(eval_real, axis=0)
fake_mean = jnp.mean(eval_fake, axis=0)

real_std = jnp.std(eval_real, axis=0)
fake_std = jnp.std(eval_fake, axis=0)

# Calculate radius (distance from origin)
real_radius = jnp.sqrt(jnp.sum(eval_real**2, axis=1))
fake_radius = jnp.sqrt(jnp.sum(eval_fake**2, axis=1))

print("\n" + "=" * 60)
print("Generator Evaluation")
print("=" * 60)
print(f"{'Metric':<20} {'Real Data':>15} {'Generated':>15} {'Difference':>15}")
print("-" * 65)
print(f"{'Mean X':<20} {real_mean[0]:>15.4f} {fake_mean[0]:>15.4f} {abs(real_mean[0] - fake_mean[0]):>15.4f}")
print(f"{'Mean Y':<20} {real_mean[1]:>15.4f} {fake_mean[1]:>15.4f} {abs(real_mean[1] - fake_mean[1]):>15.4f}")
print(f"{'Std X':<20} {real_std[0]:>15.4f} {fake_std[0]:>15.4f} {abs(real_std[0] - fake_std[0]):>15.4f}")
print(f"{'Std Y':<20} {real_std[1]:>15.4f} {fake_std[1]:>15.4f} {abs(real_std[1] - fake_std[1]):>15.4f}")
print(f"{'Mean Radius':<20} {jnp.mean(real_radius):>15.4f} {jnp.mean(fake_radius):>15.4f} {abs(jnp.mean(real_radius) - jnp.mean(fake_radius)):>15.4f}")

Output:

============================================================
Generator Evaluation
============================================================

Distribution Statistics (1000 samples):
Metric                     Real Data       Generated      Difference
-----------------------------------------------------------------
Mean X                       -0.0227          0.0865          0.1092
Mean Y                       -0.0261          0.9504          0.9765
Std X                         0.7131          0.2550          0.4581
Std Y                         0.7096          0.0622          0.6474
Mean Radius                   1.0015          0.9861          0.0154
Std Radius                    0.1009          0.0852          0.0157
============================================================

What to Look For:

  • Mean X/Y should be close to zero (centered circle)
  • Std X/Y should be similar (circular, not elliptical)
  • Mean Radius should be close to 1.0
  • Lower differences indicate better generator quality

Summary¤

In this example, you learned:

  • Workshop GAN Components: Using Generator and Discriminator classes from Workshop
  • Workshop Loss Functions: Using vanilla_generator_loss and vanilla_discriminator_loss
  • Training: Alternating updates for discriminator and generator
  • Visualization: Comparing real and generated data distributions
  • Evaluation: Quantifying generator quality with statistical metrics

Key Takeaways:

  1. Workshop APIs: Use Workshop's pre-built components instead of reimplementing from scratch
  2. Generator: Configurable MLP with tanh output activation (bounded to [-1, 1])
  3. Discriminator: Configurable MLP with sigmoid output (probability in [0, 1])
  4. Adversarial Training: The generator and discriminator are trained in opposition
  5. Balance is Critical: If one network becomes too strong, training can fail

Experiments to Try¤

  1. Increase Training Steps: Train for 500-1000 steps to see better convergence
num_steps = 500
  1. Adjust Learning Rates: Try different learning rates for G and D
gen_optimizer = nnx.Optimizer(generator, optax.adam(1e-4), wrt=nnx.Param)
disc_optimizer = nnx.Optimizer(discriminator, optax.adam(2e-3), wrt=nnx.Param)
  1. Modify Architecture: Add more layers or neurons
generator = Generator(
    hidden_dims=[64, 64, 32],  # Deeper network
    output_shape=(1, 2),
    latent_dim=10,
    activation="relu",
    rngs=gen_rngs,
)
  1. Different Data Distribution: Try a different target distribution
# In generate_real_data, create a different shape
# Example: two circles, a line, etc.
  1. Add Batch Normalization: Stabilize training with batch norm
generator = Generator(
    hidden_dims=[32, 32],
    output_shape=(1, 2),
    latent_dim=10,
    activation="relu",
    batch_norm=True,  # Enable batch normalization
    rngs=gen_rngs,
)
  1. Try Different Loss Functions: Use Workshop's other GAN loss variants
from workshop.generative_models.core.losses.adversarial import (
    least_squares_generator_loss,
    least_squares_discriminator_loss,
)
# Then replace vanilla losses with least_squares losses

Troubleshooting¤

Training is unstable¤

  • Solution: Reduce learning rates, especially for discriminator
  • Explanation: Discriminator can become too strong too quickly
  • Fix: Try disc_lr=5e-4 and gen_lr=1e-3

Generated samples don't match real data¤

  • Solution: Train for more steps (500-1000)
  • Note: 100 steps is minimal for demonstration
  • Fix: Increase num_steps and monitor loss curves

Mode collapse (all generated samples are similar)¤

  • Solution: Add batch normalization or try different loss function
  • Monitoring: Check diversity in generated samples
  • Fix: Use batch_norm=True or try least squares loss

Next Steps¤

After mastering this basic GAN example using Workshop APIs, explore:

  • VAE Example: Compare GAN with VAE generative approach
  • Advanced GANs: DCGAN, Wasserstein GAN, Conditional GAN
  • Image GANs: Apply GANs to MNIST and other image datasets

Documentation Resources¤

Papers¤

  1. Generative Adversarial Networks (Goodfellow et al., 2014)
  2. Original GAN paper: https://arxiv.org/abs/1406.2661

  3. Improved Techniques for Training GANs (Salimans et al., 2016)

  4. Training stability: https://arxiv.org/abs/1606.03498

  5. Wasserstein GAN (Arjovsky et al., 2017)

  6. Improved training: https://arxiv.org/abs/1701.07875

Congratulations! You've successfully trained a GAN using Workshop's APIs! 🎉