Skip to content

GAN Trainer¤

Module: artifex.generative_models.training.trainers.gan_trainer

The GAN Trainer provides specialized training utilities for Generative Adversarial Networks, including multiple loss variants, gradient penalty regularization, and R1 regularization for stable training.

Overview¤

GAN training requires careful balancing between generator and discriminator. The GAN Trainer handles this through:

  • Multiple Loss Types: Vanilla, Wasserstein, Hinge, and Least Squares GAN
  • Gradient Penalty: WGAN-GP regularization for stable Wasserstein training
  • R1 Regularization: Gradient penalty on real data for improved stability
  • Label Smoothing: One-sided smoothing to prevent overconfidence

Quick Start¤

from artifex.generative_models.training.trainers import (
    GANTrainer,
    GANTrainingConfig,
)
from flax import nnx
import optax
import jax

# Create models and optimizers
generator = create_generator(rngs=nnx.Rngs(0))
discriminator = create_discriminator(rngs=nnx.Rngs(1))

g_optimizer = nnx.Optimizer(generator, optax.adam(1e-4, b1=0.5), wrt=nnx.Param)
d_optimizer = nnx.Optimizer(discriminator, optax.adam(1e-4, b1=0.5), wrt=nnx.Param)

# Configure GAN training
config = GANTrainingConfig(
    loss_type="wasserstein",
    n_critic=5,
    gp_weight=10.0,
)

trainer = GANTrainer(
    generator=generator,
    discriminator=discriminator,
    g_optimizer=g_optimizer,
    d_optimizer=d_optimizer,
    config=config,
)

# Training loop
key = jax.random.key(0)
latent_dim = 128

for step, batch in enumerate(train_loader):
    key, d_key, g_key, z_key = jax.random.split(key, 4)
    real = batch["image"]
    z = jax.random.normal(z_key, (real.shape[0], latent_dim))

    # Train discriminator
    d_loss, d_metrics = trainer.discriminator_step(real, z, d_key)

    # Train generator every n_critic steps
    if step % config.n_critic == 0:
        z = jax.random.normal(z_key, (real.shape[0], latent_dim))
        g_loss, g_metrics = trainer.generator_step(z)

Configuration¤

artifex.generative_models.training.trainers.gan_trainer.GANTrainingConfig dataclass ¤

GANTrainingConfig(
    loss_type: Literal[
        "vanilla", "wasserstein", "hinge", "lsgan"
    ] = "vanilla",
    n_critic: int = 1,
    gp_weight: float = 10.0,
    gp_target: float = 1.0,
    r1_weight: float = 0.0,
    label_smoothing: float = 0.0,
)

Configuration for GAN-specific training.

Attributes:

Name Type Description
loss_type Literal['vanilla', 'wasserstein', 'hinge', 'lsgan']

GAN loss variant. - "vanilla": Standard GAN loss (BCE) - "wasserstein": Wasserstein distance (requires gradient penalty) - "hinge": Hinge loss (used in BigGAN, StyleGAN2) - "lsgan": Least squares GAN

n_critic int

Discriminator updates per generator update.

gp_weight float

Gradient penalty weight (for WGAN-GP).

gp_target float

Target gradient norm (usually 1.0).

r1_weight float

R1 regularization weight.

label_smoothing float

Smooth real labels to [1-smoothing, 1].

loss_type class-attribute instance-attribute ¤

loss_type: Literal[
    "vanilla", "wasserstein", "hinge", "lsgan"
] = "vanilla"

n_critic class-attribute instance-attribute ¤

n_critic: int = 1

gp_weight class-attribute instance-attribute ¤

gp_weight: float = 10.0

gp_target class-attribute instance-attribute ¤

gp_target: float = 1.0

r1_weight class-attribute instance-attribute ¤

r1_weight: float = 0.0

label_smoothing class-attribute instance-attribute ¤

label_smoothing: float = 0.0

Configuration Options¤

Parameter Type Default Description
loss_type str "vanilla" Loss variant: "vanilla", "wasserstein", "hinge", "lsgan"
n_critic int 1 Discriminator updates per generator update
gp_weight float 10.0 Gradient penalty weight (WGAN-GP)
gp_target float 1.0 Target gradient norm for GP
r1_weight float 0.0 R1 regularization weight
label_smoothing float 0.0 One-sided label smoothing for real labels

