Skip to content

GAN API Reference¤

Complete API reference for all GAN model classes in Workshop.

Overview¤

The GAN module provides implementations of various Generative Adversarial Network architectures:

  • Base GAN: Standard generator and discriminator
  • DCGAN: Deep convolutional architecture
  • WGAN: Wasserstein distance with gradient penalty
  • LSGAN: Least squares loss
  • Conditional GAN: Class-conditioned generation
  • CycleGAN: Unpaired image-to-image translation
  • PatchGAN: Patch-based discrimination

Base Classes¤

Generator¤

workshop.generative_models.models.gan.Generator ¤

Generator(config: GeneratorConfig, *, rngs: Rngs)

Bases: Module

Generator network for GAN.

Base generator using fully-connected (Dense) layers. For convolutional generators, use DCGANGenerator or other specialized subclasses.

Parameters:

Name Type Description Default
config GeneratorConfig

GeneratorConfig with network architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not GeneratorConfig

config instance-attribute ¤

config = config

hidden_dims instance-attribute ¤

hidden_dims = hidden_dims

output_shape instance-attribute ¤

output_shape = output_shape

latent_dim instance-attribute ¤

latent_dim = latent_dim

batch_norm instance-attribute ¤

batch_norm = batch_norm

dropout_rate instance-attribute ¤

dropout_rate = dropout_rate

activation_fn instance-attribute ¤

activation_fn = _get_activation_fn(activation)

layers instance-attribute ¤

layers = List(layers_list)

output_layer instance-attribute ¤

output_layer = Linear(in_features=last_dim, out_features=output_size, rngs=rngs)

bn_layers instance-attribute ¤

bn_layers = List([(BatchNorm(dim, rngs=rngs)) for dim in hidden_dims])

dropout instance-attribute ¤

dropout = Dropout(rate=dropout_rate, rngs=rngs)

Basic generator network that transforms latent vectors into data samples.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
hidden_dims list[int] Required Hidden layer dimensions
output_shape tuple Required Shape of generated samples (batch, C, H, W)
latent_dim int Required Dimension of latent space
activation str "relu" Activation function name
batch_norm bool True Whether to use batch normalization
dropout_rate float 0.0 Dropout rate
rngs nnx.Rngs Required Random number generators

Methods:

__call__(z, training=False)¤

Generate samples from latent vectors.

Parameters:

  • z (jax.Array): Latent vectors of shape (batch_size, latent_dim)
  • training (bool): Whether in training mode (affects batch norm and dropout)

Returns:

  • jax.Array: Generated samples of shape (batch_size, *output_shape[1:])

Example:

from workshop.generative_models.models.gan import Generator
from flax import nnx
import jax.numpy as jnp

# Create generator
generator = Generator(
    hidden_dims=[256, 512, 1024],
    output_shape=(1, 1, 28, 28),  # MNIST
    latent_dim=100,
    activation="relu",
    batch_norm=True,
    rngs=nnx.Rngs(params=0),
)

# Generate samples
z = jnp.ones((32, 100))  # Batch of latent vectors
samples = generator(z, training=False)
print(samples.shape)  # (32, 1, 28, 28)

Discriminator¤

workshop.generative_models.models.gan.Discriminator ¤

Discriminator(config: DiscriminatorConfig, *, rngs: Rngs)

Bases: Module

Discriminator network for GAN.

Base discriminator using fully-connected (Dense) layers. For convolutional discriminators, use DCGANDiscriminator or other specialized subclasses.

Parameters:

Name Type Description Default
config DiscriminatorConfig

DiscriminatorConfig with network architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not DiscriminatorConfig

config instance-attribute ¤

config = config

hidden_dims instance-attribute ¤

hidden_dims = hidden_dims

input_shape instance-attribute ¤

input_shape = input_shape

activation instance-attribute ¤

activation = activation

leaky_relu_slope instance-attribute ¤

leaky_relu_slope = leaky_relu_slope

batch_norm instance-attribute ¤

batch_norm = batch_norm

dropout_rate instance-attribute ¤

dropout_rate = dropout_rate

use_spectral_norm instance-attribute ¤

use_spectral_norm = use_spectral_norm

activation_fn instance-attribute ¤

activation_fn = _get_activation_fn(activation, leaky_relu_slope)

layers instance-attribute ¤

layers = List(layers_list)

output_layer instance-attribute ¤

output_layer = Linear(in_features=curr_dim, out_features=1, rngs=rngs)

bn_layers instance-attribute ¤

bn_layers = List([(BatchNorm(dim, rngs=rngs)) for dim in hidden_dims])

dropout instance-attribute ¤

dropout = Dropout(rate=dropout_rate, rngs=rngs)

Basic discriminator network that classifies samples as real or fake.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
hidden_dims list[int] Required Hidden layer dimensions
activation str "leaky_relu" Activation function name
leaky_relu_slope float 0.2 Negative slope for LeakyReLU
batch_norm bool False Whether to use batch normalization
dropout_rate float 0.3 Dropout rate
use_spectral_norm bool False Whether to use spectral normalization
rngs nnx.Rngs Required Random number generators

Methods:

initialize_layers(input_shape, rngs=None)¤

Initialize layers based on input shape.

Parameters:

  • input_shape (tuple): Shape of input data (batch, C, H, W)
  • rngs (nnx.Rngs, optional): Random number generators

__call__(x, training=False)¤

Classify samples as real or fake.

Parameters:

  • x (jax.Array): Input samples of shape (batch_size, C, H, W)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Discrimination scores of shape (batch_size, 1), values in [0, 1]

Example:

from workshop.generative_models.models.gan import Discriminator
from flax import nnx
import jax.numpy as jnp

# Create discriminator
discriminator = Discriminator(
    hidden_dims=[512, 256, 128],
    activation="leaky_relu",
    leaky_relu_slope=0.2,
    batch_norm=False,
    dropout_rate=0.3,
    rngs=nnx.Rngs(params=0, dropout=1),
)

# Classify samples
samples = jnp.ones((32, 1, 28, 28))
scores = discriminator(samples, training=True)
print(scores.shape)  # (32, 1)
print(f"Scores range: [{scores.min():.3f}, {scores.max():.3f}]")

GAN¤

workshop.generative_models.models.gan.GAN ¤

GAN(config: GANConfig, *, rngs: Rngs, precision: Precision | None = None)

Bases: GenerativeModel

Generative Adversarial Network (GAN) implementation.

Base GAN using fully-connected Generator and Discriminator. For convolutional GANs, use DCGAN or other specialized subclasses.

Parameters:

Name Type Description Default
config GANConfig

GANConfig with nested GeneratorConfig and DiscriminatorConfig

required
rngs Rngs

Random number generators

required
precision Precision | None

Numerical precision for computations

None

Raises:

Type Description
ValueError

If rngs is None or missing 'sample' stream

TypeError

If config is not GANConfig

config instance-attribute ¤

config = config

rngs instance-attribute ¤

rngs = rngs

latent_dim instance-attribute ¤

latent_dim = latent_dim

loss_type instance-attribute ¤

loss_type = loss_type

gradient_penalty_weight instance-attribute ¤

gradient_penalty_weight = gradient_penalty_weight

generator instance-attribute ¤

generator = Generator(config=generator, rngs=rngs)

discriminator instance-attribute ¤

discriminator = Discriminator(config=discriminator, rngs=rngs)

generate ¤

generate(n_samples: int = 1, *, batch_size: int | None = None, **kwargs: Any) -> Array

Generate samples from the generator.

Note: Uses stored self.rngs for sampling. RNG automatically advances each call.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
batch_size int | None

Alternative way to specify number of samples (for compatibility)

None
**kwargs Any

Additional keyword arguments

{}

Returns:

Type Description
Array

Generated samples of shape (num_samples, *output_shape[1:])

loss_fn ¤

loss_fn(batch: dict[str, Any], model_outputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]

Compute GAN loss for training.

Note: Uses stored self.rngs for sampling. RNG automatically advances each call.

Parameters:

Name Type Description Default
batch dict[str, Any]

Input batch containing real data (dict with 'x' or 'data' key, or raw array)

