Skip to content

VAE API Reference¤

Complete API documentation for Variational Autoencoder models in Workshop.

Module Overview¤

from workshop.generative_models.models.vae import (
    VAE,                    # Base VAE class
    BetaVAE,               # β-VAE with disentanglement
    BetaVAEWithCapacity,   # β-VAE with capacity control
    ConditionalVAE,        # Conditional VAE
    VQVAE,                 # Vector Quantized VAE
)

from workshop.generative_models.models.vae.encoders import (
    MLPEncoder,            # Fully-connected encoder
    CNNEncoder,            # Convolutional encoder
    ResNetEncoder,         # ResNet-based encoder
    ConditionalEncoder,    # Conditional wrapper
)

from workshop.generative_models.models.vae.decoders import (
    MLPDecoder,            # Fully-connected decoder
    CNNDecoder,            # Transposed convolutional decoder
    ResNetDecoder,         # ResNet-based decoder
    ConditionalDecoder,    # Conditional wrapper
)

Base Classes¤

VAE¤

workshop.generative_models.models.vae.base.VAE ¤

VAE(config: VAEConfig, *, rngs: Rngs, precision: Precision | None = None)

Bases: GenerativeModel

Base class for Variational Autoencoders.

This class provides a foundation for implementing various VAE models using Flax NNX. All VAE models should inherit from this class and implement the required methods.

Parameters:

Name Type Description Default
config VAEConfig

VAEConfig with encoder, decoder, encoder_type, and kl_weight settings

required
rngs Rngs

Random number generators

required
precision Precision | None

Numerical precision for computations

None

latent_dim instance-attribute ¤

latent_dim = latent_dim

kl_weight instance-attribute ¤

kl_weight = kl_weight

encoder instance-attribute ¤

encoder = create_encoder(encoder, encoder_type, rngs=rngs)

decoder instance-attribute ¤

decoder = create_decoder(decoder, encoder_type, rngs=rngs)

encode ¤

encode(x: Array) -> tuple[Array, Array]

Encode input to the latent space.

Parameters:

Name Type Description Default
x Array

Input data

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (mean, log_var) of the latent distribution

Raises:

Type Description
ValueError

If encoder output format is invalid

decode ¤

decode(z: Array) -> Array

Decode latent vectors to the output space.

Parameters:

Name Type Description Default
z Array

Latent vectors

required

Returns:

Type Description
Array

Reconstructed outputs

Raises:

Type Description
ValueError

If decoder output format is invalid

reparameterize ¤

reparameterize(mean: Array, log_var: Array) -> Array

Apply the reparameterization trick.

Parameters:

Name Type Description Default
mean Array

Mean vectors of the latent distribution

required
log_var Array

Log variance vectors of the latent distribution

required

Returns:

Type Description
Array

Sampled latent vectors

loss_fn ¤

loss_fn(params: dict | None = None, batch: dict | None = None, rng: Array | None = None, x: Array | None = None, outputs: dict[str, Array] | None = None, beta: float | None = None, reconstruction_loss_fn: Callable | None = None, **kwargs) -> dict[str, Array]

Calculate loss for VAE.

Parameters:

Name Type Description Default
params dict | None

Model parameters (optional, for compatibility with Trainer)

None
batch dict | None

Input batch (optional, for compatibility with Trainer)

None
rng Array | None

Random number generator (optional, for compatibility with Trainer)

None
x Array | None

Input data (if not provided in batch)

None
outputs dict[str, Array] | None

Dictionary of model outputs from forward pass

None
beta float | None

Weight for KL divergence term

None
reconstruction_loss_fn Callable | None

Optional custom reconstruction loss function

None
**kwargs

Additional arguments

{}

Returns:

Type Description
dict[str, Array]

Dictionary of loss components

sample ¤

sample(n_samples: int = 1, *, temperature: float = 1.0) -> Array

Sample from the prior distribution.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate. Must be a Python int (static value). When JIT-compiling, mark this as a static argument.

1
temperature float

Scaling factor for the standard deviation (higher = more diverse)

1.0

Returns:

Type Description
Array

Generated samples

Note

When JIT-compiling functions that call this method, mark n_samples as static: Example: @nnx.jit(static_argnums=(1,)) or @jax.jit(static_argnums=(1,))

reconstruct ¤

reconstruct(x: Array, *, deterministic: bool = False) -> Array

Reconstruct inputs.

Parameters:

Name Type Description Default
x Array

Input data

required
deterministic bool

If True, use mean of the latent distribution instead of sampling

False

Returns:

Type Description
Array

Reconstructed outputs

generate ¤

generate(n_samples: int = 1, *, temperature: float = 1.0, **kwargs) -> Array

Generate samples from the model.

This is an alias for the sample method to maintain consistency with the GenerativeModel interface.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
temperature float

Scaling factor for the standard deviation (higher = more diverse)

1.0
**kwargs

Additional arguments (unused, for compatibility)

{}

Returns:

Type Description
Array

Generated samples

interpolate ¤

interpolate(x1: Array, x2: Array, steps: int = 10) -> Array

Interpolate between two inputs in latent space.

Parameters:

Name Type Description Default
x1 Array

First input

required
x2 Array

Second input

required
steps int

Number of interpolation steps (including endpoints). Must be a Python int (static value) when JIT-compiling.

10

Returns:

Type Description
Array

Interpolated outputs

Note

When JIT-compiling functions that call this method, mark steps as static: Example: @nnx.jit(static_argnums=(3,)) for the third argument.

latent_traversal ¤

latent_traversal(x: Array, dim: int, range_vals: tuple[float, float] = (-3.0, 3.0), steps: int = 10) -> Array

Traverse a single dimension of the latent space.

Parameters:

Name Type Description Default
x Array

Input data

required
dim int

Dimension to traverse

required
range_vals tuple[float, float]

Range of values for traversal

(-3.0, 3.0)
steps int

Number of steps in the traversal

10

