GAN API Reference¤
Complete API reference for all GAN model classes in Workshop.
Overview¤
The GAN module provides implementations of various Generative Adversarial Network architectures:
- Base GAN: Standard generator and discriminator
- DCGAN: Deep convolutional architecture
- WGAN: Wasserstein distance with gradient penalty
- LSGAN: Least squares loss
- Conditional GAN: Class-conditioned generation
- CycleGAN: Unpaired image-to-image translation
- PatchGAN: Patch-based discrimination
Base Classes¤
Generator¤
workshop.generative_models.models.gan.Generator
¤
Generator(config: GeneratorConfig, *, rngs: Rngs)
Bases: Module
Generator network for GAN.
Base generator using fully-connected (Dense) layers. For convolutional generators, use DCGANGenerator or other specialized subclasses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GeneratorConfig
|
GeneratorConfig with network architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not GeneratorConfig |
Basic generator network that transforms latent vectors into data samples.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims |
list[int] |
Required | Hidden layer dimensions |
output_shape |
tuple |
Required | Shape of generated samples (batch, C, H, W) |
latent_dim |
int |
Required | Dimension of latent space |
activation |
str |
"relu" |
Activation function name |
batch_norm |
bool |
True |
Whether to use batch normalization |
dropout_rate |
float |
0.0 |
Dropout rate |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(z, training=False)¤
Generate samples from latent vectors.
Parameters:
z(jax.Array): Latent vectors of shape(batch_size, latent_dim)training(bool): Whether in training mode (affects batch norm and dropout)
Returns:
jax.Array: Generated samples of shape(batch_size, *output_shape[1:])
Example:
from workshop.generative_models.models.gan import Generator
from flax import nnx
import jax.numpy as jnp
# Create generator
generator = Generator(
hidden_dims=[256, 512, 1024],
output_shape=(1, 1, 28, 28), # MNIST
latent_dim=100,
activation="relu",
batch_norm=True,
rngs=nnx.Rngs(params=0),
)
# Generate samples
z = jnp.ones((32, 100)) # Batch of latent vectors
samples = generator(z, training=False)
print(samples.shape) # (32, 1, 28, 28)
Discriminator¤
workshop.generative_models.models.gan.Discriminator
¤
Discriminator(config: DiscriminatorConfig, *, rngs: Rngs)
Bases: Module
Discriminator network for GAN.
Base discriminator using fully-connected (Dense) layers. For convolutional discriminators, use DCGANDiscriminator or other specialized subclasses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DiscriminatorConfig
|
DiscriminatorConfig with network architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not DiscriminatorConfig |
Basic discriminator network that classifies samples as real or fake.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims |
list[int] |
Required | Hidden layer dimensions |
activation |
str |
"leaky_relu" |
Activation function name |
leaky_relu_slope |
float |
0.2 |
Negative slope for LeakyReLU |
batch_norm |
bool |
False |
Whether to use batch normalization |
dropout_rate |
float |
0.3 |
Dropout rate |
use_spectral_norm |
bool |
False |
Whether to use spectral normalization |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
initialize_layers(input_shape, rngs=None)¤
Initialize layers based on input shape.
Parameters:
input_shape(tuple): Shape of input data(batch, C, H, W)rngs(nnx.Rngs, optional): Random number generators
__call__(x, training=False)¤
Classify samples as real or fake.
Parameters:
x(jax.Array): Input samples of shape(batch_size, C, H, W)training(bool): Whether in training mode
Returns:
jax.Array: Discrimination scores of shape(batch_size, 1), values in[0, 1]
Example:
from workshop.generative_models.models.gan import Discriminator
from flax import nnx
import jax.numpy as jnp
# Create discriminator
discriminator = Discriminator(
hidden_dims=[512, 256, 128],
activation="leaky_relu",
leaky_relu_slope=0.2,
batch_norm=False,
dropout_rate=0.3,
rngs=nnx.Rngs(params=0, dropout=1),
)
# Classify samples
samples = jnp.ones((32, 1, 28, 28))
scores = discriminator(samples, training=True)
print(scores.shape) # (32, 1)
print(f"Scores range: [{scores.min():.3f}, {scores.max():.3f}]")
GAN¤
workshop.generative_models.models.gan.GAN
¤
Bases: GenerativeModel
Generative Adversarial Network (GAN) implementation.
Base GAN using fully-connected Generator and Discriminator. For convolutional GANs, use DCGAN or other specialized subclasses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
GANConfig
|
GANConfig with nested GeneratorConfig and DiscriminatorConfig |
required |
rngs
|
Rngs
|
Random number generators |
required |
precision
|
Precision | None
|
Numerical precision for computations |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None or missing 'sample' stream |
TypeError
|
If config is not GANConfig |
generate
¤
Generate samples from the generator.
Note: Uses stored self.rngs for sampling. RNG automatically advances each call.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
batch_size
|
int | None
|
Alternative way to specify number of samples (for compatibility) |
None
|
**kwargs
|
Any
|
Additional keyword arguments |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples of shape (num_samples, *output_shape[1:]) |
loss_fn
¤
Compute GAN loss for training.
Note: Uses stored self.rngs for sampling. RNG automatically advances each call.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Input batch containing real data (dict with 'x' or 'data' key, or raw array) |
required |
model_outputs
|
dict[str, Any]
|
Model outputs (unused for GAN loss computation) |
required |
**kwargs
|
Any
|
Additional keyword arguments |
{}
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing: - loss: Total combined loss (generator + discriminator) - generator_loss: Generator loss component - discriminator_loss: Discriminator loss component - real_scores_mean: Mean discriminator score on real data - fake_scores_mean: Mean discriminator score on fake data |
Complete GAN model combining generator and discriminator.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
config |
object |
Required | Model configuration object |
rngs |
nnx.Rngs |
Required | Random number generators |
precision |
jax.lax.Precision |
None |
Numerical precision |
Configuration Object:
The config must have the following structure:
class GANConfig:
latent_dim: int = 100 # Latent space dimension
loss_type: str = "vanilla" # Loss type: "vanilla", "wasserstein", "least_squares", "hinge"
gradient_penalty_weight: float = 0.0 # Weight for gradient penalty (WGAN-GP)
class generator:
hidden_dims: list[int] # Generator hidden dimensions
output_shape: tuple # Output shape (batch, C, H, W)
activation: str = "relu"
batch_norm: bool = True
dropout_rate: float = 0.0
class discriminator:
hidden_dims: list[int] # Discriminator hidden dimensions
activation: str = "leaky_relu"
leaky_relu_slope: float = 0.2
batch_norm: bool = False
dropout_rate: float = 0.3
use_spectral_norm: bool = False
Methods:
__call__(x, rngs=None, training=False, **kwargs)¤
Forward pass through the GAN (runs discriminator only).
Parameters:
x(jax.Array): Input datarngs(nnx.Rngs, optional): Random number generatorstraining(bool): Whether in training mode
Returns:
dict: Dictionary with keys:"real_scores": Discriminator scores for real data"fake_scores":None(computed inloss_fn)"fake_samples":None(computed inloss_fn)
generate(n_samples=1, rngs=None, batch_size=None, **kwargs)¤
Generate samples from the generator.
Parameters:
n_samples(int): Number of samples to generaterngs(nnx.Rngs, optional): Random number generatorsbatch_size(int, optional): Alternative ton_samples
Returns:
jax.Array: Generated samples
loss_fn(batch, model_outputs, rngs=None, **kwargs)¤
Compute GAN loss for training.
Parameters:
batch(dictorjax.Array): Input batch (real data)model_outputs(dict): Model outputs (unused for GAN)rngs(nnx.Rngs, optional): Random number generators
Returns:
dict: Dictionary with losses:"loss": Total loss (generator + discriminator)"generator_loss": Generator loss"discriminator_loss": Discriminator loss"real_scores_mean": Mean discriminator score for real samples"fake_scores_mean": Mean discriminator score for fake samples
Example:
from workshop.generative_models.models.gan import GAN
from flax import nnx
# Create configuration
class GANConfig:
latent_dim = 100
loss_type = "vanilla"
class generator:
hidden_dims = [256, 512]
output_shape = (1, 1, 28, 28)
activation = "relu"
batch_norm = True
dropout_rate = 0.0
class discriminator:
hidden_dims = [512, 256]
activation = "leaky_relu"
leaky_relu_slope = 0.2
batch_norm = False
dropout_rate = 0.3
use_spectral_norm = False
# Create GAN
gan = GAN(GANConfig(), rngs=nnx.Rngs(params=0, dropout=1, sample=2))
# Generate samples
samples = gan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))
# Compute loss
import jax.numpy as jnp
batch = jnp.ones((32, 1, 28, 28))
losses = gan.loss_fn(batch, None, rngs=nnx.Rngs(sample=0))
print(f"Generator Loss: {losses['generator_loss']:.4f}")
print(f"Discriminator Loss: {losses['discriminator_loss']:.4f}")
DCGAN¤
DCGANGenerator¤
workshop.generative_models.models.gan.DCGANGenerator
¤
DCGANGenerator(config: ConvGeneratorConfig, *, rngs: Rngs)
Bases: Generator
Deep Convolutional GAN Generator.
Uses transposed convolutions for progressive upsampling from latent vector to output image. All configuration comes from ConvGeneratorConfig.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConvGeneratorConfig
|
ConvGeneratorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConvGeneratorConfig |
initial_linear
instance-attribute
¤
initial_linear = Linear(in_features=latent_dim, out_features=initial_features, rngs=rngs)
output_conv
instance-attribute
¤
output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)
Deep Convolutional GAN generator using transposed convolutions.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
output_shape |
tuple[int, ...] |
Required | Output image shape (C, H, W) |
latent_dim |
int |
100 |
Latent space dimension |
hidden_dims |
tuple[int, ...] |
(256, 128, 64, 32) |
Channel dimensions per layer |
activation |
callable |
jax.nn.relu |
Activation function |
batch_norm |
bool |
True |
Use batch normalization |
dropout_rate |
float |
0.0 |
Dropout rate |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(z, training=True)¤
Generate images from latent vectors.
Parameters:
z(jax.Array): Latent vectors of shape(batch_size, latent_dim)training(bool): Whether in training mode
Returns:
jax.Array: Generated images of shape(batch_size, C, H, W)
Example:
from workshop.generative_models.models.gan import DCGANGenerator
from flax import nnx
import jax
import jax.numpy as jnp
generator = DCGANGenerator(
output_shape=(3, 64, 64), # RGB 64×64 images
latent_dim=100,
hidden_dims=(256, 128, 64, 32),
activation=jax.nn.relu,
batch_norm=True,
rngs=nnx.Rngs(params=0),
)
# Generate samples
z = jax.random.normal(jax.random.key(0), (16, 100))
images = generator(z, training=False)
print(images.shape) # (16, 3, 64, 64)
DCGANDiscriminator¤
workshop.generative_models.models.gan.DCGANDiscriminator
¤
DCGANDiscriminator(config: ConvDiscriminatorConfig, *, rngs: Rngs)
Bases: Discriminator
Deep Convolutional GAN Discriminator.
Uses strided convolutions for progressive downsampling from input image to binary classification. All configuration comes from ConvDiscriminatorConfig.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConvDiscriminatorConfig
|
ConvDiscriminatorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConvDiscriminatorConfig |
Deep Convolutional GAN discriminator using strided convolutions.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape |
tuple[int, ...] |
Required | Input image shape (C, H, W) |
hidden_dims |
tuple[int, ...] |
(32, 64, 128, 256) |
Channel dimensions per layer |
activation |
callable |
jax.nn.leaky_relu |
Activation function |
leaky_relu_slope |
float |
0.2 |
Negative slope for LeakyReLU |
batch_norm |
bool |
False |
Use batch normalization |
dropout_rate |
float |
0.3 |
Dropout rate |
use_spectral_norm |
bool |
True |
Use spectral normalization |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(x, training=True)¤
Classify images as real or fake.
Parameters:
x(jax.Array): Input images of shape(batch_size, C, H, W)training(bool): Whether in training mode
Returns:
jax.Array: Discrimination scores of shape(batch_size, 1)
Example:
from workshop.generative_models.models.gan import DCGANDiscriminator
from flax import nnx
import jax.numpy as jnp
discriminator = DCGANDiscriminator(
input_shape=(3, 64, 64),
hidden_dims=(32, 64, 128, 256),
activation=jax.nn.leaky_relu,
leaky_relu_slope=0.2,
batch_norm=False,
dropout_rate=0.3,
use_spectral_norm=True,
rngs=nnx.Rngs(params=0, dropout=1),
)
# Classify images
images = jnp.ones((16, 3, 64, 64))
scores = discriminator(images, training=True)
print(scores.shape) # (16, 1)
DCGAN¤
workshop.generative_models.models.gan.DCGAN
¤
DCGAN(config: DCGANConfig, *, rngs: Rngs)
Bases: GAN
Deep Convolutional GAN (DCGAN) model.
Uses DCGANConfig which contains ConvGeneratorConfig and ConvDiscriminatorConfig for complete architecture specification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DCGANConfig
|
DCGANConfig with nested ConvGeneratorConfig and ConvDiscriminatorConfig |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not DCGANConfig |
discriminator
instance-attribute
¤
discriminator = DCGANDiscriminator(config=discriminator, rngs=rngs)
Complete Deep Convolutional GAN model.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
config |
DCGANConfiguration |
Required | DCGAN configuration |
rngs |
nnx.Rngs |
Required | Random number generators |
Configuration:
Use DCGANConfiguration from workshop.generative_models.core.configuration.gan:
from workshop.generative_models.core.configuration.gan import DCGANConfiguration
config = DCGANConfiguration(
image_size=64, # Image size (H=W)
channels=3, # Number of channels
latent_dim=100, # Latent dimension
gen_hidden_dims=(256, 128, 64, 32), # Generator channels
disc_hidden_dims=(32, 64, 128, 256), # Discriminator channels
loss_type="vanilla", # Loss type
generator_lr=0.0002, # Generator learning rate
discriminator_lr=0.0002, # Discriminator learning rate
beta1=0.5, # Adam β1
beta2=0.999, # Adam β2
)
Example:
from workshop.generative_models.models.gan import DCGAN
from workshop.generative_models.core.configuration.gan import DCGANConfiguration
from flax import nnx
config = DCGANConfiguration(
image_size=64,
channels=3,
latent_dim=100,
gen_hidden_dims=(256, 128, 64, 32),
disc_hidden_dims=(32, 64, 128, 256),
loss_type="vanilla",
)
dcgan = DCGAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))
# Generate samples
samples = dcgan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))
print(samples.shape) # (16, 3, 64, 64)
WGAN¤
WGANGenerator¤
workshop.generative_models.models.gan.WGANGenerator
¤
WGANGenerator(config: ConvGeneratorConfig, *, rngs: Rngs)
Bases: Generator
Wasserstein GAN Generator using convolutional architecture.
Based on the PyTorch WGAN-GP reference implementation: - Uses ConvTranspose layers like DCGAN - BatchNorm is typically used in WGAN generators - Tanh activation at output
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConvGeneratorConfig
|
ConvGeneratorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConvGeneratorConfig |
initial_linear
instance-attribute
¤
initial_linear = Linear(in_features=latent_dim, out_features=init_h * init_w * hidden_dims_list[0], rngs=rngs)
initial_bn
instance-attribute
¤
initial_bn = BatchNorm(num_features=hidden_dims_list[0], use_running_average=batch_norm_use_running_avg, momentum=batch_norm_momentum, rngs=rngs)
output_conv
instance-attribute
¤
output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)
Wasserstein GAN generator with convolutional architecture.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
output_shape |
tuple[int, ...] |
Required | Output image shape (C, H, W) |
latent_dim |
int |
100 |
Latent space dimension |
hidden_dims |
tuple[int, ...] |
(1024, 512, 256) |
Channel dimensions |
activation |
callable |
jax.nn.relu |
Activation function |
batch_norm |
bool |
True |
Use batch normalization |
dropout_rate |
float |
0.0 |
Dropout rate |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(z, training=True)¤
Generate images from latent vectors.
Parameters:
z(jax.Array): Latent vectors of shape(batch_size, latent_dim)training(bool): Whether in training mode
Returns:
jax.Array: Generated images of shape(batch_size, C, H, W)
Example:
from workshop.generative_models.models.gan import WGANGenerator
from flax import nnx
import jax
generator = WGANGenerator(
output_shape=(3, 64, 64),
latent_dim=100,
hidden_dims=(1024, 512, 256),
activation=jax.nn.relu,
batch_norm=True,
rngs=nnx.Rngs(params=0),
)
z = jax.random.normal(jax.random.key(0), (16, 100))
images = generator(z, training=False)
print(images.shape) # (16, 3, 64, 64)
WGANDiscriminator¤
workshop.generative_models.models.gan.WGANDiscriminator
¤
WGANDiscriminator(config: ConvDiscriminatorConfig, *, rngs: Rngs)
Bases: Discriminator
Wasserstein GAN Discriminator (Critic) using convolutional architecture.
Key differences from standard discriminator: - Uses InstanceNorm instead of BatchNorm (as per WGAN-GP paper) - No sigmoid activation at the end (outputs raw scores) - LeakyReLU activation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConvDiscriminatorConfig
|
ConvDiscriminatorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConvDiscriminatorConfig |
Wasserstein GAN discriminator (critic) with instance normalization.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape |
tuple[int, ...] |
Required | Input image shape (C, H, W) |
hidden_dims |
tuple[int, ...] |
(256, 512, 1024) |
Channel dimensions |
activation |
callable |
jax.nn.leaky_relu |
Activation function |
leaky_relu_slope |
float |
0.2 |
Negative slope for LeakyReLU |
use_instance_norm |
bool |
True |
Use instance normalization |
dropout_rate |
float |
0.0 |
Dropout rate |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(x, training=True)¤
Compute critic scores (no sigmoid activation).
Parameters:
x(jax.Array): Input images of shape(batch_size, C, H, W)training(bool): Whether in training mode
Returns:
jax.Array: Raw critic scores (no sigmoid) of shape(batch_size,)
Example:
from workshop.generative_models.models.gan import WGANDiscriminator
from flax import nnx
import jax.numpy as jnp
discriminator = WGANDiscriminator(
input_shape=(3, 64, 64),
hidden_dims=(256, 512, 1024),
activation=jax.nn.leaky_relu,
use_instance_norm=True,
rngs=nnx.Rngs(params=0),
)
images = jnp.ones((16, 3, 64, 64))
scores = discriminator(images, training=True)
print(scores.shape) # (16,)
# Note: No sigmoid, scores can be any real number
WGAN¤
workshop.generative_models.models.gan.WGAN
¤
WGAN(config: WGANConfig, *, rngs: Rngs)
Bases: Module
Wasserstein GAN with Gradient Penalty (WGAN-GP) model.
Based on the PyTorch reference implementation with proper convolutional architecture.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WGANConfig
|
WGANConfig with nested ConvGeneratorConfig and ConvDiscriminatorConfig |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not WGANConfig |
generate
¤
generate(n_samples: int = 1, *, rngs: Rngs | None = None, batch_size: int | None = None, **kwargs) -> Array
Generate samples from the generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
rngs
|
Rngs | None
|
Random number generator |
None
|
batch_size
|
int | None
|
Alternative way to specify number of samples (for compatibility) |
None
|
**kwargs
|
Additional keyword arguments |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
discriminator_loss
¤
Complete Wasserstein GAN with Gradient Penalty (WGAN-GP).
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
config |
ModelConfiguration |
Required | Model configuration |
rngs |
nnx.Rngs |
Required | Random number generators |
precision |
jax.lax.Precision |
None |
Numerical precision |
Configuration:
from workshop.generative_models.core.configuration import ModelConfiguration
config = ModelConfiguration(
input_dim=100, # Latent dimension
output_dim=(3, 64, 64), # Output image shape
hidden_dims=None, # Use defaults
metadata={
"gan_params": {
"gen_hidden_dims": (1024, 512, 256),
"disc_hidden_dims": (256, 512, 1024),
"gradient_penalty_weight": 10.0, # Lambda for GP
"critic_iterations": 5, # Critic updates per generator
}
}
)
Methods:
generate(n_samples=1, rngs=None, batch_size=None, **kwargs)¤
Generate samples from generator.
Parameters:
n_samples(int): Number of samplesrngs(nnx.Rngs, optional): Random number generatorsbatch_size(int, optional): Alternative ton_samples
Returns:
jax.Array: Generated samples
discriminator_loss(real_samples, fake_samples, rngs)¤
Compute WGAN-GP discriminator loss with gradient penalty.
Parameters:
real_samples(jax.Array): Real imagesfake_samples(jax.Array): Generated imagesrngs(nnx.Rngs): Random number generators
Returns:
jax.Array: Discriminator loss (scalar)
Loss Formula: $$ \mathcal{L}D = \mathbb{E}[D(G(z))] - \mathbb{E}[D(x)] + \lambda \mathbb{E}[(|\nabla)|_2 - 1)^2] $$}} D(\hat{x
generator_loss(fake_samples)¤
Compute WGAN generator loss.
Parameters:
fake_samples(jax.Array): Generated images
Returns:
jax.Array: Generator loss (scalar)
Loss Formula: $$ \mathcal{L}_G = -\mathbb{E}[D(G(z))] $$
Example:
from workshop.generative_models.models.gan import WGAN
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx
config = ModelConfiguration(
input_dim=100,
output_dim=(3, 64, 64),
metadata={
"gan_params": {
"gen_hidden_dims": (1024, 512, 256),
"disc_hidden_dims": (256, 512, 1024),
"gradient_penalty_weight": 10.0,
"critic_iterations": 5,
}
}
)
wgan = WGAN(config, rngs=nnx.Rngs(params=0, sample=1))
# Generate samples
samples = wgan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))
print(samples.shape) # (16, 3, 64, 64)
# Training step
import jax
real_samples = jax.random.normal(jax.random.key(0), (32, 3, 64, 64))
z = jax.random.normal(jax.random.key(1), (32, 100))
fake_samples = wgan.generator(z, training=True)
disc_loss = wgan.discriminator_loss(real_samples, fake_samples, rngs=nnx.Rngs(params=2))
gen_loss = wgan.generator_loss(fake_samples)
compute_gradient_penalty¤
workshop.generative_models.models.gan.compute_gradient_penalty
¤
compute_gradient_penalty(discriminator: WGANDiscriminator, real_samples: Array, fake_samples: Array, rngs: Rngs, lambda_gp: float = 10.0) -> Array
Compute gradient penalty for WGAN-GP.
The gradient penalty enforces the Lipschitz constraint by penalizing the discriminator when the gradient norm deviates from 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
discriminator
|
WGANDiscriminator
|
The discriminator network. |
required |
real_samples
|
Array
|
Real samples from the dataset. |
required |
fake_samples
|
Array
|
Generated fake samples. |
required |
rngs
|
Rngs
|
Random number generators for interpolation. |
required |
lambda_gp
|
float
|
Gradient penalty weight. |
10.0
|
Returns:
| Type | Description |
|---|---|
Array
|
Gradient penalty loss value. |
Compute gradient penalty for WGAN-GP.
Module: workshop.generative_models.models.gan
Function Signature:
def compute_gradient_penalty(
discriminator: WGANDiscriminator,
real_samples: jax.Array,
fake_samples: jax.Array,
rngs: nnx.Rngs,
lambda_gp: float = 10.0,
) -> jax.Array
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
discriminator |
WGANDiscriminator |
Required | Discriminator network |
real_samples |
jax.Array |
Required | Real images |
fake_samples |
jax.Array |
Required | Generated images |
rngs |
nnx.Rngs |
Required | Random number generators |
lambda_gp |
float |
10.0 |
Gradient penalty weight |
Returns:
jax.Array: Gradient penalty loss (scalar)
Formula: $$ \text{GP} = \lambda \mathbb{E}{\hat{x}}[(|\nabla)|_2 - 1)^2] $$}} D(\hat{x
where \(\hat{x} = \epsilon x + (1-\epsilon)G(z)\) is a random interpolation.
Example:
from workshop.generative_models.models.gan import compute_gradient_penalty, WGANDiscriminator
from flax import nnx
import jax
discriminator = WGANDiscriminator(
input_shape=(3, 64, 64),
rngs=nnx.Rngs(params=0),
)
real_samples = jax.random.normal(jax.random.key(0), (32, 3, 64, 64))
fake_samples = jax.random.normal(jax.random.key(1), (32, 3, 64, 64))
gp = compute_gradient_penalty(
discriminator,
real_samples,
fake_samples,
rngs=nnx.Rngs(params=2),
lambda_gp=10.0,
)
print(f"Gradient Penalty: {gp:.4f}")
LSGAN¤
LSGANGenerator¤
workshop.generative_models.models.gan.LSGANGenerator
¤
LSGANGenerator(config: ConvGeneratorConfig, *, rngs: Rngs)
Bases: Generator
Least Squares GAN Generator using convolutional architecture.
LSGAN uses the same architecture as DCGAN but with least squares loss instead of the standard adversarial loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConvGeneratorConfig
|
ConvGeneratorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConvGeneratorConfig |
initial_linear
instance-attribute
¤
initial_linear = Linear(in_features=latent_dim, out_features=init_h * init_w * hidden_dims_list[0], rngs=rngs)
initial_bn
instance-attribute
¤
initial_bn = BatchNorm(num_features=hidden_dims_list[0], use_running_average=batch_norm_use_running_avg, momentum=batch_norm_momentum, rngs=rngs)
output_conv
instance-attribute
¤
output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)
Least Squares GAN generator (same architecture as DCGAN).
Module: workshop.generative_models.models.gan
Parameters: Same as DCGANGenerator
Methods: Same as DCGANGenerator
Example:
from workshop.generative_models.models.gan import LSGANGenerator
from flax import nnx
import jax
generator = LSGANGenerator(
output_shape=(3, 64, 64),
latent_dim=100,
hidden_dims=(512, 256, 128, 64),
rngs=nnx.Rngs(params=0),
)
z = jax.random.normal(jax.random.key(0), (16, 100))
images = generator(z, training=False)
LSGANDiscriminator¤
workshop.generative_models.models.gan.LSGANDiscriminator
¤
LSGANDiscriminator(config: ConvDiscriminatorConfig, *, rngs: Rngs)
Bases: Discriminator
Least Squares GAN Discriminator using convolutional architecture.
LSGAN discriminator uses the same architecture as DCGAN discriminator but with least squares loss instead of sigmoid cross-entropy loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConvDiscriminatorConfig
|
ConvDiscriminatorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConvDiscriminatorConfig |
output_layer
instance-attribute
¤
Least Squares GAN discriminator (no sigmoid activation).
Module: workshop.generative_models.models.gan
Parameters: Same as DCGANDiscriminator
Key Difference: Output layer has no sigmoid activation (outputs raw logits for least squares loss).
Example:
from workshop.generative_models.models.gan import LSGANDiscriminator
from flax import nnx
import jax.numpy as jnp
discriminator = LSGANDiscriminator(
input_shape=(3, 64, 64),
hidden_dims=(64, 128, 256, 512),
rngs=nnx.Rngs(params=0, dropout=1),
)
images = jnp.ones((16, 3, 64, 64))
scores = discriminator(images, training=True)
# Note: Scores are raw logits, not in [0, 1]
LSGAN¤
workshop.generative_models.models.gan.LSGAN
¤
LSGAN(config: LSGANConfig, *, rngs: Rngs)
Bases: Module
Least Squares GAN implementation.
LSGAN replaces the log loss in the original GAN formulation with a least squares loss, which provides more stable training and better quality gradients for the generator.
Reference
Mao et al. "Least Squares Generative Adversarial Networks" (2017)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
LSGANConfig
|
LSGANConfig with nested ConvGeneratorConfig and ConvDiscriminatorConfig |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not LSGANConfig |
discriminator
instance-attribute
¤
discriminator = LSGANDiscriminator(config=disc_config, rngs=rngs)
generator_loss
¤
Compute LSGAN generator loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fake_scores
|
Array
|
Discriminator scores for fake samples |
required |
target_real
|
float
|
Target value for fake samples (usually 1.0) |
1.0
|
reduction
|
str
|
Reduction method ('mean', 'sum', 'none') |
'mean'
|
Returns:
| Type | Description |
|---|---|
Array
|
Generator loss |
discriminator_loss
¤
discriminator_loss(real_scores: Array, fake_scores: Array, target_real: float = 1.0, target_fake: float = 0.0, reduction: str = 'mean') -> Array
Compute LSGAN discriminator loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
real_scores
|
Array
|
Discriminator scores for real samples |
required |
fake_scores
|
Array
|
Discriminator scores for fake samples |
required |
target_real
|
float
|
Target value for real samples (usually 1.0) |
1.0
|
target_fake
|
float
|
Target value for fake samples (usually 0.0) |
0.0
|
reduction
|
str
|
Reduction method ('mean', 'sum', 'none') |
'mean'
|
Returns:
| Type | Description |
|---|---|
Array
|
Discriminator loss |
training_step
¤
Perform a single training step.
Note: Use model.train() for training mode and model.eval() for evaluation mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
real_images
|
Array
|
Batch of real images |
required |
latent_vectors
|
Array
|
Batch of latent vectors |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary with loss values and generated images |
Complete Least Squares GAN model.
Module: workshop.generative_models.models.gan
Parameters: Same as base GAN
Methods:
generator_loss(fake_scores, target_real=1.0, reduction="mean")¤
Compute LSGAN generator loss.
Formula: $$ \mathcal{L}_G = \frac{1}{2}\mathbb{E}[(D(G(z)) - c)^2] $$
where \(c\) is the target value for fake samples (usually 1.0).
discriminator_loss(real_scores, fake_scores, target_real=1.0, target_fake=0.0, reduction="mean")¤
Compute LSGAN discriminator loss.
Formula: $$ \mathcal{L}_D = \frac{1}{2}\mathbb{E}[(D(x) - b)^2] + \frac{1}{2}\mathbb{E}[D(G(z))^2] $$
where \(b\) is the target for real samples (usually 1.0).
Example:
from workshop.generative_models.models.gan import LSGAN
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx
config = ModelConfiguration(
input_dim=100,
output_dim=(3, 64, 64),
hidden_dims=[512, 256, 128, 64],
)
lsgan = LSGAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))
# Generate samples
samples = lsgan.generate(n_samples=16, rngs=nnx.Rngs(sample=0))
# Compute losses
import jax
real_images = jax.random.normal(jax.random.key(0), (32, 3, 64, 64))
z = jax.random.normal(jax.random.key(1), (32, 100))
fake_images = lsgan.generator(z, training=True)
real_scores = lsgan.discriminator(real_images, training=True)
fake_scores = lsgan.discriminator(fake_images, training=True)
gen_loss = lsgan.generator_loss(fake_scores)
disc_loss = lsgan.discriminator_loss(real_scores, fake_scores)
Conditional GAN¤
ConditionalGenerator¤
workshop.generative_models.models.gan.ConditionalGenerator
¤
ConditionalGenerator(config: ConditionalGeneratorConfig, *, rngs: Rngs)
Bases: Generator
Conditional GAN Generator using convolutional architecture.
The generator is conditioned on class labels by concatenating the label embedding with the noise vector before passing through the network.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConditionalGeneratorConfig
|
ConditionalGeneratorConfig with network architecture and conditional parameters (num_classes, embedding_dim via config.conditional) |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConditionalGeneratorConfig |
label_embedding
instance-attribute
¤
label_embedding = Linear(in_features=num_classes, out_features=embedding_dim, rngs=rngs)
initial_projection
instance-attribute
¤
initial_projection = Linear(in_features=combined_input_dim, out_features=init_h * init_w * hidden_dims_list[0], rngs=rngs)
initial_bn
instance-attribute
¤
initial_bn = BatchNorm(num_features=hidden_dims_list[0], use_running_average=batch_norm_use_running_avg, momentum=batch_norm_momentum, rngs=rngs)
output_conv
instance-attribute
¤
output_conv = ConvTranspose(in_features=hidden_dims_list[-1], out_features=channels, kernel_size=kernel_size, strides=stride, padding=padding, rngs=rngs)
Conditional GAN generator that takes class labels as input.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
output_shape |
tuple[int, ...] |
Required | Output image shape (C, H, W) |
num_classes |
int |
Required | Number of classes |
latent_dim |
int |
100 |
Latent space dimension |
hidden_dims |
tuple[int, ...] |
(512, 256, 128, 64) |
Channel dimensions |
activation |
callable |
jax.nn.relu |
Activation function |
batch_norm |
bool |
True |
Use batch normalization |
dropout_rate |
float |
0.0 |
Dropout rate |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(z, labels, training=True)¤
Generate samples conditioned on labels.
Parameters:
z(jax.Array): Latent vectors of shape(batch_size, latent_dim)labels(jax.Array): One-hot encoded labels of shape(batch_size, num_classes)training(bool): Whether in training mode
Returns:
jax.Array: Generated images of shape(batch_size, C, H, W)
Example:
from workshop.generative_models.models.gan import ConditionalGenerator
from flax import nnx
import jax
import jax.numpy as jnp
generator = ConditionalGenerator(
output_shape=(1, 28, 28), # MNIST
num_classes=10,
latent_dim=100,
hidden_dims=(512, 256, 128, 64),
rngs=nnx.Rngs(params=0),
)
# Generate specific digits
z = jax.random.normal(jax.random.key(0), (10, 100))
labels = jax.nn.one_hot(jnp.arange(10), 10) # One of each digit
images = generator(z, labels, training=False)
print(images.shape) # (10, 1, 28, 28)
ConditionalDiscriminator¤
workshop.generative_models.models.gan.ConditionalDiscriminator
¤
ConditionalDiscriminator(config: ConditionalDiscriminatorConfig, *, rngs: Rngs)
Bases: Discriminator
Conditional GAN Discriminator using convolutional architecture.
The discriminator is conditioned on class labels by concatenating the label embedding with the input image before passing through the network.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConditionalDiscriminatorConfig
|
ConditionalDiscriminatorConfig with network architecture and conditional parameters (num_classes, embedding_dim via config.conditional) |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConditionalDiscriminatorConfig |
Conditional GAN discriminator that takes class labels as input.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape |
tuple[int, ...] |
Required | Input image shape (C, H, W) |
num_classes |
int |
Required | Number of classes |
hidden_dims |
tuple[int, ...] |
(64, 128, 256, 512) |
Channel dimensions |
activation |
callable |
jax.nn.leaky_relu |
Activation function |
leaky_relu_slope |
float |
0.2 |
Negative slope for LeakyReLU |
batch_norm |
bool |
False |
Use batch normalization |
dropout_rate |
float |
0.0 |
Dropout rate |
rngs |
nnx.Rngs |
Required | Random number generators |
Methods:
__call__(x, labels, training=True)¤
Classify samples conditioned on labels.
Parameters:
x(jax.Array): Input images of shape(batch_size, C, H, W)labels(jax.Array): One-hot encoded labels of shape(batch_size, num_classes)training(bool): Whether in training mode
Returns:
jax.Array: Discrimination scores of shape(batch_size,)
Example:
from workshop.generative_models.models.gan import ConditionalDiscriminator
from flax import nnx
import jax
import jax.numpy as jnp
discriminator = ConditionalDiscriminator(
input_shape=(1, 28, 28),
num_classes=10,
hidden_dims=(64, 128, 256, 512),
rngs=nnx.Rngs(params=0, dropout=1),
)
# Classify samples with labels
images = jax.random.normal(jax.random.key(0), (32, 1, 28, 28))
labels = jax.nn.one_hot(jnp.zeros(32, dtype=int), 10) # All zeros
scores = discriminator(images, labels, training=True)
print(scores.shape) # (32,)
ConditionalGAN¤
workshop.generative_models.models.gan.ConditionalGAN
¤
ConditionalGAN(config: ConditionalGANConfig, *, rngs: Rngs)
Bases: Module
Conditional Generative Adversarial Network (CGAN).
Based on "Conditional Generative Adversarial Nets" by Mirza & Osindero (2014). The generator and discriminator are both conditioned on class labels.
Uses composition pattern: conditional parameters (num_classes, embedding_dim) are embedded in the nested ConditionalGeneratorConfig and ConditionalDiscriminatorConfig via ConditionalParams.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConditionalGANConfig
|
ConditionalGANConfig with nested ConditionalGeneratorConfig and ConditionalDiscriminatorConfig. All parameters are specified in the config objects. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not ConditionalGANConfig |
discriminator
instance-attribute
¤
discriminator = ConditionalDiscriminator(config=disc_config, rngs=rngs)
generate
¤
generate(n_samples: int = 1, labels: Array | None = None, *, rngs: Rngs | None = None, batch_size: int | None = None, **kwargs) -> Array
Generate conditional samples from the generator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
labels
|
Array | None
|
One-hot encoded labels of shape (n_samples, num_classes) |
None
|
rngs
|
Rngs | None
|
Random number generator |
None
|
batch_size
|
int | None
|
Alternative way to specify number of samples (for compatibility) |
None
|
**kwargs
|
Additional keyword arguments |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
discriminator_loss
¤
discriminator_loss(real_samples: Array, fake_samples: Array, real_labels: Array, fake_labels: Array) -> Array
Compute conditional discriminator loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
real_samples
|
Array
|
Real samples from dataset. |
required |
fake_samples
|
Array
|
Generated fake samples. |
required |
real_labels
|
Array
|
Labels for real samples. |
required |
fake_labels
|
Array
|
Labels for fake samples. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Discriminator loss value. |
generator_loss
¤
Complete Conditional GAN model.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
config |
ModelConfiguration |
Required | Model configuration |
rngs |
nnx.Rngs |
Required | Random number generators |
precision |
jax.lax.Precision |
None |
Numerical precision |
Configuration:
from workshop.generative_models.core.configuration import ModelConfiguration
config = ModelConfiguration(
input_dim=100, # Latent dimension
output_dim=(1, 28, 28), # MNIST shape
metadata={
"gan_params": {
"num_classes": 10, # Number of classes
"gen_hidden_dims": (512, 256, 128, 64),
"discriminator_features": [64, 128, 256, 512],
}
}
)
Methods:
generate(n_samples=1, labels=None, rngs=None, batch_size=None, **kwargs)¤
Generate conditional samples.
Parameters:
n_samples(int): Number of sampleslabels(jax.Array, optional): One-hot encoded labels. IfNone, random labels are used.rngs(nnx.Rngs, optional): Random number generatorsbatch_size(int, optional): Alternative ton_samples
Returns:
jax.Array: Generated samples
Example:
from workshop.generative_models.models.gan import ConditionalGAN
from workshop.generative_models.core.configuration import ModelConfiguration
from flax import nnx
import jax
import jax.numpy as jnp
config = ModelConfiguration(
input_dim=100,
output_dim=(1, 28, 28),
metadata={
"gan_params": {
"num_classes": 10,
"gen_hidden_dims": (512, 256, 128, 64),
"discriminator_features": [64, 128, 256, 512],
}
}
)
cgan = ConditionalGAN(config, rngs=nnx.Rngs(params=0, dropout=1, sample=2))
# Generate specific digits
labels = jax.nn.one_hot(jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 10)
samples = cgan.generate(n_samples=10, labels=labels, rngs=nnx.Rngs(sample=0))
print(samples.shape) # (10, 1, 28, 28)
CycleGAN¤
CycleGANGenerator¤
workshop.generative_models.models.gan.CycleGANGenerator
¤
CycleGANGenerator(config: CycleGANGeneratorConfig, *, rngs: Rngs)
Bases: Module
CycleGAN Generator for image-to-image translation.
Uses a ResNet-based architecture with reflection padding as described in the original CycleGAN paper. This follows the pytorch reference implementation more closely.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CycleGANGeneratorConfig
|
CycleGANGeneratorConfig with network architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not CycleGANGeneratorConfig |
initial_conv
instance-attribute
¤
initial_conv = Conv(in_features=input_channels, out_features=hidden_dims[0], kernel_size=(7, 7), strides=(1, 1), padding='SAME', rngs=rngs)
CycleGAN generator for image-to-image translation.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
output_shape |
tuple[int, ...] |
Required | Output image shape (C, H, W) |
hidden_dims |
tuple[int, ...] |
(64, 128, 256) |
Channel dimensions |
num_residual_blocks |
int |
9 |
Number of ResNet blocks |
rngs |
nnx.Rngs |
Required | Random number generators |
CycleGANDiscriminator¤
workshop.generative_models.models.gan.CycleGANDiscriminator
¤
CycleGANDiscriminator(config: PatchGANDiscriminatorConfig, *, rngs: Rngs)
Bases: Module
CycleGAN Discriminator (PatchGAN-style).
Uses a PatchGAN discriminator that classifies patches of the input as real or fake, rather than the entire image.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PatchGANDiscriminatorConfig
|
PatchGANDiscriminatorConfig with network architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not PatchGANDiscriminatorConfig |
activation
instance-attribute
¤
activation = lambda x: leaky_relu(x, negative_slope=leaky_relu_slope)
CycleGAN PatchGAN discriminator.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape |
tuple[int, ...] |
Required | Input image shape (C, H, W) |
hidden_dims |
tuple[int, ...] |
(64, 128, 256, 512) |
Channel dimensions |
rngs |
nnx.Rngs |
Required | Random number generators |
CycleGAN¤
workshop.generative_models.models.gan.CycleGAN
¤
CycleGAN(config: CycleGANConfig, *, rngs: Rngs)
Bases: GenerativeModel
CycleGAN for unpaired image-to-image translation.
Implements the complete CycleGAN architecture with two generators and two discriminators for bidirectional domain translation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
CycleGANConfig
|
CycleGANConfig with nested CycleGANGeneratorConfig and PatchGANDiscriminatorConfig objects |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not CycleGANConfig |
generator_a_to_b
instance-attribute
¤
generator_a_to_b = CycleGANGenerator(config=gen_a_to_b_config, rngs=rngs)
generator_b_to_a
instance-attribute
¤
generator_b_to_a = CycleGANGenerator(config=gen_b_to_a_config, rngs=rngs)
discriminator_a
instance-attribute
¤
discriminator_a = CycleGANDiscriminator(config=disc_a_config, rngs=rngs)
discriminator_b
instance-attribute
¤
discriminator_b = CycleGANDiscriminator(config=disc_b_config, rngs=rngs)
generate
¤
generate(n_samples: int = 1, *, rngs: Rngs | None = None, batch_size: int | None = None, domain: str = 'a_to_b', input_images: Array | None = None, **kwargs) -> Array
Generate translated images.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate (ignored if input_images provided). |
1
|
rngs
|
Rngs | None
|
Random number generators. |
None
|
batch_size
|
int | None
|
Alternative way to specify number of samples (for compatibility). |
None
|
domain
|
str
|
Translation direction ("a_to_b" or "b_to_a"). |
'a_to_b'
|
input_images
|
Array | None
|
Input images to translate. If None, generates random input. |
None
|
**kwargs
|
Additional keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Translated images. |
compute_cycle_loss
¤
compute_identity_loss
¤
Compute identity losses.
Identity loss encourages generators to preserve color composition when translating images that already belong to the target domain.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
real_a
|
Array
|
Real images from domain A. |
required |
real_b
|
Array
|
Real images from domain B. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple of (identity_loss_a, identity_loss_b). |
loss_fn
¤
loss_fn(batch: dict[str, Any], model_outputs: dict[str, Any], *, rngs: Rngs | None = None, **kwargs) -> dict[str, Any]
Compute total CycleGAN loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Any]
|
Batch containing real images from both domains. |
required |
model_outputs
|
dict[str, Any]
|
Model outputs (not used in basic implementation). |
required |
rngs
|
Rngs | None
|
Random number generators. |
None
|
**kwargs
|
Additional keyword arguments. |
{}
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary containing loss and metrics. |
Complete CycleGAN model for unpaired image-to-image translation.
Module: workshop.generative_models.models.gan
Key Features:
- Two generators:
G: X → YandF: Y → X - Two discriminators:
D_XandD_Y - Cycle consistency loss:
x → G(x) → F(G(x)) ≈ x - Identity loss (optional):
F(x) ≈ xifx ∈ X
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape_x |
tuple |
Required | Domain X image shape (C, H, W) |
input_shape_y |
tuple |
Required | Domain Y image shape (C, H, W) |
gen_hidden_dims |
tuple |
(64, 128, 256) |
Generator channels |
disc_hidden_dims |
tuple |
(64, 128, 256) |
Discriminator channels |
cycle_weight |
float |
10.0 |
Cycle consistency weight |
identity_weight |
float |
0.5 |
Identity loss weight |
rngs |
nnx.Rngs |
Required | Random number generators |
Example:
from workshop.generative_models.models.gan import CycleGAN
from flax import nnx
import jax
cyclegan = CycleGAN(
input_shape_x=(3, 256, 256), # Horses
input_shape_y=(3, 256, 256), # Zebras
gen_hidden_dims=(64, 128, 256),
disc_hidden_dims=(64, 128, 256),
cycle_weight=10.0,
identity_weight=0.5,
rngs=nnx.Rngs(params=0, dropout=1),
)
# Translate horse to zebra
horse_images = jax.random.normal(jax.random.key(0), (4, 3, 256, 256))
zebra_images = cyclegan.generator_g(horse_images, training=False)
print(zebra_images.shape) # (4, 3, 256, 256)
# Translate zebra back to horse (cycle consistency)
reconstructed_horses = cyclegan.generator_f(zebra_images, training=False)
print(reconstructed_horses.shape) # (4, 3, 256, 256)
PatchGAN¤
PatchGANDiscriminator¤
workshop.generative_models.models.gan.PatchGANDiscriminator
¤
PatchGANDiscriminator(config: PatchGANDiscriminatorConfig, *, rngs: Rngs)
Bases: Discriminator
PatchGAN Discriminator for image-to-image translation.
The PatchGAN discriminator classifies whether N×N patches in an image are real or fake, rather than classifying the entire image. This is particularly effective for image translation tasks where local texture and structure are important.
Reference
Isola et al. "Image-to-Image Translation with Conditional Adversarial Networks" (2017) Wang et al. "High-Resolution Image Synthesis and Semantic Manipulation" (2018)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
PatchGANDiscriminatorConfig
|
PatchGANDiscriminatorConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None |
TypeError
|
If config is not PatchGANDiscriminatorConfig |
patchgan_activation_fn
instance-attribute
¤
patchgan_activation_fn = lambda x: leaky_relu(x, negative_slope=leaky_relu_slope)
initial_conv
instance-attribute
¤
initial_conv = Conv(in_features=channels, out_features=num_filters, kernel_size=kernel_size, strides=stride, padding='SAME', use_bias=True, rngs=rngs)
PatchGAN discriminator that outputs N×N array of patch predictions.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape |
tuple[int, ...] |
Required | Input image shape (C, H, W) |
hidden_dims |
tuple[int, ...] |
(64, 128, 256, 512) |
Channel dimensions |
kernel_size |
int |
4 |
Convolution kernel size |
stride |
int |
2 |
Convolution stride |
rngs |
nnx.Rngs |
Required | Random number generators |
Returns: N×N array of patch classifications instead of single scalar
Example:
from workshop.generative_models.models.gan import PatchGANDiscriminator
from flax import nnx
import jax.numpy as jnp
discriminator = PatchGANDiscriminator(
input_shape=(3, 256, 256),
hidden_dims=(64, 128, 256, 512),
kernel_size=4,
stride=2,
rngs=nnx.Rngs(params=0, dropout=1),
)
images = jnp.ones((16, 3, 256, 256))
patch_scores = discriminator(images, training=True)
print(patch_scores.shape) # (16, H', W', 1) - array of patch predictions
MultiScalePatchGANDiscriminator¤
workshop.generative_models.models.gan.MultiScalePatchGANDiscriminator
¤
MultiScalePatchGANDiscriminator(config: MultiScalePatchGANConfig, *, rngs: Rngs)
Bases: Module
Multi-scale PatchGAN discriminator.
Processes images at multiple scales using several PatchGAN discriminators. This allows the discriminator to capture both fine-grained and coarse-grained features at different resolutions.
Reference
Wang et al. "High-Resolution Image Synthesis and Semantic Manipulation" (2018)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
MultiScalePatchGANConfig
|
MultiScalePatchGANConfig with all architecture parameters |
required |
rngs
|
Rngs
|
Random number generators |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If rngs is None or configuration is invalid |
TypeError
|
If config is not MultiScalePatchGANConfig |
downsample_image
¤
Multi-scale PatchGAN discriminator operating at multiple resolutions.
Module: workshop.generative_models.models.gan
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
input_shape |
tuple[int, ...] |
Required | Input image shape (C, H, W) |
hidden_dims |
tuple[int, ...] |
(64, 128, 256) |
Channel dimensions |
num_scales |
int |
3 |
Number of scales |
rngs |
nnx.Rngs |
Required | Random number generators |
Returns: List of predictions at different scales
Example:
from workshop.generative_models.models.gan import MultiScalePatchGANDiscriminator
from flax import nnx
import jax.numpy as jnp
discriminator = MultiScalePatchGANDiscriminator(
input_shape=(3, 256, 256),
hidden_dims=(64, 128, 256),
num_scales=3,
rngs=nnx.Rngs(params=0, dropout=1),
)
images = jnp.ones((16, 3, 256, 256))
predictions = discriminator(images, training=True)
# predictions is a list of 3 arrays at different scales
for i, pred in enumerate(predictions):
print(f"Scale {i}: {pred.shape}")
Loss Functions¤
See Adversarial Loss Functions for detailed documentation of:
vanilla_generator_lossvanilla_discriminator_lossleast_squares_generator_lossleast_squares_discriminator_losswasserstein_generator_losswasserstein_discriminator_losshinge_generator_losshinge_discriminator_loss
Summary¤
This API reference covered all GAN model classes:
- Base Classes:
Generator,Discriminator,GAN - DCGAN:
DCGANGenerator,DCGANDiscriminator,DCGAN - WGAN:
WGANGenerator,WGANDiscriminator,WGAN,compute_gradient_penalty - LSGAN:
LSGANGenerator,LSGANDiscriminator,LSGAN - Conditional GAN:
ConditionalGenerator,ConditionalDiscriminator,ConditionalGAN - CycleGAN:
CycleGANGenerator,CycleGANDiscriminator,CycleGAN - PatchGAN:
PatchGANDiscriminator,MultiScalePatchGANDiscriminator
See Also¤
- GAN Concepts: Theory and mathematical foundations
- GAN User Guide: Practical usage examples
- GAN MNIST Example: Complete training tutorial
- Adversarial Losses: Loss function reference