required
model_outputs dict[str, Any]

Model outputs (unused for GAN loss computation)

required
**kwargs Any

Additional keyword arguments

{}

Returns:

Type Description
dict[str, Any]

Dictionary containing: - loss: Total combined loss (generator + discriminator) - generator_loss: Generator loss component - discriminator_loss: Discriminator loss component - real_scores_mean: Mean discriminator score on real data - fake_scores_mean: Mean discriminator score on fake data

Complete GAN model combining generator and discriminator.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
config object Required Model configuration object
rngs nnx.Rngs Required Random number generators
precision jax.lax.Precision None Numerical precision

Configuration Object:

The config must have the following structure:

class GANConfig:
    latent_dim: int = 100                  # Latent space dimension
    loss_type: str = "vanilla"             # Loss type: "vanilla", "wasserstein", "least_squares", "hinge"
    gradient_penalty_weight: float = 0.0   # Weight for gradient penalty (WGAN-GP)

    class generator:
        hidden_dims: list[int]             # Generator hidden dimensions
        output_shape: tuple                # Output shape (batch, C, H, W)
        activation: str = "relu"
        batch_norm: bool = True
        dropout_rate: float = 0.0

    class discriminator:
        hidden_dims: list[int]             # Discriminator hidden dimensions
        activation: str = "leaky_relu"
        leaky_relu_slope: float = 0.2
        batch_norm: bool = False
        dropout_rate: float = 0.3
        use_spectral_norm: bool = False

Methods:

__call__(x, rngs=None, training=False, **kwargs)¤

Forward pass through the GAN (runs discriminator only).

Parameters:

  • x (jax.Array): Input data
  • rngs (nnx.Rngs, optional): Random number generators
  • training (bool): Whether in training mode

Returns:

  • dict: Dictionary with keys:
  • "real_scores": Discriminator scores for real data
  • "fake_scores": None (computed in loss_fn)
  • "fake_samples": None (computed in loss_fn)

generate(n_samples=1, rngs=None, batch_size=None, **kwargs)¤

Generate samples from the generator.

Parameters:

  • n_samples (int): Number of samples to generate
  • rngs (nnx.Rngs, optional): Random number generators
  • batch_size (int, optional): Alternative to n_samples

Returns:

  • jax.Array: Generated samples

loss_fn(batch, model_outputs, rngs=None, **kwargs)¤

Compute GAN loss for training.

Parameters:

  • batch (dict or jax.Array): Input batch (real data)
  • model_outputs (dict): Model outputs (unused for GAN)
  • rngs (nnx.Rngs, optional): Random number generators

Returns:

  • dict: Dictionary with losses:
  • "loss": Total loss (generator + discriminator)
  • "generator_loss": Generator loss
  • "discriminator_loss": Discriminator loss
  • "real_scores_mean": Mean discriminator score for real samples
  • "fake_scores_mean": Mean discriminator score for fake samples

Example:

from workshop.generative_models.models.gan import GAN
from flax import nnx

# Create configuration
class GANConfig:
    latent_dim = 100
    loss_type = "vanilla"

    class generator:
        hidden_dims = [256, 512]
        output_shape = (1, 1, 28, 28)
        activation = "relu"
        batch_norm = True
        dropout_rate = 0.0

    class discriminator:
        hidden_dims = [512, 256]
        activation = "leaky_relu"
        leaky_relu_slope = 0.2
        batch_norm = False
        dropout_rate = 0.3
        use_spectral_norm = False

# Create GAN
gan = GAN(GANConfig(), rngs=nnx.Rngs(params=0, dropout=1, sample=2))

# Generate samples
samples = gan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))

# Compute loss
import jax.numpy as jnp
batch = jnp.ones((32, 1, 28, 28))
losses = gan.loss_fn(batch, None, rngs=nnx.Rngs(sample=0))
print(f"Generator Loss: {losses['generator_loss']:.4f}")
print(f"Discriminator Loss: {losses['discriminator_loss']:.4f}")

DCGAN¤

DCGANGenerator¤

workshop.generative_models.models.gan.DCGANGenerator ¤

DCGANGenerator(config: ConvGeneratorConfig, *, rngs: Rngs)

Bases: Generator

Deep Convolutional GAN Generator.

Uses transposed convolutions for progressive upsampling from latent vector to output image. All configuration comes from ConvGeneratorConfig.

Parameters:

Name Type Description Default
config ConvGeneratorConfig

ConvGeneratorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConvGeneratorConfig

init_h instance-attribute ¤

init_h = max(1, height // stride_factor ** num_upsample_layers)

init_w instance-attribute ¤

init_w = max(1, width // stride_factor ** num_upsample_layers)

initial_linear instance-attribute ¤

initial_linear = Linear(in_features=latent_dim, out_features=initial_features, rngs=rngs)

conv_transpose_layers instance-attribute ¤

conv_transpose_layers = List([])

dcgan_batch_norms instance-attribute ¤

dcgan_batch_norms = List([])

output_conv instance-attribute ¤

output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)

Deep Convolutional GAN generator using transposed convolutions.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
output_shape tuple[int, ...] Required Output image shape (C, H, W)
latent_dim int 100 Latent space dimension
hidden_dims tuple[int, ...] (256, 128, 64, 32) Channel dimensions per layer
activation callable jax.nn.relu Activation function
batch_norm bool True Use batch normalization
dropout_rate float 0.0 Dropout rate
rngs nnx.Rngs Required Random number generators

Methods:

__call__(z, training=True)¤

Generate images from latent vectors.

Parameters:

  • z (jax.Array): Latent vectors of shape (batch_size, latent_dim)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Generated images of shape (batch_size, C, H, W)

Example:

from workshop.generative_models.models.gan import DCGANGenerator
from flax import nnx
import jax
import jax.numpy as jnp

generator = DCGANGenerator(
    output_shape=(3, 64, 64),            # RGB 64×64 images
    latent_dim=100,
    hidden_dims=(256, 128, 64, 32),
    activation=jax.nn.relu,
    batch_norm=True,
    rngs=nnx.Rngs(params=0),
)

# Generate samples
z = jax.random.normal(jax.random.key(0), (16, 100))
images = generator(z, training=False)
print(images.shape)  # (16, 3, 64, 64)

DCGANDiscriminator¤

workshop.generative_models.models.gan.DCGANDiscriminator ¤

DCGANDiscriminator(config: ConvDiscriminatorConfig, *, rngs: Rngs)

Bases: Discriminator

Deep Convolutional GAN Discriminator.

Uses strided convolutions for progressive downsampling from input image to binary classification. All configuration comes from ConvDiscriminatorConfig.

Parameters:

Name Type Description Default
config ConvDiscriminatorConfig

ConvDiscriminatorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConvDiscriminatorConfig

conv_layers instance-attribute ¤

conv_layers = List([])

dcgan_batch_norms instance-attribute ¤

dcgan_batch_norms = List([])

final_linear instance-attribute ¤

final_linear = Linear(in_features=final_features, out_features=output_dim, rngs=rngs)

Deep Convolutional GAN discriminator using strided convolutions.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
input_shape tuple[int, ...] Required Input image shape (C, H, W)
hidden_dims tuple[int, ...] (32, 64, 128, 256) Channel dimensions per layer
activation callable jax.nn.leaky_relu Activation function
leaky_relu_slope float 0.2 Negative slope for LeakyReLU
batch_norm bool False Use batch normalization
dropout_rate float 0.3 Dropout rate
use_spectral_norm bool True Use spectral normalization
rngs nnx.Rngs Required Random number generators

Methods:

__call__(x, training=True)¤

Classify images as real or fake.

Parameters:

  • x (jax.Array): Input images of shape (batch_size, C, H, W)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Discrimination scores of shape (batch_size, 1)

Example:

from workshop.generative_models.models.gan import DCGANDiscriminator
from flax import nnx
import jax.numpy as jnp

discriminator = DCGANDiscriminator(
    input_shape=(3, 64, 64),
    hidden_dims=(32, 64, 128, 256),
    activation=jax.nn.leaky_relu,
    leaky_relu_slope=0.2,
    batch_norm=False,
    dropout_rate=0.3,
    use_spectral_norm=True,
    rngs=nnx.Rngs(params=0, dropout=1),
)

# Classify images
images = jnp.ones((16, 3, 64, 64))
scores = discriminator(images, training=True)
print(scores.shape)  # (16, 1)

DCGAN¤

workshop.generative_models.models.gan.DCGAN ¤

DCGAN(config: DCGANConfig, *, rngs: Rngs)

Bases: GAN

Deep Convolutional GAN (DCGAN) model.

Uses DCGANConfig which contains ConvGeneratorConfig and ConvDiscriminatorConfig for complete architecture specification.

Parameters:

Name Type Description Default
config DCGANConfig

DCGANConfig with nested ConvGeneratorConfig and ConvDiscriminatorConfig

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not DCGANConfig

config instance-attribute ¤

config = config

rngs instance-attribute ¤

rngs = rngs

generator instance-attribute ¤

generator = DCGANGenerator(config=generator, rngs=rngs)

discriminator instance-attribute ¤

discriminator = DCGANDiscriminator(config=discriminator, rngs=rngs)

latent_dim instance-attribute ¤

latent_dim = latent_dim

loss_type instance-attribute ¤

loss_type = loss_type

gradient_penalty_weight instance-attribute ¤

gradient_penalty_weight = gradient_penalty_weight

generator_lr instance-attribute ¤

generator_lr = generator_lr

discriminator_lr instance-attribute ¤

discriminator_lr = discriminator_lr

beta1 instance-attribute ¤

beta1 = beta1

beta2 instance-attribute ¤

beta2 = beta2

Complete Deep Convolutional GAN model.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
config DCGANConfiguration Required DCGAN configuration
rngs nnx.Rngs Required Random number generators

Configuration:

Use DCGANConfiguration from workshop.generative_models.core.configuration.gan:

from workshop.generative_models.core.configuration.gan import DCGANConfiguration

config = DCGANConfiguration(
    image_size=64,                        # Image size (H=W)
    channels=3,                           # Number of channels
    latent_dim=100,                       # Latent dimension
    gen_hidden_dims=(256, 128, 64, 32),  # Generator channels
    disc_hidden_dims=(32, 64, 128, 256), # Discriminator channels
    loss_type="vanilla",                  # Loss type
    generator_lr=0.0002,                  # Generator learning rate
    discriminator_lr=0.0002,              # Discriminator learning rate
    beta1=0.5,                            # Adam β1
    beta2=0.999,                          # Adam β2
)

Example:

from workshop.generative_models.models.gan import DCGAN
from workshop.generative_models.core.configuration.gan import DCGANConfiguration
from flax import nnx

config = DCGANConfiguration(
    image_size=64,
    channels=3,
    latent_dim=100,
    gen_hidden_dims=(256, 128, 64, 32),
    disc_hidden_dims=(32, 64, 128, 256),
    loss_type="vanilla",
)

dcgan = DCGAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))