Returns:

Type Description
Array

Decoded outputs from the traversal

Raises:

Type Description
ValueError

If dim is out of range

The base VAE class implementing standard Variational Autoencoder functionality.

Class Definition¤

class VAE(GenerativeModel):
    """Base class for Variational Autoencoders."""

    def __init__(
        self,
        encoder: nnx.Module,
        decoder: nnx.Module,
        latent_dim: int,
        *,
        rngs: nnx.Rngs,
        kl_weight: float = 1.0,
        precision: jax.lax.Precision | None = None,
    ) -> None:
        """Initialize a VAE."""

Parameters¤

Parameter Type Default Description
encoder nnx.Module Required Encoder network mapping inputs to latent distributions
decoder nnx.Module Required Decoder network mapping latent codes to reconstructions
latent_dim int Required Dimensionality of the latent space (must be positive)
rngs nnx.Rngs Required Random number generators for initialization and sampling
kl_weight float 1.0 Weight for KL divergence term (β parameter)
precision jax.lax.Precision \| None None Numerical precision for computations

Methods¤

encode¤
def encode(
    self,
    x: jax.Array,
    *,
    rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, jax.Array]:
    """Encode input to latent distribution parameters."""

Parameters:

  • x (Array): Input data with shape (batch_size, *input_shape)
  • rngs (Rngs, optional): Random number generators

Returns:

  • tuple[Array, Array]: Mean and log-variance of latent distribution
  • mean: Shape (batch_size, latent_dim)
  • log_var: Shape (batch_size, latent_dim)

Raises:

  • ValueError: If encoder output format is invalid

Example:

mean, log_var = vae.encode(x, rngs=rngs)
print(f"Mean shape: {mean.shape}")        # (32, 20)
print(f"Log-var shape: {log_var.shape}")  # (32, 20)
decode¤
def decode(
    self,
    z: jax.Array,
    *,
    rngs: nnx.Rngs | None = None
) -> jax.Array:
    """Decode latent vectors to reconstructions."""

Parameters:

  • z (Array): Latent vectors with shape (batch_size, latent_dim)
  • rngs (Rngs, optional): Random number generators

Returns:

  • Array: Reconstructed outputs with shape (batch_size, *output_shape)

Example:

z = jax.random.normal(key, (32, 20))
reconstructed = vae.decode(z, rngs=rngs)
print(f"Reconstruction shape: {reconstructed.shape}")  # (32, 784)
reparameterize¤
@nnx.jit
def reparameterize(
    self,
    mean: jax.Array,
    log_var: jax.Array,
    *,
    rngs: nnx.Rngs | None = None
) -> jax.Array:
    """Apply the reparameterization trick."""

Parameters:

  • mean (Array): Mean vectors with shape (batch_size, latent_dim)
  • log_var (Array): Log-variance vectors with shape (batch_size, latent_dim)
  • rngs (Rngs, optional): Random number generators

Returns:

  • Array: Sampled latent vectors with shape (batch_size, latent_dim)

Details:

Implements the reparameterization trick: \(z = \mu + \sigma \odot \epsilon\) where \(\epsilon \sim \mathcal{N}(0, I)\)

Includes numerical stability via log-variance clipping to [-20, 20].

Example:

mean, log_var = vae.encode(x, rngs=rngs)
z = vae.reparameterize(mean, log_var, rngs=rngs)
__call__¤
def __call__(
    self,
    x: jax.Array,
    *,
    rngs: nnx.Rngs | None = None
) -> dict[str, Any]:
    """Forward pass through the VAE."""

Parameters:

  • x (Array): Input data with shape (batch_size, *input_shape)
  • rngs (Rngs, optional): Random number generators

Returns:

  • dict: Dictionary containing:
  • reconstructed (Array): Reconstructed outputs
  • reconstruction (Array): Alias for compatibility
  • mean (Array): Latent mean vectors
  • log_var (Array): Latent log-variance vectors
  • logvar (Array): Alias for compatibility
  • z (Array): Sampled latent vectors

Example:

outputs = vae(x, rngs=rngs)
print(outputs.keys())
# dict_keys(['reconstructed', 'reconstruction', 'mean', 'log_var', 'logvar', 'z'])
loss_fn¤
def loss_fn(
    self,
    params: dict | None = None,
    batch: dict | None = None,
    rng: jax.Array | None = None,
    x: jax.Array | None = None,
    outputs: dict[str, jax.Array] | None = None,
    beta: float | None = None,
    reconstruction_loss_fn: Callable | None = None,
    **kwargs,
) -> dict[str, jax.Array]:
    """Calculate VAE loss (ELBO)."""

Parameters:

  • params (dict, optional): Model parameters (for Trainer compatibility)
  • batch (dict, optional): Input batch (for Trainer compatibility)
  • rng (Array, optional): Random number generator
  • x (Array, optional): Input data if not in batch
  • outputs (dict, optional): Pre-computed model outputs
  • beta (float, optional): KL divergence weight override
  • reconstruction_loss_fn (Callable, optional): Custom reconstruction loss
  • **kwargs: Additional arguments

Returns:

  • dict: Dictionary containing:
  • reconstruction_loss (Array): Reconstruction error
  • recon_loss (Array): Alias for compatibility
  • kl_loss (Array): KL divergence
  • total_loss (Array): Combined loss
  • loss (Array): Alias for compatibility

Loss Formula:

\[ \mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] + \beta \cdot D_{KL}(q(z|x) \| p(z)) \]

Example:

outputs = vae(x, rngs=rngs)
losses = vae.loss_fn(x=x, outputs=outputs)

print(f"Reconstruction: {losses['reconstruction_loss']:.4f}")
print(f"KL Divergence: {losses['kl_loss']:.4f}")
print(f"Total Loss: {losses['total_loss']:.4f}")
sample¤
def sample(
    self,
    n_samples: int = 1,
    *,
    temperature: float = 1.0,
    rngs: nnx.Rngs | None = None
) -> jax.Array:
    """Sample from the prior distribution."""