Loss Types¤

Vanilla GAN (Non-Saturating)¤

Standard GAN with non-saturating generator loss for numerical stability:

config = GANTrainingConfig(loss_type="vanilla")
# Uses log(1 - sigmoid) form with softplus for stability

Wasserstein GAN¤

Earth Mover's distance with gradient penalty:

config = GANTrainingConfig(
    loss_type="wasserstein",
    gp_weight=10.0,  # Required for WGAN-GP
    n_critic=5,      # More D updates per G update
)

Hinge Loss¤

Hinge loss used in BigGAN and StyleGAN2:

config = GANTrainingConfig(loss_type="hinge")
# D loss: relu(1 - D(real)) + relu(1 + D(fake))
# G loss: -D(fake)

Least Squares GAN¤

Mean squared error between predictions and targets:

config = GANTrainingConfig(loss_type="lsgan")
# More stable gradients than vanilla GAN

Regularization Techniques¤

Gradient Penalty (WGAN-GP)¤

Enforces 1-Lipschitz constraint via gradient penalty on interpolated samples:

config = GANTrainingConfig(
    loss_type="wasserstein",
    gp_weight=10.0,
    gp_target=1.0,  # Target gradient norm
)

The gradient penalty is computed as:

\[\lambda \mathbb{E}_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]\]

where \(\hat{x}\) is interpolated between real and fake samples.

R1 Regularization¤

Gradient penalty on real data only, used in StyleGAN:

config = GANTrainingConfig(
    loss_type="hinge",
    r1_weight=10.0,  # R1 penalty weight
)

R1 penalty is computed as:

\[\frac{\gamma}{2} \mathbb{E}_{x \sim p_{data}}[||\nabla_x D(x)||_2^2]\]

Label Smoothing¤

One-sided label smoothing to prevent discriminator overconfidence:

config = GANTrainingConfig(
    loss_type="vanilla",
    label_smoothing=0.1,  # Real labels: 0.9 instead of 1.0
)

API Reference¤

artifex.generative_models.training.trainers.gan_trainer.GANTrainer ¤

GANTrainer(config: GANTrainingConfig | None = None)

GAN-specific trainer with multiple loss variants.

This trainer provides a JIT-compatible interface for adversarial training with support for multiple loss functions and regularization techniques. The step methods take models and optimizers as explicit arguments, allowing them to be wrapped with nnx.jit for performance.

Features
  • Multiple loss types (vanilla, wasserstein, hinge, lsgan)
  • Configurable discriminator/generator update ratio
  • WGAN-GP gradient penalty
  • R1 regularization for discriminator
  • Label smoothing

Example (non-JIT):

from artifex.generative_models.training.trainers import (
    GANTrainer,
    GANTrainingConfig,
)

config = GANTrainingConfig(
    loss_type="wasserstein",
    n_critic=5,
    gp_weight=10.0,
)
trainer = GANTrainer(config)

# Create models and optimizers separately
generator = Generator(rngs=nnx.Rngs(0))
discriminator = Discriminator(rngs=nnx.Rngs(1))
g_optimizer = nnx.Optimizer(generator, optax.adam(1e-4))
d_optimizer = nnx.Optimizer(discriminator, optax.adam(1e-4))

# Training loop
for step in range(num_steps):
    rng, d_key, g_key = jax.random.split(rng, 3)
    d_loss, d_metrics = trainer.discriminator_step(
        generator, discriminator, d_optimizer, real_batch, z, d_key
    )
    if step % config.n_critic == 0:
        g_loss, g_metrics = trainer.generator_step(
            generator, discriminator, g_optimizer, z
        )

Example (JIT-compiled):

trainer = GANTrainer(config)
jit_d_step = nnx.jit(trainer.discriminator_step)
jit_g_step = nnx.jit(trainer.generator_step)

for step in range(num_steps):
    d_loss, d_metrics = jit_d_step(
        generator, discriminator, d_optimizer, real_batch, z, d_key
    )
    if step % config.n_critic == 0:
        g_loss, g_metrics = jit_g_step(
            generator, discriminator, g_optimizer, z
        )