# Generate samples
samples = dcgan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))
print(samples.shape)  # (16, 3, 64, 64)

WGAN¤

WGANGenerator¤

workshop.generative_models.models.gan.WGANGenerator ¤

WGANGenerator(config: ConvGeneratorConfig, *, rngs: Rngs)

Bases: Generator

Wasserstein GAN Generator using convolutional architecture.

Based on the PyTorch WGAN-GP reference implementation: - Uses ConvTranspose layers like DCGAN - BatchNorm is typically used in WGAN generators - Tanh activation at output

Parameters:

Name Type Description Default
config ConvGeneratorConfig

ConvGeneratorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConvGeneratorConfig

wgan_activation_fn instance-attribute ¤

wgan_activation_fn = _get_activation_fn(activation)

initial_linear instance-attribute ¤

initial_linear = Linear(in_features=latent_dim, out_features=init_h * init_w * hidden_dims_list[0], rngs=rngs)

initial_bn instance-attribute ¤

initial_bn = BatchNorm(num_features=hidden_dims_list[0], use_running_average=batch_norm_use_running_avg, momentum=batch_norm_momentum, rngs=rngs)

conv_transpose_layers instance-attribute ¤

conv_transpose_layers = List([])

wgan_batch_norm_layers instance-attribute ¤

wgan_batch_norm_layers = List([])

output_conv instance-attribute ¤

output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)

Wasserstein GAN generator with convolutional architecture.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
output_shape tuple[int, ...] Required Output image shape (C, H, W)
latent_dim int 100 Latent space dimension
hidden_dims tuple[int, ...] (1024, 512, 256) Channel dimensions
activation callable jax.nn.relu Activation function
batch_norm bool True Use batch normalization
dropout_rate float 0.0 Dropout rate
rngs nnx.Rngs Required Random number generators

Methods:

__call__(z, training=True)¤

Generate images from latent vectors.

Parameters:

  • z (jax.Array): Latent vectors of shape (batch_size, latent_dim)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Generated images of shape (batch_size, C, H, W)

Example:

from workshop.generative_models.models.gan import WGANGenerator
from flax import nnx
import jax

generator = WGANGenerator(
    output_shape=(3, 64, 64),
    latent_dim=100,
    hidden_dims=(1024, 512, 256),
    activation=jax.nn.relu,
    batch_norm=True,
    rngs=nnx.Rngs(params=0),
)

z = jax.random.normal(jax.random.key(0), (16, 100))
images = generator(z, training=False)
print(images.shape)  # (16, 3, 64, 64)

WGANDiscriminator¤

workshop.generative_models.models.gan.WGANDiscriminator ¤

WGANDiscriminator(config: ConvDiscriminatorConfig, *, rngs: Rngs)

Bases: Discriminator

Wasserstein GAN Discriminator (Critic) using convolutional architecture.

Key differences from standard discriminator: - Uses InstanceNorm instead of BatchNorm (as per WGAN-GP paper) - No sigmoid activation at the end (outputs raw scores) - LeakyReLU activation

Parameters:

Name Type Description Default
config ConvDiscriminatorConfig

ConvDiscriminatorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConvDiscriminatorConfig

conv_layers instance-attribute ¤

conv_layers = List([])

norm_layers instance-attribute ¤

norm_layers = List([])

output_conv instance-attribute ¤

output_conv = Conv(in_features=hidden_dims_list[-1], out_features=1, kernel_size=(4, 4), strides=(1, 1), padding='VALID', rngs=rngs)

wgan_use_instance_norm instance-attribute ¤

wgan_use_instance_norm = use_instance_norm

Wasserstein GAN discriminator (critic) with instance normalization.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
input_shape tuple[int, ...] Required Input image shape (C, H, W)
hidden_dims tuple[int, ...] (256, 512, 1024) Channel dimensions
activation callable jax.nn.leaky_relu Activation function
leaky_relu_slope float 0.2 Negative slope for LeakyReLU
use_instance_norm bool True Use instance normalization
dropout_rate float 0.0 Dropout rate
rngs nnx.Rngs Required Random number generators

Methods:

__call__(x, training=True)¤

Compute critic scores (no sigmoid activation).

Parameters:

  • x (jax.Array): Input images of shape (batch_size, C, H, W)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Raw critic scores (no sigmoid) of shape (batch_size,)

Example:

from workshop.generative_models.models.gan import WGANDiscriminator
from flax import nnx
import jax.numpy as jnp

discriminator = WGANDiscriminator(
    input_shape=(3, 64, 64),
    hidden_dims=(256, 512, 1024),
    activation=jax.nn.leaky_relu,
    use_instance_norm=True,
    rngs=nnx.Rngs(params=0),
)

images = jnp.ones((16, 3, 64, 64))
scores = discriminator(images, training=True)
print(scores.shape)  # (16,)
# Note: No sigmoid, scores can be any real number

WGAN¤

workshop.generative_models.models.gan.WGAN ¤

WGAN(config: WGANConfig, *, rngs: Rngs)

Bases: Module

Wasserstein GAN with Gradient Penalty (WGAN-GP) model.

Based on the PyTorch reference implementation with proper convolutional architecture.

Parameters:

Name Type Description Default
config WGANConfig

WGANConfig with nested ConvGeneratorConfig and ConvDiscriminatorConfig

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not WGANConfig