Parameters:

  • n_samples (int): Number of samples to generate
  • temperature (float): Scaling factor for sampling diversity (higher = more diverse)
  • rngs (Rngs, optional): Random number generators

Returns:

  • Array: Generated samples with shape (n_samples, *output_shape)

Example:

# Generate 16 samples
samples = vae.sample(n_samples=16, temperature=1.0, rngs=rngs)

# More diverse samples
hot_samples = vae.sample(n_samples=16, temperature=2.0, rngs=rngs)
generate¤
def generate(
    self,
    n_samples: int = 1,
    *,
    temperature: float = 1.0,
    rngs: nnx.Rngs | None = None,
    **kwargs,
) -> jax.Array:
    """Generate samples (alias for sample)."""

Alias for sample() to maintain consistency with GenerativeModel interface.

reconstruct¤
def reconstruct(
    self,
    x: jax.Array,
    *,
    deterministic: bool = False,
    rngs: nnx.Rngs | None = None
) -> jax.Array:
    """Reconstruct inputs."""

Parameters:

  • x (Array): Input data
  • deterministic (bool): If True, use mean instead of sampling
  • rngs (Rngs, optional): Random number generators

Returns:

  • Array: Reconstructed outputs

Example:

# Stochastic reconstruction
recon = vae.reconstruct(x, deterministic=False, rngs=rngs)

# Deterministic reconstruction (use latent mean)
det_recon = vae.reconstruct(x, deterministic=True, rngs=rngs)
interpolate¤
def interpolate(
    self,
    x1: jax.Array,
    x2: jax.Array,
    steps: int = 10,
    *,
    rngs: nnx.Rngs | None = None,
) -> jax.Array:
    """Interpolate between two inputs in latent space."""

Parameters:

  • x1 (Array): First input
  • x2 (Array): Second input
  • steps (int): Number of interpolation steps (including endpoints)
  • rngs (Rngs, optional): Random number generators

Returns:

  • Array: Interpolated outputs with shape (steps, *output_shape)

Example:

x1 = test_images[0]
x2 = test_images[1]
interpolation = vae.interpolate(x1, x2, steps=10, rngs=rngs)
latent_traversal¤
def latent_traversal(
    self,
    x: jax.Array,
    dim: int,
    range_vals: tuple[float, float] = (-3.0, 3.0),
    steps: int = 10,
    *,
    rngs: nnx.Rngs | None = None,
) -> jax.Array:
    """Traverse a single latent dimension."""

Parameters:

  • x (Array): Input data
  • dim (int): Dimension to traverse (0 to latent_dim-1)
  • range_vals (tuple): Range of values for traversal
  • steps (int): Number of traversal steps
  • rngs (Rngs, optional): Random number generators

Returns:

  • Array: Decoded outputs from traversal with shape (steps, *output_shape)

Raises:

  • ValueError: If dimension is out of range

Example:

# Traverse dimension 5
traversal = vae.latent_traversal(
    x=test_image,
    dim=5,
    range_vals=(-3.0, 3.0),
    steps=15,
    rngs=rngs,
)

VAE Variants¤

BetaVAE¤

workshop.generative_models.models.vae.beta_vae.BetaVAE ¤

BetaVAE(config: BetaVAEConfig, *, rngs: Rngs)

Bases: VAE

Beta Variational Autoencoder implementation.

Beta-VAE is a variant of VAE that introduces a hyperparameter beta to the KL divergence term in the loss function. This allows for better control over the disentanglement of latent representations.

By setting beta > 1, the model is encouraged to learn more disentangled representations, but potentially at the cost of reconstruction quality.

Parameters:

Name Type Description Default
config BetaVAEConfig

BetaVAEConfig with encoder, decoder, encoder_type, and beta settings

required
rngs Rngs

Random number generator for initialization

required

beta_default instance-attribute ¤

beta_default = beta_default

beta_warmup_steps instance-attribute ¤

beta_warmup_steps = beta_warmup_steps

reconstruction_loss_type instance-attribute ¤

reconstruction_loss_type = reconstruction_loss_type

loss_fn ¤

loss_fn(params: dict | None = None, batch: dict | None = None, rng: Array | None = None, x: Array | None = None, outputs: dict[str, Array] | None = None, beta: float | None = None, **kwargs) -> dict[str, Array]

Calculate loss for BetaVAE.

Parameters:

Name Type Description Default
params dict | None

Model parameters (optional, for compatibility with Trainer)

None
batch dict | None

Input batch (optional, for compatibility with Trainer)

None
rng Array | None

Random number generator (optional, for Trainer compatibility)

None
x Array | None

Input data (if not provided in batch)

None
outputs dict[str, Array] | None

Dictionary of model outputs from forward pass

None
beta float | None

Weight for KL divergence term. If None, uses model's beta_default with optional warmup.

None
**kwargs

Additional arguments including current training step

{}

Returns:

Type Description
dict[str, Array]

Dictionary of loss components

β-VAE for learning disentangled representations.

Class Definition¤

class BetaVAE(VAE):
    """Beta Variational Autoencoder for disentanglement."""

    def __init__(
        self,
        encoder: nnx.Module,
        decoder: nnx.Module,
        latent_dim: int,
        beta_default: float = 1.0,
        beta_warmup_steps: int = 0,
        reconstruction_loss_type: str = "mse",
        *,
        rngs: nnx.Rngs,
    ) -> None:
        """Initialize a BetaVAE."""

Parameters¤

Parameter Type Default Description
encoder nnx.Module Required Encoder network
decoder nnx.Module Required Decoder network
latent_dim int Required Latent space dimension
beta_default float 1.0 Default β value for KL weighting
beta_warmup_steps int 0 Steps for β annealing (0 = no annealing)
reconstruction_loss_type str "mse" Loss type: "mse" or "bce"
rngs nnx.Rngs Required Random number generators

