GAN Trainer¤
Module: artifex.generative_models.training.trainers.gan_trainer
The GAN Trainer provides specialized training utilities for Generative Adversarial Networks, including multiple loss variants, gradient penalty regularization, and R1 regularization for stable training.
Overview¤
GAN training requires careful balancing between generator and discriminator. The GAN Trainer handles this through:
- Multiple Loss Types: Vanilla, Wasserstein, Hinge, and Least Squares GAN
- Gradient Penalty: WGAN-GP regularization for stable Wasserstein training
- R1 Regularization: Gradient penalty on real data for improved stability
- Label Smoothing: One-sided smoothing to prevent overconfidence
Quick Start¤
from artifex.generative_models.training.trainers import (
GANTrainer,
GANTrainingConfig,
)
from flax import nnx
import optax
import jax
# Create models and optimizers
generator = create_generator(rngs=nnx.Rngs(0))
discriminator = create_discriminator(rngs=nnx.Rngs(1))
g_optimizer = nnx.Optimizer(generator, optax.adam(1e-4, b1=0.5), wrt=nnx.Param)
d_optimizer = nnx.Optimizer(discriminator, optax.adam(1e-4, b1=0.5), wrt=nnx.Param)
# Configure GAN training
config = GANTrainingConfig(
loss_type="wasserstein",
n_critic=5,
gp_weight=10.0,
)
trainer = GANTrainer(
generator=generator,
discriminator=discriminator,
g_optimizer=g_optimizer,
d_optimizer=d_optimizer,
config=config,
)
# Training loop
key = jax.random.key(0)
latent_dim = 128
for step, batch in enumerate(train_loader):
key, d_key, g_key, z_key = jax.random.split(key, 4)
real = batch["image"]
z = jax.random.normal(z_key, (real.shape[0], latent_dim))
# Train discriminator
d_loss, d_metrics = trainer.discriminator_step(real, z, d_key)
# Train generator every n_critic steps
if step % config.n_critic == 0:
z = jax.random.normal(z_key, (real.shape[0], latent_dim))
g_loss, g_metrics = trainer.generator_step(z)
Configuration¤
artifex.generative_models.training.trainers.gan_trainer.GANTrainingConfig
dataclass
¤
GANTrainingConfig(
loss_type: Literal[
"vanilla", "wasserstein", "hinge", "lsgan"
] = "vanilla",
n_critic: int = 1,
gp_weight: float = 10.0,
gp_target: float = 1.0,
r1_weight: float = 0.0,
label_smoothing: float = 0.0,
)
Configuration for GAN-specific training.
Attributes:
| Name | Type | Description |
|---|---|---|
loss_type |
Literal['vanilla', 'wasserstein', 'hinge', 'lsgan']
|
GAN loss variant. - "vanilla": Standard GAN loss (BCE) - "wasserstein": Wasserstein distance (requires gradient penalty) - "hinge": Hinge loss (used in BigGAN, StyleGAN2) - "lsgan": Least squares GAN |
n_critic |
int
|
Discriminator updates per generator update. |
gp_weight |
float
|
Gradient penalty weight (for WGAN-GP). |
gp_target |
float
|
Target gradient norm (usually 1.0). |
r1_weight |
float
|
R1 regularization weight. |
label_smoothing |
float
|
Smooth real labels to [1-smoothing, 1]. |
Configuration Options¤
| Parameter | Type | Default | Description |
|---|---|---|---|
loss_type |
str |
"vanilla" |
Loss variant: "vanilla", "wasserstein", "hinge", "lsgan" |
n_critic |
int |
1 |
Discriminator updates per generator update |
gp_weight |
float |
10.0 |
Gradient penalty weight (WGAN-GP) |
gp_target |
float |
1.0 |
Target gradient norm for GP |
r1_weight |
float |
0.0 |
R1 regularization weight |
label_smoothing |
float |
0.0 |
One-sided label smoothing for real labels |
Loss Types¤
Vanilla GAN (Non-Saturating)¤
Standard GAN with non-saturating generator loss for numerical stability:
config = GANTrainingConfig(loss_type="vanilla")
# Uses log(1 - sigmoid) form with softplus for stability
Wasserstein GAN¤
Earth Mover's distance with gradient penalty:
config = GANTrainingConfig(
loss_type="wasserstein",
gp_weight=10.0, # Required for WGAN-GP
n_critic=5, # More D updates per G update
)
Hinge Loss¤
Hinge loss used in BigGAN and StyleGAN2:
config = GANTrainingConfig(loss_type="hinge")
# D loss: relu(1 - D(real)) + relu(1 + D(fake))
# G loss: -D(fake)
Least Squares GAN¤
Mean squared error between predictions and targets:
Regularization Techniques¤
Gradient Penalty (WGAN-GP)¤
Enforces 1-Lipschitz constraint via gradient penalty on interpolated samples:
config = GANTrainingConfig(
loss_type="wasserstein",
gp_weight=10.0,
gp_target=1.0, # Target gradient norm
)
The gradient penalty is computed as:
where \(\hat{x}\) is interpolated between real and fake samples.
R1 Regularization¤
Gradient penalty on real data only, used in StyleGAN:
R1 penalty is computed as:
Label Smoothing¤
One-sided label smoothing to prevent discriminator overconfidence:
config = GANTrainingConfig(
loss_type="vanilla",
label_smoothing=0.1, # Real labels: 0.9 instead of 1.0
)
API Reference¤
artifex.generative_models.training.trainers.gan_trainer.GANTrainer
¤
GANTrainer(config: GANTrainingConfig | None = None)
GAN-specific trainer with multiple loss variants.
This trainer provides a JIT-compatible interface for adversarial training with support for multiple loss functions and regularization techniques. The step methods take models and optimizers as explicit arguments, allowing them to be wrapped with nnx.jit for performance.
Features
- Multiple loss types (vanilla, wasserstein, hinge, lsgan)
- Configurable discriminator/generator update ratio
- WGAN-GP gradient penalty
- R1 regularization for discriminator
- Label smoothing
Example (non-JIT):
from artifex.generative_models.training.trainers import (
GANTrainer,
GANTrainingConfig,
)
config = GANTrainingConfig(
loss_type="wasserstein",
n_critic=5,
gp_weight=10.0,
)
trainer = GANTrainer(config)
# Create models and optimizers separately
generator = Generator(rngs=nnx.Rngs(0))
discriminator = Discriminator(rngs=nnx.Rngs(1))
g_optimizer = nnx.Optimizer(generator, optax.adam(1e-4))
d_optimizer = nnx.Optimizer(discriminator, optax.adam(1e-4))
# Training loop
for step in range(num_steps):
rng, d_key, g_key = jax.random.split(rng, 3)
d_loss, d_metrics = trainer.discriminator_step(
generator, discriminator, d_optimizer, real_batch, z, d_key
)
if step % config.n_critic == 0:
g_loss, g_metrics = trainer.generator_step(
generator, discriminator, g_optimizer, z
)
Example (JIT-compiled):
trainer = GANTrainer(config)
jit_d_step = nnx.jit(trainer.discriminator_step)
jit_g_step = nnx.jit(trainer.generator_step)
for step in range(num_steps):
d_loss, d_metrics = jit_d_step(
generator, discriminator, d_optimizer, real_batch, z, d_key
)
if step % config.n_critic == 0:
g_loss, g_metrics = jit_g_step(
generator, discriminator, g_optimizer, z
)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GANTrainingConfig | None
|
GAN training configuration. |
None
|
compute_d_loss_vanilla
¤
Compute vanilla GAN discriminator loss.
Uses non-saturating loss from core/losses for numerical stability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_real
|
Array
|
Discriminator output for real samples (logits). |
required |
d_fake
|
Array
|
Discriminator output for fake samples (logits). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Discriminator loss. |
compute_d_loss_wasserstein
¤
Compute Wasserstein discriminator loss.
Uses wasserstein_discriminator_loss from core/losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_real
|
Array
|
Discriminator output for real samples. |
required |
d_fake
|
Array
|
Discriminator output for fake samples. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Discriminator loss (negative critic loss). |
compute_d_loss_hinge
¤
compute_d_loss_lsgan
¤
Compute least squares GAN discriminator loss.
Uses least_squares_discriminator_loss from core/losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_real
|
Array
|
Discriminator output for real samples. |
required |
d_fake
|
Array
|
Discriminator output for fake samples. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Discriminator loss. |
compute_discriminator_loss
¤
compute_g_loss_vanilla
¤
compute_g_loss_wasserstein
¤
compute_g_loss_hinge
¤
compute_g_loss_lsgan
¤
compute_generator_loss
¤
compute_gradient_penalty
¤
Compute WGAN-GP gradient penalty.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
discriminator
|
Module
|
Discriminator model. |
required |
real
|
Array
|
Real samples. |
required |
fake
|
Array
|
Fake samples (must have same shape as real). |
required |
key
|
Array
|
PRNG key for interpolation. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Gradient penalty loss. |
compute_r1_penalty
¤
discriminator_step
¤
discriminator_step(
generator: Module,
discriminator: Module,
d_optimizer: Optimizer,
real: Array,
z: Array,
key: Array,
) -> tuple[Array, dict[str, Any]]
Execute a discriminator training step.
This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.discriminator_step) loss, metrics = jit_step(generator, discriminator, d_optimizer, real, z, key)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generator
|
Module
|
Generator model. |
required |
discriminator
|
Module
|
Discriminator model. |
required |
d_optimizer
|
Optimizer
|
Optimizer for discriminator. |
required |
real
|
Array
|
Real samples. |
required |
z
|
Array
|
Latent vectors for generator. |
required |
key
|
Array
|
PRNG key for gradient penalty. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
generator_step
¤
generator_step(
generator: Module,
discriminator: Module,
g_optimizer: Optimizer,
z: Array,
) -> tuple[Array, dict[str, Any]]
Execute a generator training step.
This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.generator_step) loss, metrics = jit_step(generator, discriminator, g_optimizer, z)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
generator
|
Module
|
Generator model. |
required |
discriminator
|
Module
|
Discriminator model. |
required |
g_optimizer
|
Optimizer
|
Optimizer for generator. |
required |
z
|
Array
|
Latent vectors for generator. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, dict[str, Any]]
|
Tuple of (loss, metrics_dict). |
Training Patterns¤
Standard GAN Training¤
for step, batch in enumerate(train_loader):
key, subkey = jax.random.split(key)
real = batch["image"]
z = jax.random.normal(subkey, (batch_size, latent_dim))
# Train discriminator
d_loss, d_metrics = trainer.discriminator_step(real, z, subkey)
# Train generator (every step for vanilla/hinge/lsgan)
g_loss, g_metrics = trainer.generator_step(z)
WGAN Training (Multiple D Updates)¤
for step, batch in enumerate(train_loader):
key, d_key, g_key = jax.random.split(key, 3)
real = batch["image"]
z = jax.random.normal(d_key, (batch_size, latent_dim))
# Multiple discriminator updates
for _ in range(config.n_critic):
d_loss, d_metrics = trainer.discriminator_step(real, z, d_key)
# Single generator update
z = jax.random.normal(g_key, (batch_size, latent_dim))
g_loss, g_metrics = trainer.generator_step(z)
Progressive Training¤
For high-resolution generation, progressively grow resolution:
resolutions = [4, 8, 16, 32, 64, 128]
for resolution in resolutions:
# Update model for this resolution
generator.grow_layer()
discriminator.grow_layer()
# Train at this resolution
for step in range(steps_per_resolution):
# ... training step ...
Model Requirements¤
Generator Interface¤
class Generator(nnx.Module):
def __call__(self, z: jax.Array) -> jax.Array:
"""Generate images from latent vectors.
Args:
z: Latent vectors, shape (batch, latent_dim).
Returns:
Generated images, shape (batch, H, W, C).
"""
...
Discriminator Interface¤
class Discriminator(nnx.Module):
def __call__(self, x: jax.Array) -> jax.Array:
"""Classify real/fake images.
Args:
x: Images, shape (batch, H, W, C).
Returns:
Logits (unbounded scores), shape (batch,) or (batch, 1).
"""
...
Training Metrics¤
Discriminator Metrics¤
| Metric | Description |
|---|---|
d_loss |
Base discriminator loss |
d_loss_total |
Total loss including regularization |
d_real |
Mean discriminator output on real samples |
d_fake |
Mean discriminator output on fake samples |
gp_loss |
Gradient penalty loss (if enabled) |
r1_loss |
R1 regularization loss (if enabled) |
Generator Metrics¤
| Metric | Description |
|---|---|
g_loss |
Generator loss |
d_fake_g |
Mean discriminator output on generated samples |
Loss Functions¤
The GAN Trainer uses loss functions from artifex.generative_models.core.losses.adversarial:
from artifex.generative_models.core.losses import (
# Vanilla GAN (non-saturating)
ns_vanilla_generator_loss,
ns_vanilla_discriminator_loss,
# Wasserstein
wasserstein_generator_loss,
wasserstein_discriminator_loss,
# Hinge
hinge_generator_loss,
hinge_discriminator_loss,
# Least Squares
least_squares_generator_loss,
least_squares_discriminator_loss,
)