generator instance-attribute ¤

generator = WGANGenerator(config=gen_config, rngs=rngs)

discriminator instance-attribute ¤

discriminator = WGANDiscriminator(config=disc_config, rngs=rngs)

lambda_gp instance-attribute ¤

lambda_gp = gradient_penalty_weight

n_critic instance-attribute ¤

n_critic = critic_iterations

latent_dim instance-attribute ¤

latent_dim = latent_dim

config instance-attribute ¤

config = config

generate ¤

generate(n_samples: int = 1, *, rngs: Rngs | None = None, batch_size: int | None = None, **kwargs) -> Array

Generate samples from the generator.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
rngs Rngs | None

Random number generator

None
batch_size int | None

Alternative way to specify number of samples (for compatibility)

None
**kwargs

Additional keyword arguments

{}

Returns:

Type Description
Array

Generated samples

discriminator_loss ¤

discriminator_loss(real_samples: Array, fake_samples: Array, rngs: Rngs) -> Array

Compute WGAN-GP discriminator loss.

Parameters:

Name Type Description Default
real_samples Array

Real samples from dataset.

required
fake_samples Array

Generated fake samples.

required
rngs Rngs

Random number generators.

required

Returns:

Type Description
Array

Discriminator loss value.

generator_loss ¤

generator_loss(fake_samples: Array) -> Array

Compute WGAN generator loss.

Parameters:

Name Type Description Default
fake_samples Array

Generated fake samples.

required

Returns:

Type Description
Array

Generator loss value.

Complete Wasserstein GAN with Gradient Penalty (WGAN-GP).

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
config ModelConfiguration Required Model configuration
rngs nnx.Rngs Required Random number generators
precision jax.lax.Precision None Numerical precision

Configuration:

from workshop.generative_models.core.configuration import ModelConfiguration

config = ModelConfiguration(
    input_dim=100,                        # Latent dimension
    output_dim=(3, 64, 64),               # Output image shape
    hidden_dims=None,                     # Use defaults
    metadata={
        "gan_params": {
            "gen_hidden_dims": (1024, 512, 256),
            "disc_hidden_dims": (256, 512, 1024),
            "gradient_penalty_weight": 10.0,     # Lambda for GP
            "critic_iterations": 5,               # Critic updates per generator
        }
    }
)

Methods:

generate(n_samples=1, rngs=None, batch_size=None, **kwargs)¤

Generate samples from generator.

Parameters:

  • n_samples (int): Number of samples
  • rngs (nnx.Rngs, optional): Random number generators
  • batch_size (int, optional): Alternative to n_samples

Returns:

  • jax.Array: Generated samples

discriminator_loss(real_samples, fake_samples, rngs)¤

Compute WGAN-GP discriminator loss with gradient penalty.

Parameters:

  • real_samples (jax.Array): Real images
  • fake_samples (jax.Array): Generated images
  • rngs (nnx.Rngs): Random number generators

Returns:

  • jax.Array: Discriminator loss (scalar)

Loss Formula: $$ \mathcal{L}D = \mathbb{E}[D(G(z))] - \mathbb{E}[D(x)] + \lambda \mathbb{E}[(|\nabla)|_2 - 1)^2] $$}} D(\hat{x

generator_loss(fake_samples)¤

Compute WGAN generator loss.

Parameters:

  • fake_samples (jax.Array): Generated images

Returns:

  • jax.Array: Generator loss (scalar)

Loss Formula: $$ \mathcal{L}_G = -\mathbb{E}[D(G(z))] $$

Example:

from workshop.generative_models.models.gan import WGAN
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx

config = ModelConfiguration(
    input_dim=100,
    output_dim=(3, 64, 64),
    metadata={
        "gan_params": {
            "gen_hidden_dims": (1024, 512, 256),
            "disc_hidden_dims": (256, 512, 1024),
            "gradient_penalty_weight": 10.0,
            "critic_iterations": 5,
        }
    }
)

wgan = WGAN(config, rngs=nnx.Rngs(params=0, sample=1))

# Generate samples
samples = wgan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))
print(samples.shape)  # (16, 3, 64, 64)

# Training step
import jax
real_samples = jax.random.normal(jax.random.key(0), (32, 3, 64, 64))
z = jax.random.normal(jax.random.key(1), (32, 100))
fake_samples = wgan.generator(z, training=True)

disc_loss = wgan.discriminator_loss(real_samples, fake_samples, rngs=nnx.Rngs(params=2))
gen_loss = wgan.generator_loss(fake_samples)

compute_gradient_penalty¤

workshop.generative_models.models.gan.compute_gradient_penalty ¤

compute_gradient_penalty(discriminator: WGANDiscriminator, real_samples: Array, fake_samples: Array, rngs: Rngs, lambda_gp: float = 10.0) -> Array

Compute gradient penalty for WGAN-GP.

The gradient penalty enforces the Lipschitz constraint by penalizing the discriminator when the gradient norm deviates from 1.

Parameters:

Name Type Description Default
discriminator WGANDiscriminator

The discriminator network.

required
real_samples Array

Real samples from the dataset.

required
fake_samples Array

Generated fake samples.

required
rngs Rngs

Random number generators for interpolation.

required
lambda_gp float

Gradient penalty weight.

10.0

Returns:

Type Description
Array

Gradient penalty loss value.

Compute gradient penalty for WGAN-GP.

Module: workshop.generative_models.models.gan

Function Signature:

def compute_gradient_penalty(
    discriminator: WGANDiscriminator,
    real_samples: jax.Array,
    fake_samples: jax.Array,
    rngs: nnx.Rngs,
    lambda_gp: float = 10.0,
) -> jax.Array

Parameters:

Parameter Type Default Description
discriminator WGANDiscriminator Required Discriminator network
real_samples jax.Array Required Real images
fake_samples jax.Array Required Generated images
rngs nnx.Rngs Required Random number generators
lambda_gp float 10.0 Gradient penalty weight

Returns:

  • jax.Array: Gradient penalty loss (scalar)

