Simple GAN Example - Training a GAN on 2D Data¤
Level: Beginner | Runtime: ~2-3 minutes (CPU/GPU) | Format: Python + Jupyter
This example demonstrates how to use Workshop's GAN components to train a basic Generative Adversarial Network on 2D circular data. It showcases using Workshop's Generator, Discriminator, and loss functions instead of reimplementing from scratch.
Files¤
- Python Script:
examples/generative_models/image/gan/simple_gan.py - Jupyter Notebook:
examples/generative_models/image/gan/simple_gan.ipynb
Dual-Format Implementation
This example is available in two synchronized formats:
- Python Script (.py) - For version control, IDE development, and CI/CD integration
- Jupyter Notebook (.ipynb) - For interactive learning, experimentation, and exploration
Both formats contain identical content and can be used interchangeably. Choose the format that best suits your workflow.
Quick Start¤
# Activate Workshop environment
source activate.sh
# Run the Python script
python examples/generative_models/image/gan/simple_gan.py
# Or launch Jupyter notebook
jupyter lab examples/generative_models/image/gan/simple_gan.ipynb
Overview¤
Learning Objectives:
- Use Workshop's Generator and Discriminator classes
- Apply Workshop's adversarial loss functions
- Implement alternating generator/discriminator training
- Visualize GAN training progress
- Evaluate generated samples
Prerequisites:
- Basic understanding of neural networks
- Familiarity with JAX and Flax NNX basics
- Understanding of adversarial training concepts
- Workshop installed
Estimated Time: 10-15 minutes
What's Covered¤
-
Workshop APIs
Using pre-built Generator, Discriminator, and loss functions from Workshop
-
Adversarial Training
Alternating updates between generator and discriminator networks
-
Training Loop
100-step training with loss tracking and visualization
-
Visualization
Comparing real and generated 2D data distributions
Expected Results:
- Quick training (~2-3 minutes on CPU/GPU)
- Training loss curves showing GAN convergence
- Visualizations comparing real vs. generated circular data
Theory: Generative Adversarial Networks¤
A GAN consists of two neural networks trained in opposition:
- Generator (G): Creates fake samples from random noise
- Discriminator (D): Distinguishes real samples from fake ones
The two networks play a minimax game:
Where:
- \(x\) is real data
- \(z\) is random noise (latent vector)
- \(D(x)\) is the discriminator's probability that \(x\) is real
- \(G(z)\) is the generator's output given noise \(z\)
Step 1: Setup and Imports¤
Import Workshop's GAN components and loss functions:
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from flax import nnx
from workshop.generative_models.core.losses.adversarial import (
vanilla_discriminator_loss,
vanilla_generator_loss,
)
from workshop.generative_models.models.gan import Discriminator, Generator
print("✅ All libraries imported successfully")
Key Imports:
Generator,Discriminator: Workshop's configurable MLP networks for GANsvanilla_generator_loss,vanilla_discriminator_loss: Workshop's vanilla GAN loss functions
Step 2: Configure Random Number Generators¤
Initialize RNGs for reproducibility at examples/generative_models/image/gan/simple_gan.py:79-91:
# Set random seed for reproducibility
SEED = 42
# Create RNG keys for different components
key = jax.random.key(SEED)
gen_key, disc_key, data_key, sample_key = jax.random.split(key, 4)
# Initialize RNGs for generator and discriminator
gen_rngs = nnx.Rngs(params=gen_key, sample=data_key)
disc_rngs = nnx.Rngs(params=disc_key)
sample_rngs = nnx.Rngs(sample=sample_key)
print(f"✅ RNGs initialized with seed={SEED}")
Step 3: Data Generation Function¤
Create synthetic 2D data following a circular distribution at examples/generative_models/image/gan/simple_gan.py:102-128:
def generate_real_data(key, batch_size=32):
"""Generate 2D data points arranged in a circle.
Args:
key: JAX random key
batch_size: Number of samples to generate
Returns:
Array of shape (batch_size, 2) containing 2D points
"""
# Generate angles uniformly around the circle
theta_key, noise_key = jax.random.split(key)
theta = jax.random.uniform(theta_key, (batch_size,)) * 2 * jnp.pi
# Add small Gaussian noise to radius for variety
r = 1.0 + jax.random.normal(noise_key, (batch_size,)) * 0.1
# Convert polar to Cartesian coordinates
x = r * jnp.cos(theta)
y = r * jnp.sin(theta)
return jnp.stack([x, y], axis=-1)
Why Circular Data?
- Simple 2D distribution is easy to visualize
- Clear success metric: generated points should form a circle
- Fast training for demonstration purposes
Step 4: Create Generator Using Workshop's API¤
Use Workshop's Generator class instead of implementing from scratch at examples/generative_models/image/gan/simple_gan.py:150-168:
# Create generator using Workshop's Generator class
generator = Generator(
hidden_dims=[32, 32], # Two hidden layers with 32 neurons each
output_shape=(1, 2), # Output shape: (batch, 2) for 2D points
latent_dim=10, # 10D latent space
activation="relu", # ReLU activation for hidden layers
batch_norm=False, # No batch normalization for simplicity
dropout_rate=0.0, # No dropout
rngs=gen_rngs,
)
Generator Architecture:
- Input: 10D latent vector (random noise)
- Hidden layers: [32, 32] with ReLU activation
- Output: 2D point (shape: (1, 2))
- Final activation: tanh (bounds outputs to [-1, 1])
Workshop Generator Features:
- Configurable MLP architecture
- Optional batch normalization and dropout
- Automatic initialization with RNGs
Step 5: Create Discriminator Using Workshop's API¤
Use Workshop's Discriminator class for binary classification at examples/generative_models/image/gan/simple_gan.py:185-205:
# Create discriminator using Workshop's Discriminator class
discriminator = Discriminator(
hidden_dims=[32, 32], # Two hidden layers with 32 neurons each
activation="relu", # ReLU activation (can also use leaky_relu)
batch_norm=False, # No batch normalization
dropout_rate=0.0, # No dropout for simplicity
rngs=disc_rngs,
)
Discriminator Architecture:
- Input: 2D point (automatically flattened)
- Hidden layers: [32, 32] with ReLU activation
- Output: 1D probability (after sigmoid)
- Note: Layers initialized lazily on first forward pass
Workshop Discriminator Features:
- Automatic input flattening
- Sigmoid output for probability in [0, 1] range
- Lazy layer initialization based on input shape
Step 6: Define Training Functions Using Workshop's Loss Functions¤
Use Workshop's pre-built loss functions at examples/generative_models/image/gan/simple_gan.py:225-267:
def compute_discriminator_loss(discriminator, generator, real_batch, z):
"""Compute discriminator loss using Workshop's loss function.
Args:
discriminator: Discriminator model
generator: Generator model (fixed during D update)
real_batch: Real data samples
z: Latent noise vectors
Returns:
Discriminator loss (scalar)
"""
# Get discriminator scores (probabilities after sigmoid)
real_scores = discriminator(real_batch, training=True)
fake_batch = generator(z, training=False) # Generator not being trained here
fake_scores = discriminator(fake_batch, training=True)
# Use Workshop's vanilla discriminator loss
return vanilla_discriminator_loss(real_scores, fake_scores)
def compute_generator_loss(generator, discriminator, z):
"""Compute generator loss using Workshop's loss function.
Args:
generator: Generator model
discriminator: Discriminator model (fixed during G update)
z: Latent noise vectors
Returns:
Generator loss (scalar)
"""
# Generate fake samples
fake_batch = generator(z, training=True)
# Get discriminator's opinion (probabilities after sigmoid)
fake_scores = discriminator(fake_batch, training=False) # Discriminator not being trained
# Use Workshop's vanilla generator loss
return vanilla_generator_loss(fake_scores)
Vanilla GAN Loss Functions:
Discriminator loss (from Workshop): \(\(\mathcal{L}_D = -\mathbb{E}_{x \sim p_{data}}[\log D(x)] - \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)\)
Generator loss (from Workshop): \(\(\mathcal{L}_G = -\mathbb{E}_{z \sim p_z}[\log D(G(z))]\)\)
Workshop Loss Function Features:
- Expect probabilities (sigmoid outputs), not logits
- Automatic clipping for numerical stability
- Support for different reduction modes (mean, sum, none)
Step 7: Create Optimizers¤
Create Adam optimizers for both networks at examples/generative_models/image/gan/simple_gan.py:281-287:
# Create optimizers
learning_rate = 1e-3
gen_optimizer = nnx.Optimizer(generator, optax.adam(learning_rate), wrt=nnx.Param)
disc_optimizer = nnx.Optimizer(discriminator, optax.adam(learning_rate), wrt=nnx.Param)
print(f"✅ Optimizers created with learning rate: {learning_rate}")
Optimizer Configuration
Workshop provides OptimizerConfig for configuration-based optimizer creation in full training pipelines. For this simple example, we create optimizers directly using optax.
Step 8: Training Loop¤
Train the GAN for 100 steps with alternating updates at examples/generative_models/image/gan/simple_gan.py:305-369:
# Training configuration
num_steps = 100
batch_size = 32
log_interval = 20
print("=" * 60)
print("Training GAN")
print("=" * 60)
# Training history
history = {
"step": [],
"d_loss": [],
"g_loss": [],
}
# Training loop
train_key = jax.random.key(999)
for step in range(num_steps):
# Generate keys for this step
train_key, data_key, z_disc_key, z_gen_key = jax.random.split(train_key, 4)
# Sample real data
real_batch = generate_real_data(data_key, batch_size=batch_size)
# Sample latent vectors for discriminator update
z_disc = jax.random.normal(z_disc_key, (batch_size, 10))
# ========================================
# Update Discriminator
# ========================================
def disc_loss_wrapper(disc):
return compute_discriminator_loss(disc, generator, real_batch, z_disc)
disc_loss, disc_grads = nnx.value_and_grad(disc_loss_wrapper)(discriminator)
disc_optimizer.update(discriminator, disc_grads)
# ========================================
# Update Generator
# ========================================
# Sample new latent vectors for generator update
z_gen = jax.random.normal(z_gen_key, (batch_size, 10))
def gen_loss_wrapper(gen):
return compute_generator_loss(gen, discriminator, z_gen)
gen_loss, gen_grads = nnx.value_and_grad(gen_loss_wrapper)(generator)
gen_optimizer.update(generator, gen_grads)
# Store history
history["step"].append(step)
history["d_loss"].append(float(disc_loss))
history["g_loss"].append(float(gen_loss))
# Log progress
if step % log_interval == 0 or step == num_steps - 1:
print(f"Step {step:3d} | D_loss: {disc_loss:.4f} | G_loss: {gen_loss:.4f}")
print("Training complete!")
Training Procedure per Step:
- Sample real data batch
- Sample latent noise for discriminator update
- Update Discriminator: Compute gradients and update parameters
- Sample new latent noise for generator update
- Update Generator: Compute gradients and update parameters
- Log losses every 20 steps
Output:
============================================================
Training GAN
============================================================
Steps: 100
Batch size: 32
Learning rate: 0.001
------------------------------------------------------------
Step 0 | D_loss: 1.3414 | G_loss: 0.7156
Step 20 | D_loss: 1.3936 | G_loss: 0.6099
Step 40 | D_loss: 1.3762 | G_loss: 0.6786
Step 60 | D_loss: 1.2609 | G_loss: 0.7105
Step 80 | D_loss: 1.1737 | G_loss: 0.8118
Step 99 | D_loss: 1.7007 | G_loss: 0.4157
------------------------------------------------------------
Training complete!
Step 9: Visualize Training Progress¤
Plot discriminator and generator losses:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
# Plot losses
ax.plot(history["step"], history["d_loss"], label="Discriminator Loss", linewidth=2)
ax.plot(history["step"], history["g_loss"], label="Generator Loss", linewidth=2)
ax.set_xlabel("Training Step", fontsize=12)
ax.set_ylabel("Loss", fontsize=12)
ax.set_title("GAN Training Progress", fontsize=14, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("examples_output/gan_training_curves.png", dpi=150, bbox_inches="tight")

What to Look For:
- Discriminator and generator losses should balance
- Oscillations are normal in adversarial training
- Neither loss should dominate completely
Step 10: Generate Samples Using Trained Generator¤
Generate samples from the trained generator:
# Generate samples
num_viz_samples = 500
final_real = generate_real_data(jax.random.key(5000), batch_size=num_viz_samples)
# Generate fake samples using the generator
z_final = jax.random.normal(jax.random.key(6000), (num_viz_samples, 10))
final_fake = generator(z_final, training=False)
print(f"✅ Generated {num_viz_samples} samples")
Step 11: Visualize Real vs. Generated Data¤
Compare real and generated distributions:
# Create visualization
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
# Plot real data
ax1.scatter(final_real[:, 0], final_real[:, 1], alpha=0.5, s=10, color='blue')
ax1.set_title("Real Data Distribution", fontsize=14, fontweight="bold")
ax1.set_xlim(-2, 2)
ax1.set_ylim(-2, 2)
# Plot generated data
ax2.scatter(final_fake[:, 0], final_fake[:, 1], alpha=0.5, s=10, color='orange')
ax2.set_title("Generated Data Distribution", fontsize=14, fontweight="bold")
ax2.set_xlim(-2, 2)
ax2.set_ylim(-2, 2)
# Overlay comparison
ax3.scatter(final_real[:, 0], final_real[:, 1], alpha=0.4, s=10,
color='blue', label='Real')
ax3.scatter(final_fake[:, 0], final_fake[:, 1], alpha=0.4, s=10,
color='orange', label='Generated')
ax3.set_title("Overlay Comparison", fontsize=14, fontweight="bold")
ax3.legend(fontsize=10)
plt.tight_layout()
fig.savefig("examples_output/simple_gan_results.png", dpi=150, bbox_inches="tight")

Interpretation:
- Left: Real data forms a clean circle
- Middle: Generated data should approximate the circle
- Right: Overlay shows how well the generator matches the real distribution
Step 12: Evaluate Generator Quality¤
Compute statistical metrics to quantify generator performance:
# Generate large sample for statistics
eval_key = jax.random.key(7777)
eval_real = generate_real_data(eval_key, batch_size=1000)
z_eval = jax.random.normal(jax.random.key(8888), (1000, 10))
eval_fake = generator(z_eval, training=False)
# Calculate statistics
real_mean = jnp.mean(eval_real, axis=0)
fake_mean = jnp.mean(eval_fake, axis=0)
real_std = jnp.std(eval_real, axis=0)
fake_std = jnp.std(eval_fake, axis=0)
# Calculate radius (distance from origin)
real_radius = jnp.sqrt(jnp.sum(eval_real**2, axis=1))
fake_radius = jnp.sqrt(jnp.sum(eval_fake**2, axis=1))
print("\n" + "=" * 60)
print("Generator Evaluation")
print("=" * 60)
print(f"{'Metric':<20} {'Real Data':>15} {'Generated':>15} {'Difference':>15}")
print("-" * 65)
print(f"{'Mean X':<20} {real_mean[0]:>15.4f} {fake_mean[0]:>15.4f} {abs(real_mean[0] - fake_mean[0]):>15.4f}")
print(f"{'Mean Y':<20} {real_mean[1]:>15.4f} {fake_mean[1]:>15.4f} {abs(real_mean[1] - fake_mean[1]):>15.4f}")
print(f"{'Std X':<20} {real_std[0]:>15.4f} {fake_std[0]:>15.4f} {abs(real_std[0] - fake_std[0]):>15.4f}")
print(f"{'Std Y':<20} {real_std[1]:>15.4f} {fake_std[1]:>15.4f} {abs(real_std[1] - fake_std[1]):>15.4f}")
print(f"{'Mean Radius':<20} {jnp.mean(real_radius):>15.4f} {jnp.mean(fake_radius):>15.4f} {abs(jnp.mean(real_radius) - jnp.mean(fake_radius)):>15.4f}")
Output:
============================================================
Generator Evaluation
============================================================
Distribution Statistics (1000 samples):
Metric Real Data Generated Difference
-----------------------------------------------------------------
Mean X -0.0227 0.0865 0.1092
Mean Y -0.0261 0.9504 0.9765
Std X 0.7131 0.2550 0.4581
Std Y 0.7096 0.0622 0.6474
Mean Radius 1.0015 0.9861 0.0154
Std Radius 0.1009 0.0852 0.0157
============================================================
What to Look For:
- Mean X/Y should be close to zero (centered circle)
- Std X/Y should be similar (circular, not elliptical)
- Mean Radius should be close to 1.0
- Lower differences indicate better generator quality
Summary¤
In this example, you learned:
- ✅ Workshop GAN Components: Using Generator and Discriminator classes from Workshop
- ✅ Workshop Loss Functions: Using vanilla_generator_loss and vanilla_discriminator_loss
- ✅ Training: Alternating updates for discriminator and generator
- ✅ Visualization: Comparing real and generated data distributions
- ✅ Evaluation: Quantifying generator quality with statistical metrics
Key Takeaways:
- Workshop APIs: Use Workshop's pre-built components instead of reimplementing from scratch
- Generator: Configurable MLP with tanh output activation (bounded to [-1, 1])
- Discriminator: Configurable MLP with sigmoid output (probability in [0, 1])
- Adversarial Training: The generator and discriminator are trained in opposition
- Balance is Critical: If one network becomes too strong, training can fail
Experiments to Try¤
- Increase Training Steps: Train for 500-1000 steps to see better convergence
- Adjust Learning Rates: Try different learning rates for G and D
gen_optimizer = nnx.Optimizer(generator, optax.adam(1e-4), wrt=nnx.Param)
disc_optimizer = nnx.Optimizer(discriminator, optax.adam(2e-3), wrt=nnx.Param)
- Modify Architecture: Add more layers or neurons
generator = Generator(
hidden_dims=[64, 64, 32], # Deeper network
output_shape=(1, 2),
latent_dim=10,
activation="relu",
rngs=gen_rngs,
)
- Different Data Distribution: Try a different target distribution
- Add Batch Normalization: Stabilize training with batch norm
generator = Generator(
hidden_dims=[32, 32],
output_shape=(1, 2),
latent_dim=10,
activation="relu",
batch_norm=True, # Enable batch normalization
rngs=gen_rngs,
)
- Try Different Loss Functions: Use Workshop's other GAN loss variants
from workshop.generative_models.core.losses.adversarial import (
least_squares_generator_loss,
least_squares_discriminator_loss,
)
# Then replace vanilla losses with least_squares losses
Troubleshooting¤
Training is unstable¤
- Solution: Reduce learning rates, especially for discriminator
- Explanation: Discriminator can become too strong too quickly
- Fix: Try
disc_lr=5e-4andgen_lr=1e-3
Generated samples don't match real data¤
- Solution: Train for more steps (500-1000)
- Note: 100 steps is minimal for demonstration
- Fix: Increase
num_stepsand monitor loss curves
Mode collapse (all generated samples are similar)¤
- Solution: Add batch normalization or try different loss function
- Monitoring: Check diversity in generated samples
- Fix: Use
batch_norm=Trueor try least squares loss
Next Steps¤
After mastering this basic GAN example using Workshop APIs, explore:
Related Examples¤
- VAE Example: Compare GAN with VAE generative approach
- Advanced GANs: DCGAN, Wasserstein GAN, Conditional GAN
- Image GANs: Apply GANs to MNIST and other image datasets
Documentation Resources¤
- GAN Concepts: Deep dive into GAN theory
- GAN User Guide: Advanced usage patterns
- GAN API Reference: Complete API documentation
Papers¤
- Generative Adversarial Networks (Goodfellow et al., 2014)
-
Original GAN paper: https://arxiv.org/abs/1406.2661
-
Improved Techniques for Training GANs (Salimans et al., 2016)
-
Training stability: https://arxiv.org/abs/1606.03498
-
Wasserstein GAN (Arjovsky et al., 2017)
- Improved training: https://arxiv.org/abs/1701.07875
Congratulations! You've successfully trained a GAN using Workshop's APIs! 🎉