Parameters:

Name Type Description Default
config GANTrainingConfig | None

GAN training configuration.

None

config instance-attribute ¤

config = config or GANTrainingConfig()

compute_d_loss_vanilla ¤

compute_d_loss_vanilla(
    d_real: Array, d_fake: Array
) -> Array

Compute vanilla GAN discriminator loss.

Uses non-saturating loss from core/losses for numerical stability.

Parameters:

Name Type Description Default
d_real Array

Discriminator output for real samples (logits).

required
d_fake Array

Discriminator output for fake samples (logits).

required

Returns:

Type Description
Array

Discriminator loss.

compute_d_loss_wasserstein ¤

compute_d_loss_wasserstein(
    d_real: Array, d_fake: Array
) -> Array

Compute Wasserstein discriminator loss.

Uses wasserstein_discriminator_loss from core/losses.

Parameters:

Name Type Description Default
d_real Array

Discriminator output for real samples.

required
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Discriminator loss (negative critic loss).

compute_d_loss_hinge ¤

compute_d_loss_hinge(d_real: Array, d_fake: Array) -> Array

Compute hinge discriminator loss.

Uses hinge_discriminator_loss from core/losses.

Parameters:

Name Type Description Default
d_real Array

Discriminator output for real samples.

required
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Discriminator loss.

compute_d_loss_lsgan ¤

compute_d_loss_lsgan(d_real: Array, d_fake: Array) -> Array

Compute least squares GAN discriminator loss.

Uses least_squares_discriminator_loss from core/losses.

Parameters:

Name Type Description Default
d_real Array

Discriminator output for real samples.

required
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Discriminator loss.

compute_discriminator_loss ¤

compute_discriminator_loss(
    d_real: Array, d_fake: Array
) -> Array

Compute discriminator loss based on configured loss type.

Parameters:

Name Type Description Default
d_real Array

Discriminator output for real samples.

required
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Discriminator loss.

compute_g_loss_vanilla ¤

compute_g_loss_vanilla(d_fake: Array) -> Array

Compute vanilla GAN generator loss.

Uses ns_vanilla_generator_loss from core/losses.

Parameters:

Name Type Description Default
d_fake Array

Discriminator output for fake samples (logits).

required

Returns:

Type Description
Array

Generator loss.

compute_g_loss_wasserstein ¤

compute_g_loss_wasserstein(d_fake: Array) -> Array

Compute Wasserstein generator loss.

Uses wasserstein_generator_loss from core/losses.

Parameters:

Name Type Description Default
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Generator loss.

compute_g_loss_hinge ¤

compute_g_loss_hinge(d_fake: Array) -> Array

Compute hinge generator loss.

Uses hinge_generator_loss from core/losses.

Parameters:

Name Type Description Default
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Generator loss.

compute_g_loss_lsgan ¤

compute_g_loss_lsgan(d_fake: Array) -> Array

Compute least squares GAN generator loss.

Uses least_squares_generator_loss from core/losses.

Parameters:

Name Type Description Default
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Generator loss.

compute_generator_loss ¤

compute_generator_loss(d_fake: Array) -> Array

Compute generator loss based on configured loss type.

Parameters:

Name Type Description Default
d_fake Array

Discriminator output for fake samples.

required

Returns:

Type Description
Array

Generator loss.

compute_gradient_penalty ¤

compute_gradient_penalty(
    discriminator: Module,
    real: Array,
    fake: Array,
    key: Array,
) -> Array

Compute WGAN-GP gradient penalty.

Parameters:

Name Type Description Default
discriminator Module

Discriminator model.

required
real Array

Real samples.

required
fake Array

Fake samples (must have same shape as real).

required
key Array

PRNG key for interpolation.

required

Returns:

Type Description
Array

Gradient penalty loss.

compute_r1_penalty ¤

compute_r1_penalty(
    discriminator: Module, real: Array
) -> Array

Compute R1 regularization penalty.

Parameters:

Name Type Description Default
discriminator Module

Discriminator model.

required
real Array

Real samples.

required

Returns:

Type Description
Array

R1 penalty.

discriminator_step ¤