Formula: $$ \text{GP} = \lambda \mathbb{E}{\hat{x}}[(|\nabla)|_2 - 1)^2] $$}} D(\hat{x

where \(\hat{x} = \epsilon x + (1-\epsilon)G(z)\) is a random interpolation.

Example:

from workshop.generative_models.models.gan import compute_gradient_penalty, WGANDiscriminator
from flax import nnx
import jax

discriminator = WGANDiscriminator(
    input_shape=(3, 64, 64),
    rngs=nnx.Rngs(params=0),
)

real_samples = jax.random.normal(jax.random.key(0), (32, 3, 64, 64))
fake_samples = jax.random.normal(jax.random.key(1), (32, 3, 64, 64))

gp = compute_gradient_penalty(
    discriminator,
    real_samples,
    fake_samples,
    rngs=nnx.Rngs(params=2),
    lambda_gp=10.0,
)
print(f"Gradient Penalty: {gp:.4f}")

LSGAN¤

LSGANGenerator¤

workshop.generative_models.models.gan.LSGANGenerator ¤

LSGANGenerator(config: ConvGeneratorConfig, *, rngs: Rngs)

Bases: Generator

Least Squares GAN Generator using convolutional architecture.

LSGAN uses the same architecture as DCGAN but with least squares loss instead of the standard adversarial loss.

Parameters:

Name Type Description Default
config ConvGeneratorConfig

ConvGeneratorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConvGeneratorConfig

lsgan_activation_fn instance-attribute ¤

lsgan_activation_fn = _get_activation(activation)

initial_linear instance-attribute ¤

initial_linear = Linear(in_features=latent_dim, out_features=init_h * init_w * hidden_dims_list[0], rngs=rngs)

initial_bn instance-attribute ¤

initial_bn = BatchNorm(num_features=hidden_dims_list[0], use_running_average=batch_norm_use_running_avg, momentum=batch_norm_momentum, rngs=rngs)

conv_transpose_layers instance-attribute ¤

conv_transpose_layers = List([])

lsgan_batch_norm_layers instance-attribute ¤

lsgan_batch_norm_layers = List([])

output_conv instance-attribute ¤

output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)

Least Squares GAN generator (same architecture as DCGAN).

Module: workshop.generative_models.models.gan

Parameters: Same as DCGANGenerator

Methods: Same as DCGANGenerator

Example:

from workshop.generative_models.models.gan import LSGANGenerator
from flax import nnx
import jax

generator = LSGANGenerator(
    output_shape=(3, 64, 64),
    latent_dim=100,
    hidden_dims=(512, 256, 128, 64),
    rngs=nnx.Rngs(params=0),
)

z = jax.random.normal(jax.random.key(0), (16, 100))
images = generator(z, training=False)

LSGANDiscriminator¤

workshop.generative_models.models.gan.LSGANDiscriminator ¤

LSGANDiscriminator(config: ConvDiscriminatorConfig, *, rngs: Rngs)

Bases: Discriminator

Least Squares GAN Discriminator using convolutional architecture.

LSGAN discriminator uses the same architecture as DCGAN discriminator but with least squares loss instead of sigmoid cross-entropy loss.

Parameters:

Name Type Description Default
config ConvDiscriminatorConfig

ConvDiscriminatorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConvDiscriminatorConfig

conv_layers instance-attribute ¤

conv_layers = List([])

lsgan_batch_norm_layers instance-attribute ¤

lsgan_batch_norm_layers = List([])

output_layer instance-attribute ¤

output_layer: Linear = Linear(in_features=final_features, out_features=1, rngs=rngs)

lsgan_dropout instance-attribute ¤

lsgan_dropout = Dropout(rate=dropout_rate, rngs=rngs)

Least Squares GAN discriminator (no sigmoid activation).

Module: workshop.generative_models.models.gan

Parameters: Same as DCGANDiscriminator

Key Difference: Output layer has no sigmoid activation (outputs raw logits for least squares loss).

Example:

from workshop.generative_models.models.gan import LSGANDiscriminator
from flax import nnx
import jax.numpy as jnp

discriminator = LSGANDiscriminator(
    input_shape=(3, 64, 64),
    hidden_dims=(64, 128, 256, 512),
    rngs=nnx.Rngs(params=0, dropout=1),
)

images = jnp.ones((16, 3, 64, 64))
scores = discriminator(images, training=True)
# Note: Scores are raw logits, not in [0, 1]

LSGAN¤

workshop.generative_models.models.gan.LSGAN ¤

LSGAN(config: LSGANConfig, *, rngs: Rngs)

Bases: Module

Least Squares GAN implementation.

LSGAN replaces the log loss in the original GAN formulation with a least squares loss, which provides more stable training and better quality gradients for the generator.

Reference

Mao et al. "Least Squares Generative Adversarial Networks" (2017)

Parameters:

Name Type Description Default
config LSGANConfig

LSGANConfig with nested ConvGeneratorConfig and ConvDiscriminatorConfig

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not LSGANConfig

loss_type instance-attribute ¤

loss_type = 'least_squares'

generator instance-attribute ¤

generator = LSGANGenerator(config=gen_config, rngs=rngs)

discriminator instance-attribute ¤

discriminator = LSGANDiscriminator(config=disc_config, rngs=rngs)

a instance-attribute ¤

a = a

b instance-attribute ¤

b = b

c instance-attribute ¤

c = c

latent_dim instance-attribute ¤

latent_dim = latent_dim

config instance-attribute ¤

config = config

generator_loss ¤

generator_loss(fake_scores: Array, target_real: float = 1.0, reduction: str = 'mean') -> Array

Compute LSGAN generator loss.

Parameters:

Name Type Description Default
fake_scores Array

Discriminator scores for fake samples

required
target_real float

Target value for fake samples (usually 1.0)

1.0
reduction str

Reduction method ('mean', 'sum', 'none')

'mean'

Returns:

Type Description
Array

Generator loss

discriminator_loss ¤

discriminator_loss(real_scores: Array, fake_scores: Array, target_real: float = 1.0, target_fake: float = 0.0, reduction: str = 'mean') -> Array

Compute LSGAN discriminator loss.

Parameters:

Name Type Description Default
real_scores Array

Discriminator scores for real samples

required
fake_scores Array

Discriminator scores for fake samples

required
target_real float

Target value for real samples (usually 1.0)

1.0
target_fake float

Target value for fake samples (usually 0.0)

0.0
reduction str

Reduction method ('mean', 'sum', 'none')

'mean'

Returns:

Type Description
Array

Discriminator loss

training_step ¤

training_step(real_images: Array, latent_vectors: Array) -> dict[str, Array]

Perform a single training step.

Note: Use model.train() for training mode and model.eval() for evaluation mode.

Parameters:

Name Type Description Default
real_images Array

Batch of real images

required
latent_vectors Array

Batch of latent vectors

required

Returns:

Type Description
dict[str, Array]

Dictionary with loss values and generated images

Complete Least Squares GAN model.

Module: workshop.generative_models.models.gan

Parameters: Same as base GAN

Methods:

generator_loss(fake_scores, target_real=1.0, reduction="mean")¤

Compute LSGAN generator loss.

Formula: $$ \mathcal{L}_G = \frac{1}{2}\mathbb{E}[(D(G(z)) - c)^2] $$

where \(c\) is the target value for fake samples (usually 1.0).

discriminator_loss(real_scores, fake_scores, target_real=1.0, target_fake=0.0, reduction="mean")¤

Compute LSGAN discriminator loss.

Formula: $$ \mathcal{L}_D = \frac{1}{2}\mathbb{E}[(D(x) - b)^2] + \frac{1}{2}\mathbb{E}[D(G(z))^2] $$

where \(b\) is the target for real samples (usually 1.0).

Example:

from workshop.generative_models.models.gan import LSGAN
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx

config = ModelConfiguration(
    input_dim=100,
    output_dim=(3, 64, 64),
    hidden_dims=[512, 256, 128, 64],
)

lsgan = LSGAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))

# Generate samples
samples = lsgan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))

# Compute losses
import jax
real_images = jax.random.normal(jax.random.key(0), (32, 3, 64, 64))
z = jax.random.normal(jax.random.key(1), (32, 100))
fake_images = lsgan.generator(z, training=True)

real_scores = lsgan.discriminator(real_images, training=True)
fake_scores = lsgan.discriminator(fake_images, training=True)

gen_loss = lsgan.generator_loss(fake_scores)
disc_loss = lsgan.discriminator_loss(real_scores, fake_scores)

Conditional GAN¤

ConditionalGenerator¤

workshop.generative_models.models.gan.ConditionalGenerator ¤

ConditionalGenerator(config: ConditionalGeneratorConfig, *, rngs: Rngs)

Bases: Generator

Conditional GAN Generator using convolutional architecture.

The generator is conditioned on class labels by concatenating the label embedding with the noise vector before passing through the network.

Parameters:

Name Type Description Default
config ConditionalGeneratorConfig

ConditionalGeneratorConfig with network architecture and conditional parameters (num_classes, embedding_dim via config.conditional)

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConditionalGeneratorConfig

num_classes instance-attribute ¤

num_classes = num_classes

embedding_dim instance-attribute ¤

embedding_dim = embedding_dim

activation_fn instance-attribute ¤

activation_fn = _get_activation_fn(activation)

label_embedding instance-attribute ¤

label_embedding = Linear(in_features=num_classes, out_features=embedding_dim, rngs=rngs)

init_h instance-attribute ¤

init_h = target_size // 2 ** total_upsample_layers

init_w instance-attribute ¤

init_w = target_size // 2 ** total_upsample_layers

initial_projection instance-attribute ¤

initial_projection = Linear(in_features=combined_input_dim, out_features=init_h * init_w * hidden_dims_list[0], rngs=rngs)

initial_bn instance-attribute ¤

initial_bn = BatchNorm(num_features=hidden_dims_list[0], use_running_average=batch_norm_use_running_avg, momentum=batch_norm_momentum, rngs=rngs)

conv_transpose_layers instance-attribute ¤

conv_transpose_layers = List([])

batch_norm_layers instance-attribute ¤

batch_norm_layers = List([])

output_conv instance-attribute ¤

output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)