Key Differences from VAE¤

Modified Loss Function:

\[ \mathcal{L}_\beta = \mathbb{E}_{q(z|x)}[\log p(x|z)] + \beta \cdot D_{KL}(q(z|x) \| p(z)) \]

Beta Annealing:

When beta_warmup_steps > 0, β increases linearly from 0 to beta_default:

\[ \beta(t) = \min\left(\beta_{\text{default}}, \beta_{\text{default}} \cdot \frac{t}{T_{\text{warmup}}}\right) \]

Example¤

beta_vae = BetaVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=4.0,              # Higher β encourages disentanglement
    beta_warmup_steps=10000,       # Gradually increase β
    reconstruction_loss_type="mse",
    rngs=rngs,
)

# Training step with beta annealing
for step in range(num_steps):
    outputs = beta_vae(batch, rngs=rngs)
    losses = beta_vae.loss_fn(x=batch, outputs=outputs, step=step)
    # losses['beta'] contains current β value

BetaVAEWithCapacity¤

workshop.generative_models.models.vae.beta_vae.BetaVAEWithCapacity ¤

BetaVAEWithCapacity(config: BetaVAEWithCapacityConfig, *, rngs: Rngs)

Bases: BetaVAE

Beta-VAE with Burgess et al. capacity control.

Parameters:

Name Type Description Default
config BetaVAEWithCapacityConfig

BetaVAEWithCapacityConfig with encoder, decoder, beta, and capacity settings

required
rngs Rngs

Random number generator for initialization

required

use_capacity_control instance-attribute ¤

use_capacity_control = use_capacity_control

capacity_max instance-attribute ¤

capacity_max = capacity_max

capacity_num_iter instance-attribute ¤

capacity_num_iter = capacity_num_iter

gamma instance-attribute ¤

gamma = gamma

loss_fn ¤

loss_fn(params: dict | None = None, batch: dict | None = None, rng: Array | None = None, x: Array | None = None, outputs: dict[str, Array] | None = None, beta: float | None = None, step: int = 0, **kwargs) -> dict[str, Array]

Calculate loss with optional capacity control.

β-VAE with Burgess et al. capacity control mechanism.

Class Definition¤

class BetaVAEWithCapacity(BetaVAE):
    """β-VAE with capacity control."""

    def __init__(
        self,
        encoder: nnx.Module,
        decoder: nnx.Module,
        latent_dim: int,
        beta_default: float = 1.0,
        beta_warmup_steps: int = 0,
        reconstruction_loss_type: str = "mse",
        use_capacity_control: bool = False,
        capacity_max: float = 25.0,
        capacity_num_iter: int = 25000,
        gamma: float = 1000.0,
        *,
        rngs: nnx.Rngs,
    ) -> None:
        """Initialize BetaVAE with capacity control."""

Additional Parameters¤

Parameter Type Default Description
use_capacity_control bool False Enable capacity control
capacity_max float 25.0 Maximum capacity in nats
capacity_num_iter int 25000 Steps to reach max capacity
gamma float 1000.0 Weight for capacity loss

Capacity Loss¤

\[ \mathcal{L}_C = \mathbb{E}_{q(z|x)}[\log p(x|z)] + \gamma \cdot |D_{KL}(q(z|x) \| p(z)) - C(t)| \]

Where capacity \(C(t)\) increases linearly:

\[ C(t) = \min\left(C_{\max}, C_{\max} \cdot \frac{t}{T_{\text{capacity}}}\right) \]

Example¤

capacity_vae = BetaVAEWithCapacity(
    encoder=encoder,
    decoder=decoder,
    latent_dim=32,
    beta_default=4.0,
    use_capacity_control=True,
    capacity_max=25.0,
    capacity_num_iter=25000,
    gamma=1000.0,
    rngs=rngs,
)

# Loss includes capacity terms
losses = capacity_vae.loss_fn(x=batch, outputs=outputs, step=step)
print(losses['current_capacity'])    # Current C value
print(losses['capacity_loss'])       # γ * |KL - C|
print(losses['kl_capacity_diff'])    # KL - C

ConditionalVAE¤

workshop.generative_models.models.vae.conditional.ConditionalVAE ¤

ConditionalVAE(config: ConditionalVAEConfig, *, rngs: Rngs, precision: Precision | None = None)

Bases: VAE

Conditional Variational Autoencoder implementation.

Extends the base VAE with conditioning capabilities by incorporating additional information at both encoding and decoding steps. This follows the standard CVAE pattern using one-hot concatenation.

Parameters:

Name Type Description Default
config ConditionalVAEConfig

ConditionalVAEConfig with encoder, decoder, encoder_type, conditioning settings

required
rngs Rngs

Random number generators

required
precision Precision | None

Numerical precision for computations

None

rngs instance-attribute ¤

rngs = rngs

precision instance-attribute ¤

precision = precision

latent_dim instance-attribute ¤

latent_dim = latent_dim

kl_weight instance-attribute ¤

kl_weight = kl_weight

condition_dim instance-attribute ¤

condition_dim = condition_dim

condition_type instance-attribute ¤

condition_type = condition_type

encoder instance-attribute ¤

encoder = create_encoder(encoder, encoder_type, conditional=True, num_classes=condition_dim, rngs=rngs)

decoder instance-attribute ¤

decoder = create_decoder(decoder, encoder_type, conditional=True, num_classes=condition_dim, rngs=rngs)

encode ¤

encode(x: Array, y: Array | None = None) -> tuple[Array, Array]

Encode input to the latent space with conditioning.

Parameters:

Name Type Description Default
x Array

Input data

required
y Array | None

Conditioning information (optional)

None

Returns:

Type Description
tuple[Array, Array]

Tuple of (mean, log_var) of the latent distribution

decode ¤

decode(z: Array, y: Array | None = None) -> Array

Decode latent vectors with conditioning.

Parameters:

Name Type Description Default
z Array

Latent vectors