discriminator_step(
    generator: Module,
    discriminator: Module,
    d_optimizer: Optimizer,
    real: Array,
    z: Array,
    key: Array,
) -> tuple[Array, dict[str, Any]]

Execute a discriminator training step.

This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.discriminator_step) loss, metrics = jit_step(generator, discriminator, d_optimizer, real, z, key)

Parameters:

Name Type Description Default
generator Module

Generator model.

required
discriminator Module

Discriminator model.

required
d_optimizer Optimizer

Optimizer for discriminator.

required
real Array

Real samples.

required
z Array

Latent vectors for generator.

required
key Array

PRNG key for gradient penalty.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

generator_step ¤

generator_step(
    generator: Module,
    discriminator: Module,
    g_optimizer: Optimizer,
    z: Array,
) -> tuple[Array, dict[str, Any]]

Execute a generator training step.

This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.generator_step) loss, metrics = jit_step(generator, discriminator, g_optimizer, z)

Parameters:

Name Type Description Default
generator Module

Generator model.

required
discriminator Module

Discriminator model.

required
g_optimizer Optimizer

Optimizer for generator.

required
z Array

Latent vectors for generator.

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

Training Patterns¤

Standard GAN Training¤

for step, batch in enumerate(train_loader):
    key, subkey = jax.random.split(key)
    real = batch["image"]
    z = jax.random.normal(subkey, (batch_size, latent_dim))

    # Train discriminator
    d_loss, d_metrics = trainer.discriminator_step(real, z, subkey)

    # Train generator (every step for vanilla/hinge/lsgan)
    g_loss, g_metrics = trainer.generator_step(z)

WGAN Training (Multiple D Updates)¤

for step, batch in enumerate(train_loader):
    key, d_key, g_key = jax.random.split(key, 3)
    real = batch["image"]
    z = jax.random.normal(d_key, (batch_size, latent_dim))

    # Multiple discriminator updates
    for _ in range(config.n_critic):
        d_loss, d_metrics = trainer.discriminator_step(real, z, d_key)

    # Single generator update
    z = jax.random.normal(g_key, (batch_size, latent_dim))
    g_loss, g_metrics = trainer.generator_step(z)

Progressive Training¤

For high-resolution generation, progressively grow resolution:

resolutions = [4, 8, 16, 32, 64, 128]

for resolution in resolutions:
    # Update model for this resolution
    generator.grow_layer()
    discriminator.grow_layer()

    # Train at this resolution
    for step in range(steps_per_resolution):
        # ... training step ...

Model Requirements¤

Generator Interface¤

class Generator(nnx.Module):
    def __call__(self, z: jax.Array) -> jax.Array:
        """Generate images from latent vectors.

        Args:
            z: Latent vectors, shape (batch, latent_dim).

        Returns:
            Generated images, shape (batch, H, W, C).
        """
        ...

Discriminator Interface¤

class Discriminator(nnx.Module):
    def __call__(self, x: jax.Array) -> jax.Array:
        """Classify real/fake images.

        Args:
            x: Images, shape (batch, H, W, C).

        Returns:
            Logits (unbounded scores), shape (batch,) or (batch, 1).
        """
        ...

Training Metrics¤

Discriminator Metrics¤

Metric Description
d_loss Base discriminator loss
d_loss_total Total loss including regularization
d_real Mean discriminator output on real samples
d_fake Mean discriminator output on fake samples
gp_loss Gradient penalty loss (if enabled)
r1_loss R1 regularization loss (if enabled)

Generator Metrics¤

Metric Description
g_loss Generator loss
d_fake_g Mean discriminator output on generated samples

Loss Functions¤

The GAN Trainer uses loss functions from artifex.generative_models.core.losses.adversarial:

from artifex.generative_models.core.losses import (
    # Vanilla GAN (non-saturating)
    ns_vanilla_generator_loss,
    ns_vanilla_discriminator_loss,
    # Wasserstein
    wasserstein_generator_loss,
    wasserstein_discriminator_loss,
    # Hinge
    hinge_generator_loss,
    hinge_discriminator_loss,
    # Least Squares
    least_squares_generator_loss,
    least_squares_discriminator_loss,
)

References¤