Conditional GAN generator that takes class labels as input.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
output_shape tuple[int, ...] Required Output image shape (C, H, W)
num_classes int Required Number of classes
latent_dim int 100 Latent space dimension
hidden_dims tuple[int, ...] (512, 256, 128, 64) Channel dimensions
activation callable jax.nn.relu Activation function
batch_norm bool True Use batch normalization
dropout_rate float 0.0 Dropout rate
rngs nnx.Rngs Required Random number generators

Methods:

__call__(z, labels, training=True)¤

Generate samples conditioned on labels.

Parameters:

  • z (jax.Array): Latent vectors of shape (batch_size, latent_dim)
  • labels (jax.Array): One-hot encoded labels of shape (batch_size, num_classes)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Generated images of shape (batch_size, C, H, W)

Example:

from workshop.generative_models.models.gan import ConditionalGenerator
from flax import nnx
import jax
import jax.numpy as jnp

generator = ConditionalGenerator(
    output_shape=(1, 28, 28),           # MNIST
    num_classes=10,
    latent_dim=100,
    hidden_dims=(512, 256, 128, 64),
    rngs=nnx.Rngs(params=0),
)

# Generate specific digits
z = jax.random.normal(jax.random.key(0), (10, 100))
labels = jax.nn.one_hot(jnp.arange(10), 10)  # One of each digit
images = generator(z, labels, training=False)
print(images.shape)  # (10, 1, 28, 28)

ConditionalDiscriminator¤

workshop.generative_models.models.gan.ConditionalDiscriminator ¤

ConditionalDiscriminator(config: ConditionalDiscriminatorConfig, *, rngs: Rngs)

Bases: Discriminator

Conditional GAN Discriminator using convolutional architecture.

The discriminator is conditioned on class labels by concatenating the label embedding with the input image before passing through the network.

Parameters:

Name Type Description Default
config ConditionalDiscriminatorConfig

ConditionalDiscriminatorConfig with network architecture and conditional parameters (num_classes, embedding_dim via config.conditional)

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConditionalDiscriminatorConfig

input_shape instance-attribute ¤

input_shape = input_shape

num_classes instance-attribute ¤

num_classes = num_classes

embedding_dim instance-attribute ¤

embedding_dim = embedding_dim

leaky_relu_slope instance-attribute ¤

leaky_relu_slope = leaky_relu_slope

expected_height instance-attribute ¤

expected_height = height

expected_width instance-attribute ¤

expected_width = width

expected_channels instance-attribute ¤

expected_channels = channels

label_embedding instance-attribute ¤

label_embedding = Linear(in_features=num_classes, out_features=height * width, rngs=rngs)

conv_layers instance-attribute ¤

conv_layers = List([])

output_conv instance-attribute ¤

output_conv = Conv(in_features=hidden_dims_list[-1], out_features=1, kernel_size=kernel_size, strides=stride_first, padding=padding, rngs=rngs)

Conditional GAN discriminator that takes class labels as input.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
input_shape tuple[int, ...] Required Input image shape (C, H, W)
num_classes int Required Number of classes
hidden_dims tuple[int, ...] (64, 128, 256, 512) Channel dimensions
activation callable jax.nn.leaky_relu Activation function
leaky_relu_slope float 0.2 Negative slope for LeakyReLU
batch_norm bool False Use batch normalization
dropout_rate float 0.0 Dropout rate
rngs nnx.Rngs Required Random number generators

Methods:

__call__(x, labels, training=True)¤

Classify samples conditioned on labels.

Parameters:

  • x (jax.Array): Input images of shape (batch_size, C, H, W)
  • labels (jax.Array): One-hot encoded labels of shape (batch_size, num_classes)
  • training (bool): Whether in training mode

Returns:

  • jax.Array: Discrimination scores of shape (batch_size,)

Example:

from workshop.generative_models.models.gan import ConditionalDiscriminator
from flax import nnx
import jax
import jax.numpy as jnp

discriminator = ConditionalDiscriminator(
    input_shape=(1, 28, 28),
    num_classes=10,
    hidden_dims=(64, 128, 256, 512),
    rngs=nnx.Rngs(params=0, dropout=1),
)

# Classify samples with labels
images = jax.random.normal(jax.random.key(0), (32, 1, 28, 28))
labels = jax.nn.one_hot(jnp.zeros(32, dtype=int), 10)  # All zeros
scores = discriminator(images, labels, training=True)
print(scores.shape)  # (32,)

ConditionalGAN¤

workshop.generative_models.models.gan.ConditionalGAN ¤

ConditionalGAN(config: ConditionalGANConfig, *, rngs: Rngs)

Bases: Module

Conditional Generative Adversarial Network (CGAN).

Based on "Conditional Generative Adversarial Nets" by Mirza & Osindero (2014). The generator and discriminator are both conditioned on class labels.

Uses composition pattern: conditional parameters (num_classes, embedding_dim) are embedded in the nested ConditionalGeneratorConfig and ConditionalDiscriminatorConfig via ConditionalParams.

Parameters:

Name Type Description Default
config ConditionalGANConfig

ConditionalGANConfig with nested ConditionalGeneratorConfig and ConditionalDiscriminatorConfig. All parameters are specified in the config objects.

required
rngs Rngs

Random number generators.

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not ConditionalGANConfig

generator instance-attribute ¤

generator = ConditionalGenerator(config=gen_config, rngs=rngs)

discriminator instance-attribute ¤

discriminator = ConditionalDiscriminator(config=disc_config, rngs=rngs)

num_classes instance-attribute ¤

num_classes = num_classes

latent_dim instance-attribute ¤

latent_dim = latent_dim

config instance-attribute ¤

config = config

generate ¤

generate(n_samples: int = 1, labels: Array | None = None, *, rngs: Rngs | None = None, batch_size: int | None = None, **kwargs) -> Array

Generate conditional samples from the generator.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
labels Array | None

One-hot encoded labels of shape (n_samples, num_classes)

None
rngs Rngs | None

Random number generator

None
batch_size int | None

Alternative way to specify number of samples (for compatibility)

None
**kwargs

Additional keyword arguments

{}

Returns:

Type Description
Array

Generated samples

discriminator_loss ¤

discriminator_loss(real_samples: Array, fake_samples: Array, real_labels: Array, fake_labels: Array) -> Array

Compute conditional discriminator loss.

Parameters:

Name Type Description Default
real_samples Array

Real samples from dataset.

required
fake_samples Array

Generated fake samples.

required
real_labels Array

Labels for real samples.

required
fake_labels Array

Labels for fake samples.

required

Returns:

Type Description
Array

Discriminator loss value.

generator_loss ¤

generator_loss(fake_samples: Array, fake_labels: Array) -> Array

Compute conditional generator loss.

Parameters:

Name Type Description Default
fake_samples Array

Generated fake samples.

required
fake_labels Array

Labels for fake samples.

required

Returns:

Type Description
Array

Generator loss value.

Complete Conditional GAN model.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
config ModelConfiguration Required Model configuration
rngs nnx.Rngs Required Random number generators
precision jax.lax.Precision None Numerical precision

Configuration:

from workshop.generative_models.core.configuration import ModelConfiguration

config = ModelConfiguration(
    input_dim=100,                        # Latent dimension
    output_dim=(1, 28, 28),               # MNIST shape
    metadata={
        "gan_params": {
            "num_classes": 10,                      # Number of classes
            "gen_hidden_dims": (512, 256, 128, 64),
            "discriminator_features": [64, 128, 256, 512],
        }
    }
)

Methods:

generate(n_samples=1, labels=None, rngs=None, batch_size=None, **kwargs)¤

Generate conditional samples.

Parameters:

  • n_samples (int): Number of samples
  • labels (jax.Array, optional): One-hot encoded labels. If None, random labels are used.
  • rngs (nnx.Rngs, optional): Random number generators
  • batch_size (int, optional): Alternative to n_samples

Returns:

  • jax.Array: Generated samples

Example:

from workshop.generative_models.models.gan import ConditionalGAN
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx
import jax
import jax.numpy as jnp