required
y Array | None

Conditioning information (optional)

None

Returns:

Type Description
Array

Reconstructed output

sample ¤

sample(n_samples: int = 1, *, temperature: float = 1.0, y: Array | None = None, **kwargs) -> Array

Sample from the model with conditioning.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
temperature float

Temperature parameter controlling randomness

1.0
y Array | None

Conditioning information (optional)

None
**kwargs

Additional arguments for compatibility

{}

Returns:

Type Description
Array

Generated samples

generate ¤

generate(n_samples: int = 1, *, temperature: float = 1.0, y: Array | None = None, **kwargs) -> Array

Generate samples from the model with conditioning.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
temperature float

Temperature parameter controlling randomness

1.0
y Array | None

Conditioning information (optional)

None
**kwargs

Additional arguments

{}

Returns:

Type Description
Array

Generated samples

reconstruct ¤

reconstruct(x: Array, *, y: Array | None = None, deterministic: bool = False) -> Array

Reconstruct inputs with conditioning.

Parameters:

Name Type Description Default
x Array

Input data

required
y Array | None

Conditioning information (optional)

None
deterministic bool

If True, use mean instead of sampling

False

Returns:

Type Description
Array

Reconstructed outputs

Conditional VAE for class-conditional generation.

Class Definition¤

class ConditionalVAE(VAE):
    """Conditional Variational Autoencoder."""

    def __init__(
        self,
        encoder: nnx.Module,
        decoder: nnx.Module,
        latent_dim: int,
        condition_dim: int = 10,
        condition_type: str = "concat",
        *,
        rngs: nnx.Rngs | None = None,
    ):
        """Initialize Conditional VAE."""

Parameters¤

Parameter Type Default Description
encoder nnx.Module Required Conditional encoder
decoder nnx.Module Required Conditional decoder
latent_dim int Required Latent dimension
condition_dim int 10 Conditioning dimension
condition_type str "concat" Conditioning strategy
rngs nnx.Rngs None Random number generators

Modified Methods¤

__call__¤
def __call__(
    self,
    x: jax.Array,
    y: jax.Array | None = None,
    *,
    rngs: nnx.Rngs | None = None,
) -> dict[str, Any]:
    """Forward pass with conditioning."""
encode¤
def encode(
    self,
    x: jax.Array,
    y: jax.Array | None = None,
    *,
    rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, jax.Array]:
    """Encode with conditioning."""
decode¤
def decode(
    self,
    z: jax.Array,
    y: jax.Array | None = None,
    *,
    rngs: nnx.Rngs | None = None
) -> jax.Array:
    """Decode with conditioning."""
sample¤
def sample(
    self,
    n_samples: int = 1,
    *,
    temperature: float = 1.0,
    y: jax.Array | None = None,
    rngs: nnx.Rngs | None = None,
) -> jax.Array:
    """Sample with conditioning."""

Example¤

# Create conditional encoder/decoder
conditional_encoder = ConditionalEncoder(
    encoder=base_encoder,
    num_classes=10,
    embed_dim=128,
    rngs=rngs,
)

conditional_decoder = ConditionalDecoder(
    decoder=base_decoder,
    num_classes=10,
    embed_dim=128,
    rngs=rngs,
)

cvae = ConditionalVAE(
    encoder=conditional_encoder,
    decoder=conditional_decoder,
    latent_dim=32,
    condition_dim=10,
    rngs=rngs,
)

# Forward pass with class labels
labels = jnp.array([0, 1, 2, 3, 4])
outputs = cvae(x, y=labels, rngs=rngs)

# Generate specific classes
target_labels = jnp.array([5, 5, 5, 5])
samples = cvae.sample(n_samples=4, y=target_labels, rngs=rngs)

VQVAE¤

workshop.generative_models.models.vae.vq_vae.VQVAE ¤

VQVAE(config: VQVAEConfig, *, rngs: Rngs)

Bases: VAE

Vector Quantized Variational Autoencoder implementation.

Extends the base VAE with discrete latent variables using a codebook.

Parameters:

Name Type Description Default
config VQVAEConfig

VQVAEConfig with encoder, decoder, encoder_type, and VQ settings

required
rngs Rngs

Random number generators for initialization

required

num_embeddings instance-attribute ¤

num_embeddings = num_embeddings

embedding_dim instance-attribute ¤

embedding_dim = embedding_dim

commitment_cost instance-attribute ¤

commitment_cost = commitment_cost

embeddings instance-attribute ¤

embeddings = Embed(num_embeddings=num_embeddings, features=embedding_dim, rngs=rngs)

quantize ¤

quantize(encoding: Array) -> tuple[Array, dict[str, Array]]

Quantize the input using the codebook.

Parameters:

Name Type Description Default
encoding Array

Continuous encoding from the encoder

required

Returns:

Type Description
tuple[Array, dict[str, Array]]

Tuple of (quantized encoding, auxiliary dict containing losses and indices)

encode ¤

encode(x: Array) -> tuple[Array, Array]

Encode input to continuous latent representation (before quantization).

Parameters:

Name Type Description Default
x Array

Input data, shape [batch_size, ...]

required

Returns:

Type Description
Array

Tuple of (mean, log_var) for VAE interface compatibility.

Array

For VQ-VAE, log_var is zeros since encoding is deterministic.

decode ¤

decode(z: Array) -> Array

Decode latent representation to reconstruction.

Parameters:

Name Type Description Default
z Array

Quantized representation

required

Returns:

Type Description
Array

Reconstructed output

loss_fn ¤

loss_fn(x: Array, outputs: dict[str, Array], **kwargs: Any) -> dict[str, Array]

Compute the VQ-VAE loss.

Parameters:

Name Type Description Default
x Array

Input data

required
outputs dict[str, Array]

Dictionary of model outputs from forward pass

required
**kwargs Any

Additional arguments

{}

Returns:

Type Description
dict[str, Array]

Dictionary of loss components

sample ¤