config = ModelConfiguration(
    input_dim=100,
    output_dim=(1, 28, 28),
    metadata={
        "gan_params": {
            "num_classes": 10,
            "gen_hidden_dims": (512, 256, 128, 64),
            "discriminator_features": [64, 128, 256, 512],
        }
    }
)

cgan = ConditionalGAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))

# Generate specific digits
labels = jax.nn.one_hot(jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 10)
samples = cgan.generate(n_samples=10, labels=labels, rngs=nnx.Rngs(sample=0))
print(samples.shape)  # (10, 1, 28, 28)

CycleGAN¤

CycleGANGenerator¤

workshop.generative_models.models.gan.CycleGANGenerator ¤

CycleGANGenerator(config: CycleGANGeneratorConfig, *, rngs: Rngs)

Bases: Module

CycleGAN Generator for image-to-image translation.

Uses a ResNet-based architecture with reflection padding as described in the original CycleGAN paper. This follows the pytorch reference implementation more closely.

Parameters:

Name Type Description Default
config CycleGANGeneratorConfig

CycleGANGeneratorConfig with network architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not CycleGANGeneratorConfig

config instance-attribute ¤

config = config

input_shape instance-attribute ¤

input_shape = input_shape

output_shape instance-attribute ¤

output_shape = output_shape

hidden_dims instance-attribute ¤

hidden_dims = list(hidden_dims)

n_residual_blocks instance-attribute ¤

n_residual_blocks = n_residual_blocks

batch_norm instance-attribute ¤

batch_norm = batch_norm

dropout_rate instance-attribute ¤

dropout_rate = dropout_rate

use_skip_connections instance-attribute ¤

use_skip_connections = use_skip_connections

activation instance-attribute ¤

activation = getattr(nnx, activation_name)

initial_conv instance-attribute ¤

initial_conv = Conv(in_features=input_channels, out_features=hidden_dims[0], kernel_size=(7, 7), strides=(1, 1), padding='SAME', rngs=rngs)

initial_norm instance-attribute ¤

initial_norm = BatchNorm(num_features=hidden_dims[0], rngs=rngs)

downsample_layers instance-attribute ¤

downsample_layers = List([])

downsample_norms instance-attribute ¤

downsample_norms = List([])

residual_blocks instance-attribute ¤

residual_blocks = List([])

upsample_layers instance-attribute ¤

upsample_layers = List([])

upsample_norms instance-attribute ¤

upsample_norms = List([])

output_conv instance-attribute ¤

output_conv = Conv(in_features=hidden_dims[0], out_features=output_channels, kernel_size=(7, 7), strides=(1, 1), padding='SAME', rngs=rngs)

dropout instance-attribute ¤

dropout = None

CycleGAN generator for image-to-image translation.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
output_shape tuple[int, ...] Required Output image shape (C, H, W)
hidden_dims tuple[int, ...] (64, 128, 256) Channel dimensions
num_residual_blocks int 9 Number of ResNet blocks
rngs nnx.Rngs Required Random number generators

CycleGANDiscriminator¤

workshop.generative_models.models.gan.CycleGANDiscriminator ¤

CycleGANDiscriminator(config: PatchGANDiscriminatorConfig, *, rngs: Rngs)

Bases: Module

CycleGAN Discriminator (PatchGAN-style).

Uses a PatchGAN discriminator that classifies patches of the input as real or fake, rather than the entire image.

Parameters:

Name Type Description Default
config PatchGANDiscriminatorConfig

PatchGANDiscriminatorConfig with network architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not PatchGANDiscriminatorConfig

config instance-attribute ¤

config = config

input_shape instance-attribute ¤

input_shape = input_shape

hidden_dims instance-attribute ¤

hidden_dims = list(hidden_dims)

batch_norm instance-attribute ¤

batch_norm = batch_norm

dropout_rate instance-attribute ¤

dropout_rate = dropout_rate

activation instance-attribute ¤

activation = lambda x: leaky_relu(x, negative_slope=leaky_relu_slope)

conv_layers instance-attribute ¤

conv_layers = List([])

norm_layers instance-attribute ¤

norm_layers = List([])

final_conv instance-attribute ¤

final_conv = Conv(in_features=hidden_dims[-1], out_features=1, kernel_size=kernel_size, strides=(1, 1), padding=padding, rngs=rngs)

dropout instance-attribute ¤

dropout = None

CycleGAN PatchGAN discriminator.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
input_shape tuple[int, ...] Required Input image shape (C, H, W)
hidden_dims tuple[int, ...] (64, 128, 256, 512) Channel dimensions
rngs nnx.Rngs Required Random number generators

CycleGAN¤

workshop.generative_models.models.gan.CycleGAN ¤

CycleGAN(config: CycleGANConfig, *, rngs: Rngs)

Bases: GenerativeModel

CycleGAN for unpaired image-to-image translation.

Implements the complete CycleGAN architecture with two generators and two discriminators for bidirectional domain translation.

Parameters:

Name Type Description Default
config CycleGANConfig

CycleGANConfig with nested CycleGANGeneratorConfig and PatchGANDiscriminatorConfig objects

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not CycleGANConfig

config instance-attribute ¤

config = config

input_shape_a instance-attribute ¤

input_shape_a = input_shape_a

input_shape_b instance-attribute ¤

input_shape_b = input_shape_b

lambda_cycle instance-attribute ¤

lambda_cycle = lambda_cycle

lambda_identity instance-attribute ¤

lambda_identity = lambda_identity

generator_a_to_b instance-attribute ¤

generator_a_to_b = CycleGANGenerator(config=gen_a_to_b_config, rngs=rngs)

generator_b_to_a instance-attribute ¤

generator_b_to_a = CycleGANGenerator(config=gen_b_to_a_config, rngs=rngs)

discriminator_a instance-attribute ¤

discriminator_a = CycleGANDiscriminator(config=disc_a_config, rngs=rngs)

discriminator_b instance-attribute ¤

discriminator_b = CycleGANDiscriminator(config=disc_b_config, rngs=rngs)

generate ¤

generate(n_samples: int = 1, *, rngs: Rngs | None = None, batch_size: int | None = None, domain: str = 'a_to_b', input_images: Array | None = None, **kwargs) -> Array

Generate translated images.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate (ignored if input_images provided).

1
rngs Rngs | None

Random number generators.

None
batch_size int | None

Alternative way to specify number of samples (for compatibility).

None
domain str

Translation direction ("a_to_b" or "b_to_a").

'a_to_b'
input_images Array | None

Input images to translate. If None, generates random input.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Array

Translated images.

compute_cycle_loss ¤

compute_cycle_loss(real_a: Array, real_b: Array) -> tuple[Array, Array]

Compute cycle consistency losses.

Parameters:

Name Type Description Default
real_a Array

Real images from domain A.

required
real_b Array

Real images from domain B.

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (cycle_loss_a, cycle_loss_b).

compute_identity_loss ¤

compute_identity_loss(real_a: Array, real_b: Array) -> tuple[Array, Array]

Compute identity losses.

Identity loss encourages generators to preserve color composition when translating images that already belong to the target domain.

Parameters:

Name Type Description Default
real_a Array

Real images from domain A.

required
real_b Array

Real images from domain B.

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (identity_loss_a, identity_loss_b).

loss_fn ¤

loss_fn(batch: dict[str, Any], model_outputs: dict[str, Any], *, rngs: Rngs | None = None, **kwargs) -> dict[str, Any]

Compute total CycleGAN loss.

Parameters:

Name Type Description Default
batch dict[str, Any]

Batch containing real images from both domains.

required
model_outputs dict[str, Any]

Model outputs (not used in basic implementation).

required
rngs Rngs | None

Random number generators.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
dict[str, Any]

Dictionary containing loss and metrics.

Complete CycleGAN model for unpaired image-to-image translation.

Module: workshop.generative_models.models.gan

Key Features:

  • Two generators: G: X → Y and F: Y → X
  • Two discriminators: D_X and D_Y
  • Cycle consistency loss: x → G(x) → F(G(x)) ≈ x
  • Identity loss (optional): F(x) ≈ x if x ∈ X

Parameters:

Parameter Type Default Description
input_shape_x tuple Required Domain X image shape (C, H, W)
input_shape_y tuple Required Domain Y image shape (C, H, W)
gen_hidden_dims tuple (64, 128, 256) Generator channels
disc_hidden_dims tuple (64, 128, 256) Discriminator channels
cycle_weight float 10.0 Cycle consistency weight
identity_weight float 0.5 Identity loss weight
rngs nnx.Rngs Required Random number generators

Example:

from workshop.generative_models.models.gan import CycleGAN
from flax import nnx
import jax

cyclegan = CycleGAN(
    input_shape_x=(3, 256, 256),         # Horses
    input_shape_y=(3, 256, 256),         # Zebras
    gen_hidden_dims=(64, 128, 256),
    disc_hidden_dims=(64, 128, 256),
    cycle_weight=10.0,
    identity_weight=0.5,
    rngs=nnx.Rngs(params=0, dropout=1),
)

# Translate horse to zebra
horse_images = jax.random.normal(jax.random.key(0), (4, 3, 256, 256))
zebra_images = cyclegan.generator_g(horse_images, training=False)
print(zebra_images.shape)  # (4, 3, 256, 256)

# Translate zebra back to horse (cycle consistency)
reconstructed_horses = cyclegan.generator_f(zebra_images, training=False)
print(reconstructed_horses.shape)  # (4, 3, 256, 256)

PatchGAN¤

PatchGANDiscriminator¤

workshop.generative_models.models.gan.PatchGANDiscriminator ¤

PatchGANDiscriminator(config: PatchGANDiscriminatorConfig, *, rngs: Rngs)

Bases: Discriminator

PatchGAN Discriminator for image-to-image translation.

The PatchGAN discriminator classifies whether N×N patches in an image are real or fake, rather than classifying the entire image. This is particularly effective for image translation tasks where local texture and structure are important.

Reference

Isola et al. "Image-to-Image Translation with Conditional Adversarial Networks" (2017) Wang et al. "High-Resolution Image Synthesis and Semantic Manipulation" (2018)

Parameters:

Name Type Description Default
config PatchGANDiscriminatorConfig

PatchGANDiscriminatorConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None

TypeError

If config is not PatchGANDiscriminatorConfig

patchgan_num_filters instance-attribute ¤

patchgan_num_filters = num_filters

patchgan_num_layers instance-attribute ¤

patchgan_num_layers = num_layers

patchgan_use_bias instance-attribute ¤

patchgan_use_bias = use_bias

patchgan_kernel_size instance-attribute ¤

patchgan_kernel_size = kernel_size

patchgan_stride instance-attribute ¤

patchgan_stride = stride

patchgan_last_kernel_size instance-attribute ¤

patchgan_last_kernel_size = last_kernel_size

patchgan_activation_fn instance-attribute ¤

patchgan_activation_fn = lambda x: leaky_relu(x, negative_slope=leaky_relu_slope)

patchgan_conv_layers instance-attribute ¤

patchgan_conv_layers = List([])

patchgan_batch_norm_layers instance-attribute ¤

patchgan_batch_norm_layers = List([])

initial_conv instance-attribute ¤

initial_conv = Conv(in_features=channels, out_features=num_filters, kernel_size=kernel_size, strides=stride, padding='SAME', use_bias=True, rngs=rngs)

final_conv instance-attribute ¤

final_conv = Conv(in_features=in_channels, out_features=1, kernel_size=last_kernel_size, strides=(1, 1), padding='SAME', use_bias=True, rngs=rngs)

patchgan_dropout instance-attribute ¤

patchgan_dropout = Dropout(rate=dropout_rate, rngs=rngs)

PatchGAN discriminator that outputs N×N array of patch predictions.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
input_shape tuple[int, ...] Required Input image shape (C, H, W)
hidden_dims tuple[int, ...] (64, 128, 256, 512) Channel dimensions
kernel_size int 4 Convolution kernel size
stride int 2 Convolution stride
rngs nnx.Rngs Required Random number generators

Returns: N×N array of patch classifications instead of single scalar

Example:

from workshop.generative_models.models.gan import PatchGANDiscriminator
from flax import nnx
import jax.numpy as jnp

discriminator = PatchGANDiscriminator(
    input_shape=(3, 256, 256),
    hidden_dims=(64, 128, 256, 512),
    kernel_size=4,
    stride=2,
    rngs=nnx.Rngs(params=0, dropout=1),
)

images = jnp.ones((16, 3, 256, 256))
patch_scores = discriminator(images, training=True)
print(patch_scores.shape)  # (16, H', W', 1) - array of patch predictions

MultiScalePatchGANDiscriminator¤

workshop.generative_models.models.gan.MultiScalePatchGANDiscriminator ¤

MultiScalePatchGANDiscriminator(config: MultiScalePatchGANConfig, *, rngs: Rngs)

Bases: Module

Multi-scale PatchGAN discriminator.

Processes images at multiple scales using several PatchGAN discriminators. This allows the discriminator to capture both fine-grained and coarse-grained features at different resolutions.

Reference

Wang et al. "High-Resolution Image Synthesis and Semantic Manipulation" (2018)

Parameters:

Name Type Description Default
config MultiScalePatchGANConfig

MultiScalePatchGANConfig with all architecture parameters

required
rngs Rngs

Random number generators

required

Raises:

Type Description
ValueError

If rngs is None or configuration is invalid

TypeError

If config is not MultiScalePatchGANConfig

config instance-attribute ¤

config = config

input_shape instance-attribute ¤

input_shape = input_shape

num_discriminators instance-attribute ¤

num_discriminators = num_discriminators

num_layers_per_disc instance-attribute ¤

num_layers_per_disc = list(num_layers_per_disc)

use_avg_pool instance-attribute ¤

use_avg_pool = True

discriminators instance-attribute ¤

discriminators = List([])

downsample_image ¤

downsample_image(x: Array, factor: int) -> Array

Downsample image by given factor using average pooling.

Parameters:

Name Type Description Default
x Array

Input image tensor (B, H, W, C)

required
factor int

Downsampling factor

required

Returns:

Type Description
Array

Downsampled image

Multi-scale PatchGAN discriminator operating at multiple resolutions.

Module: workshop.generative_models.models.gan

Parameters:

Parameter Type Default Description
input_shape tuple[int, ...] Required Input image shape (C, H, W)
hidden_dims tuple[int, ...] (64, 128, 256) Channel dimensions
num_scales int 3 Number of scales
rngs nnx.Rngs Required Random number generators

Returns: List of predictions at different scales

Example:

from workshop.generative_models.models.gan import MultiScalePatchGANDiscriminator
from flax import nnx
import jax.numpy as jnp

discriminator = MultiScalePatchGANDiscriminator(
    input_shape=(3, 256, 256),
    hidden_dims=(64, 128, 256),
    num_scales=3,
    rngs=nnx.Rngs(params=0, dropout=1),
)

images = jnp.ones((16, 3, 256, 256))
predictions = discriminator(images, training=True)
# predictions is a list of 3 arrays at different scales
for i, pred in enumerate(predictions):
    print(f"Scale {i}: {pred.shape}")

Loss Functions¤

See Adversarial Loss Functions for detailed documentation of:

  • vanilla_generator_loss
  • vanilla_discriminator_loss
  • least_squares_generator_loss
  • least_squares_discriminator_loss
  • wasserstein_generator_loss
  • wasserstein_discriminator_loss
  • hinge_generator_loss
  • hinge_discriminator_loss

Summary¤

This API reference covered all GAN model classes:

  • Base Classes: Generator, Discriminator, GAN
  • DCGAN: DCGANGenerator, DCGANDiscriminator, DCGAN
  • WGAN: WGANGenerator, WGANDiscriminator, WGAN, compute_gradient_penalty
  • LSGAN: LSGANGenerator, LSGANDiscriminator, LSGAN
  • Conditional GAN: ConditionalGenerator, ConditionalDiscriminator, ConditionalGAN
  • CycleGAN: CycleGANGenerator, CycleGANDiscriminator, CycleGAN
  • PatchGAN: PatchGANDiscriminator, MultiScalePatchGANDiscriminator

See Also¤