sample(n_samples: int = 1, temperature: float = 1.0, **kwargs) -> Array

Sample from the model. Args: n_samples: Number of samples to generate temperature: Temperature parameter controlling randomness (higher = more random) **kwargs: Additional keyword arguments for compatibility

Returns:

Type Description
Array

Generated samples

generate ¤

generate(n_samples: int = 1, *, temperature: float = 1.0, **kwargs) -> Array

Generate samples from the model.

Implements the required method from GenerativeModel base class.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate

1
temperature float

Temperature parameter controlling randomness

1.0
**kwargs

Additional keyword arguments

{}

Returns:

Type Description
Array

Generated samples

Vector Quantized VAE with discrete latent representations.

Class Definition¤

class VQVAE(VAE):
    """Vector Quantized Variational Autoencoder."""

    def __init__(
        self,
        encoder: nnx.Module,
        decoder: nnx.Module,
        latent_dim: int,
        num_embeddings: int = 512,
        embedding_dim: int = 64,
        commitment_cost: float = 0.25,
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize VQ-VAE."""

Parameters¤

Parameter Type Default Description
encoder nnx.Module Required Encoder network
decoder nnx.Module Required Decoder network
latent_dim int Required Latent dimension
num_embeddings int 512 Codebook size (number of embeddings)
embedding_dim int 64 Dimension of each embedding
commitment_cost float 0.25 Weight for commitment loss
rngs nnx.Rngs Required Random number generators

Key Method: quantize¤

def quantize(
    self,
    encoding: jax.Array,
    *,
    rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, dict[str, Any]]:
    """Quantize continuous encoding using codebook."""

Parameters:

  • encoding (Array): Continuous encoding from encoder
  • rngs (Rngs, optional): Random number generators

Returns:

  • tuple: (quantized encoding, auxiliary info)
  • quantized (Array): Discrete quantized vectors
  • aux (dict): Contains commitment_loss, codebook_loss, encoding_indices

Quantization Process:

  1. Find nearest codebook embedding for each encoding vector
  2. Replace encoding with codebook embedding
  3. Use straight-through estimator for gradients

Loss Function¤

VQ-VAE uses a specialized loss:

\[ \mathcal{L}_{VQ} = \|x - \hat{x}\|^2 + \|sg[z_e] - e\|^2 + \beta \|z_e - sg[e]\|^2 \]

Where:

  • First term: Reconstruction loss
  • Second term: Codebook loss (update embeddings)
  • Third term: Commitment loss (encourage encoder to commit)

Example¤

vqvae = VQVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=64,
    num_embeddings=512,    # 512 discrete codes
    embedding_dim=64,
    commitment_cost=0.25,
    rngs=rngs,
)

# Forward pass includes quantization
outputs = vqvae(x, rngs=rngs)
print(outputs['z_e'])              # Pre-quantization encoding
print(outputs['quantized'])        # Quantized (discrete) encoding
print(outputs['commitment_loss'])  # Commitment loss component
print(outputs['codebook_loss'])    # Codebook loss component

# Loss includes VQ-specific terms
losses = vqvae.loss_fn(x=batch, outputs=outputs)
print(losses['reconstruction_loss'])
print(losses['commitment_loss'])
print(losses['codebook_loss'])

Encoders¤

MLPEncoder¤

workshop.generative_models.models.vae.encoders.MLPEncoder ¤

MLPEncoder(config: EncoderConfig, *, rngs: Rngs)

Bases: Module

Simple MLP encoder for VAE.

Parameters:

Name Type Description Default
config EncoderConfig

EncoderConfig with hidden_dims, latent_dim, activation, input_shape

required
rngs Rngs

Random number generator

required

latent_dim instance-attribute ¤

latent_dim = latent_dim

backbone instance-attribute ¤

backbone = MLP(hidden_dims=hidden_dims, activation=activation, in_features=input_features, rngs=rngs)

mean_proj instance-attribute ¤

mean_proj = Linear(in_features=hidden_dims[-1], out_features=latent_dim, rngs=rngs)

logvar_proj instance-attribute ¤

logvar_proj = Linear(in_features=hidden_dims[-1], out_features=latent_dim, rngs=rngs)

Fully-connected encoder for flattened inputs.

Class Definition¤

class MLPEncoder(nnx.Module):
    """MLP encoder for VAE."""

    def __init__(
        self,
        hidden_dims: list,
        latent_dim: int,
        activation: str = "relu",
        input_dim: tuple | None = None,
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize MLP encoder."""

Parameters¤

Parameter Type Default Description
hidden_dims list[int] Required Hidden layer dimensions
latent_dim int Required Latent space dimension
activation str "relu" Activation function
input_dim tuple \| None None Input dimensions for shape inference
rngs nnx.Rngs Required Random number generators

Example¤

encoder = MLPEncoder(
    hidden_dims=[512, 256, 128],
    latent_dim=32,
    activation="relu",
    input_dim=(784,),
    rngs=rngs,
)

mean, log_var = encoder(x, rngs=rngs)

CNNEncoder¤

workshop.generative_models.models.vae.encoders.CNNEncoder ¤

CNNEncoder(config: EncoderConfig, *, rngs: Rngs)

Bases: Module

CNN-based encoder for VAE.

Parameters:

Name Type Description Default
config EncoderConfig

EncoderConfig with hidden_dims, latent_dim, activation, input_shape

required
rngs Rngs

Random number generator

required

latent_dim instance-attribute ¤

latent_dim = latent_dim

cnn instance-attribute ¤

cnn = CNN(hidden_dims=hidden_dims, activation=activation, in_features=in_channels, rngs=rngs)

flatten instance-attribute ¤

flatten = Flatten(rngs=rngs)

mean_proj instance-attribute ¤

mean_proj = Linear(in_features=flattened_size, out_features=latent_dim, rngs=rngs)

logvar_proj instance-attribute ¤

logvar_proj = Linear(in_features=flattened_size, out_features=latent_dim, rngs=rngs)

Convolutional encoder for image inputs.

Class Definition¤

class CNNEncoder(nnx.Module):
    """CNN encoder for VAE."""

    def __init__(
        self,
        hidden_dims: list,
        latent_dim: int,
        activation: str = "relu",
        input_dim: tuple | None = None,
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize CNN encoder."""

Parameters¤

Parameter Type Default Description
hidden_dims list[int] Required Channel dimensions for conv layers
latent_dim int Required Latent space dimension
activation str "relu" Activation function
input_dim tuple \| None None Input shape (H, W, C)
rngs nnx.Rngs Required Random number generators

Architecture¤

  • Series of convolutional layers with stride 2
  • Each layer reduces spatial dimensions by half
  • Global pooling before projecting to latent space

Example¤

encoder = CNNEncoder(
    hidden_dims=[32, 64, 128, 256],
    latent_dim=64,
    activation="relu",
    input_dim=(28, 28, 1),
    rngs=rngs,
)

# Input shape: (batch, 28, 28, 1)
mean, log_var = encoder(images, rngs=rngs)
# Output shapes: (batch, 64), (batch, 64)

ConditionalEncoder¤

workshop.generative_models.models.vae.encoders.ConditionalEncoder ¤

ConditionalEncoder(encoder: Module, num_classes: int, *, rngs: Rngs)

Bases: Module

Wrapper that adds conditioning to any encoder.

This wrapper adds class conditioning to an existing encoder by converting labels to one-hot and concatenating them with the input. This follows the standard CVAE pattern from popular implementations.

Parameters:

Name Type Description Default
encoder Module

Base encoder to wrap

required
num_classes int

Number of classes for conditioning

required
rngs Rngs

Random number generators (for API consistency)

required

encoder instance-attribute ¤

encoder = encoder

num_classes instance-attribute ¤

num_classes = num_classes

Wrapper that adds conditioning to any encoder.

Class Definition¤

class ConditionalEncoder(nnx.Module):
    """Conditional encoder wrapper."""

    def __init__(
        self,
        encoder: nnx.Module,
        num_classes: int,
        embed_dim: int,
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize conditional encoder."""

Parameters¤

Parameter Type Default Description
encoder nnx.Module Required Base encoder to wrap
num_classes int Required Number of conditioning classes
embed_dim int Required Embedding dimension for labels
rngs nnx.Rngs Required Random number generators

Example¤

base_encoder = MLPEncoder(
    hidden_dims=[512, 256],
    latent_dim=32,
    rngs=rngs,
)

conditional_encoder = ConditionalEncoder(
    encoder=base_encoder,
    num_classes=10,
    embed_dim=128,
    rngs=rngs,
)

# Pass class labels as integers or one-hot
labels = jnp.array([0, 1, 2, 3])
mean, log_var = conditional_encoder(x, condition=labels, rngs=rngs)

Decoders¤

MLPDecoder¤

workshop.generative_models.models.vae.decoders.MLPDecoder ¤

MLPDecoder(config: DecoderConfig, *, rngs: Rngs)

Bases: Module

Simple MLP decoder for VAE.

Parameters:

Name Type Description Default
config DecoderConfig

DecoderConfig with hidden_dims, output_shape, latent_dim, activation

required
rngs Rngs

Random number generator

required

output_shape instance-attribute ¤

output_shape = output_shape

latent_dim instance-attribute ¤

latent_dim = latent_dim

output_size instance-attribute ¤

output_size = 1

backbone instance-attribute ¤

backbone = MLP(hidden_dims=decoder_dims, activation=activation, in_features=latent_dim, rngs=rngs)

output_layer instance-attribute ¤

output_layer = Linear(in_features=decoder_dims[-1], out_features=output_size, rngs=rngs)

activation instance-attribute ¤

activation = SigmoidActivation

Fully-connected decoder.

Class Definition¤

class MLPDecoder(nnx.Module):
    """MLP decoder for VAE."""

    def __init__(
        self,
        hidden_dims: list[int],
        output_dim: tuple[int, ...],
        latent_dim: int,
        activation: str = "relu",
        *,
        rngs: nnx.Rngs | None = None,
    ):
        """Initialize MLP decoder."""

Parameters¤

Parameter Type Default Description
hidden_dims list[int] Required Hidden layer dimensions (reversed from encoder)
output_dim tuple Required Output reconstruction shape
latent_dim int Required Latent space dimension
activation str "relu" Activation function
rngs nnx.Rngs Required Random number generators

Example¤

decoder = MLPDecoder(
    hidden_dims=[128, 256, 512],  # Reversed from encoder
    output_dim=(784,),
    latent_dim=32,
    activation="relu",
    rngs=rngs,
)

reconstructed = decoder(z)  # Shape: (batch, 784)

CNNDecoder¤

workshop.generative_models.models.vae.decoders.CNNDecoder ¤

CNNDecoder(config: DecoderConfig, *, rngs: Rngs)

Bases: Module

CNN-based decoder for VAE.

Parameters:

Name Type Description Default
config DecoderConfig

DecoderConfig with hidden_dims, output_shape, latent_dim, activation

required
rngs Rngs

Random number generator

required

output_shape instance-attribute ¤

output_shape = output_shape

latent_dim instance-attribute ¤

latent_dim = latent_dim

initial_h instance-attribute ¤

initial_h = max(h // 2 ** len(hidden_dims), 1)

initial_w instance-attribute ¤

initial_w = max(w // 2 ** len(hidden_dims), 1)

initial_features instance-attribute ¤

initial_features = hidden_dims[0]

project instance-attribute ¤

project = Linear(in_features=latent_dim, out_features=initial_size, rngs=rngs)

cnn instance-attribute ¤

cnn = CNN(hidden_dims=decoder_dims, activation=activation, use_transpose=True, in_features=initial_features, rngs=rngs)

output_conv instance-attribute ¤

output_conv = Conv(in_features=final_channels, out_features=output_shape[2], kernel_size=(3, 3), padding='SAME', rngs=rngs)

activation instance-attribute ¤

activation = SigmoidActivation

Transposed convolutional decoder for images.

Class Definition¤

class CNNDecoder(nnx.Module):
    """CNN decoder with transposed convolutions."""

    def __init__(
        self,
        hidden_dims: list[int],
        output_dim: tuple[int, ...],
        latent_dim: int,
        activation: str = "relu",
        *,
        rngs: nnx.Rngs | None = None,
    ):
        """Initialize CNN decoder."""

Parameters¤

Parameter Type Default Description
hidden_dims list[int] Required Channel dimensions (reversed from encoder)
output_dim tuple Required Output shape (H, W, C)
latent_dim int Required Latent dimension
activation str "relu" Activation function
rngs nnx.Rngs Required Random number generators

Example¤

decoder = CNNDecoder(
    hidden_dims=[256, 128, 64, 32],  # Reversed channels
    output_dim=(28, 28, 1),
    latent_dim=64,
    activation="relu",
    rngs=rngs,
)

# Input: (batch, 64)
reconstructed = decoder(z)  # Output: (batch, 28, 28, 1)

ConditionalDecoder¤

workshop.generative_models.models.vae.decoders.ConditionalDecoder ¤

ConditionalDecoder(decoder: Module, num_classes: int, *, rngs: Rngs)

Bases: Module

Wrapper that adds conditioning to any decoder.

This wrapper adds class conditioning to an existing decoder by converting labels to one-hot and concatenating them with the latent vector. This follows the standard CVAE pattern from popular implementations.

Parameters:

Name Type Description Default
decoder Module

Base decoder to wrap

required
num_classes int

Number of classes for conditioning

required
rngs Rngs

Random number generators

required

decoder instance-attribute ¤

decoder = decoder

num_classes instance-attribute ¤

num_classes = num_classes

Wrapper that adds conditioning to any decoder.

Class Definition¤

class ConditionalDecoder(nnx.Module):
    """Conditional decoder wrapper."""

    def __init__(
        self,
        decoder: nnx.Module,
        num_classes: int,
        embed_dim: int,
        *,
        rngs: nnx.Rngs,
    ):
        """Initialize conditional decoder."""

Parameters¤

Parameter Type Default Description
decoder nnx.Module Required Base decoder to wrap
num_classes int Required Number of conditioning classes
embed_dim int Required Embedding dimension
rngs nnx.Rngs Required Random number generators

Example¤

base_decoder = MLPDecoder(
    hidden_dims=[128, 256, 512],
    output_dim=(784,),
    latent_dim=32,
    rngs=rngs,
)

conditional_decoder = ConditionalDecoder(
    decoder=base_decoder,
    num_classes=10,
    embed_dim=128,
    rngs=rngs,
)

labels = jnp.array([0, 1, 2, 3])
reconstructed = conditional_decoder(z, condition=labels, rngs=rngs)

Utility Functions¤

create_encoder_unified¤

def create_encoder_unified(
    *,
    config: ModelConfiguration,
    encoder_type: str,
    conditional: bool = False,
    num_classes: int | None = None,
    embed_dim: int | None = None,
    rngs: nnx.Rngs,
) -> nnx.Module:
    """Create encoder from unified configuration."""

Parameters:

  • config (ModelConfiguration): Model configuration
  • encoder_type (str): Type of encoder: "dense", "cnn", "resnet"
  • conditional (bool): Whether to wrap in conditional encoder
  • num_classes (int, optional): Number of classes for conditioning
  • embed_dim (int, optional): Embedding dimension
  • rngs (Rngs): Random number generators

Returns:

  • nnx.Module: Configured encoder

create_decoder_unified¤

def create_decoder_unified(
    *,
    config: ModelConfiguration,
    decoder_type: str,
    conditional: bool = False,
    num_classes: int | None = None,
    embed_dim: int | None = None,
    rngs: nnx.Rngs,
) -> nnx.Module:
    """Create decoder from unified configuration."""

Parameters:

  • config (ModelConfiguration): Model configuration
  • decoder_type (str): Type of decoder: "dense", "cnn", "resnet"
  • conditional (bool): Whether to wrap in conditional decoder
  • num_classes (int, optional): Number of classes for conditioning
  • embed_dim (int, optional): Embedding dimension
  • rngs (Rngs): Random number generators

Returns:

  • nnx.Module: Configured decoder

Complete Example¤

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.vae import BetaVAE
from workshop.generative_models.models.vae.encoders import CNNEncoder
from workshop.generative_models.models.vae.decoders import CNNDecoder

# Initialize
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Create encoder
encoder = CNNEncoder(
    hidden_dims=[32, 64, 128, 256],
    latent_dim=64,
    activation="relu",
    input_dim=(28, 28, 1),
    rngs=rngs,
)

# Create decoder
decoder = CNNDecoder(
    hidden_dims=[256, 128, 64, 32],
    output_dim=(28, 28, 1),
    latent_dim=64,
    activation="relu",
    rngs=rngs,
)

# Create β-VAE
model = BetaVAE(
    encoder=encoder,
    decoder=decoder,
    latent_dim=64,
    beta_default=4.0,
    beta_warmup_steps=10000,
    reconstruction_loss_type="mse",
    rngs=rngs,
)

# Forward pass
x = jnp.ones((32, 28, 28, 1))
outputs = model(x, rngs=rngs)

# Calculate loss
losses = model.loss_fn(x=x, outputs=outputs, step=5000)
print(f"Total Loss: {losses['total_loss']:.4f}")
print(f"Beta: {losses['beta']:.4f}")

# Generate samples
samples = model.sample(n_samples=16, temperature=1.0, rngs=rngs)

# Latent traversal
traversal = model.latent_traversal(x[0], dim=10, steps=15, rngs=rngs)

See Also¤