VAE API Reference¤
Complete API documentation for Variational Autoencoder models in Workshop.
Module Overview¤
from workshop.generative_models.models.vae import (
VAE, # Base VAE class
BetaVAE, # β-VAE with disentanglement
BetaVAEWithCapacity, # β-VAE with capacity control
ConditionalVAE, # Conditional VAE
VQVAE, # Vector Quantized VAE
)
from workshop.generative_models.models.vae.encoders import (
MLPEncoder, # Fully-connected encoder
CNNEncoder, # Convolutional encoder
ResNetEncoder, # ResNet-based encoder
ConditionalEncoder, # Conditional wrapper
)
from workshop.generative_models.models.vae.decoders import (
MLPDecoder, # Fully-connected decoder
CNNDecoder, # Transposed convolutional decoder
ResNetDecoder, # ResNet-based decoder
ConditionalDecoder, # Conditional wrapper
)
Base Classes¤
VAE¤
workshop.generative_models.models.vae.base.VAE
¤
Bases: GenerativeModel
Base class for Variational Autoencoders.
This class provides a foundation for implementing various VAE models using Flax NNX. All VAE models should inherit from this class and implement the required methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VAEConfig
|
VAEConfig with encoder, decoder, encoder_type, and kl_weight settings |
required |
rngs
|
Rngs
|
Random number generators |
required |
precision
|
Precision | None
|
Numerical precision for computations |
None
|
encode
¤
Encode input to the latent space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input data |
required |
Returns:
| Type | Description |
|---|---|
tuple[Array, Array]
|
Tuple of (mean, log_var) of the latent distribution |
Raises:
| Type | Description |
|---|---|
ValueError
|
If encoder output format is invalid |
decode
¤
Decode latent vectors to the output space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Array
|
Latent vectors |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Reconstructed outputs |
Raises:
| Type | Description |
|---|---|
ValueError
|
If decoder output format is invalid |
reparameterize
¤
loss_fn
¤
loss_fn(params: dict | None = None, batch: dict | None = None, rng: Array | None = None, x: Array | None = None, outputs: dict[str, Array] | None = None, beta: float | None = None, reconstruction_loss_fn: Callable | None = None, **kwargs) -> dict[str, Array]
Calculate loss for VAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
dict | None
|
Model parameters (optional, for compatibility with Trainer) |
None
|
batch
|
dict | None
|
Input batch (optional, for compatibility with Trainer) |
None
|
rng
|
Array | None
|
Random number generator (optional, for compatibility with Trainer) |
None
|
x
|
Array | None
|
Input data (if not provided in batch) |
None
|
outputs
|
dict[str, Array] | None
|
Dictionary of model outputs from forward pass |
None
|
beta
|
float | None
|
Weight for KL divergence term |
None
|
reconstruction_loss_fn
|
Callable | None
|
Optional custom reconstruction loss function |
None
|
**kwargs
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary of loss components |
sample
¤
Sample from the prior distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate. Must be a Python int (static value). When JIT-compiling, mark this as a static argument. |
1
|
temperature
|
float
|
Scaling factor for the standard deviation (higher = more diverse) |
1.0
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
Note
When JIT-compiling functions that call this method, mark n_samples as static:
Example: @nnx.jit(static_argnums=(1,)) or @jax.jit(static_argnums=(1,))
reconstruct
¤
generate
¤
Generate samples from the model.
This is an alias for the sample method to maintain consistency with the GenerativeModel interface.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
temperature
|
float
|
Scaling factor for the standard deviation (higher = more diverse) |
1.0
|
**kwargs
|
Additional arguments (unused, for compatibility) |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
interpolate
¤
Interpolate between two inputs in latent space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x1
|
Array
|
First input |
required |
x2
|
Array
|
Second input |
required |
steps
|
int
|
Number of interpolation steps (including endpoints). Must be a Python int (static value) when JIT-compiling. |
10
|
Returns:
| Type | Description |
|---|---|
Array
|
Interpolated outputs |
Note
When JIT-compiling functions that call this method, mark steps as static:
Example: @nnx.jit(static_argnums=(3,)) for the third argument.
latent_traversal
¤
latent_traversal(x: Array, dim: int, range_vals: tuple[float, float] = (-3.0, 3.0), steps: int = 10) -> Array
Traverse a single dimension of the latent space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input data |
required |
dim
|
int
|
Dimension to traverse |
required |
range_vals
|
tuple[float, float]
|
Range of values for traversal |
(-3.0, 3.0)
|
steps
|
int
|
Number of steps in the traversal |
10
|
Returns:
| Type | Description |
|---|---|
Array
|
Decoded outputs from the traversal |
Raises:
| Type | Description |
|---|---|
ValueError
|
If dim is out of range |
The base VAE class implementing standard Variational Autoencoder functionality.
Class Definition¤
class VAE(GenerativeModel):
"""Base class for Variational Autoencoders."""
def __init__(
self,
encoder: nnx.Module,
decoder: nnx.Module,
latent_dim: int,
*,
rngs: nnx.Rngs,
kl_weight: float = 1.0,
precision: jax.lax.Precision | None = None,
) -> None:
"""Initialize a VAE."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder |
nnx.Module |
Required | Encoder network mapping inputs to latent distributions |
decoder |
nnx.Module |
Required | Decoder network mapping latent codes to reconstructions |
latent_dim |
int |
Required | Dimensionality of the latent space (must be positive) |
rngs |
nnx.Rngs |
Required | Random number generators for initialization and sampling |
kl_weight |
float |
1.0 |
Weight for KL divergence term (β parameter) |
precision |
jax.lax.Precision \| None |
None |
Numerical precision for computations |
Methods¤
encode¤
def encode(
self,
x: jax.Array,
*,
rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, jax.Array]:
"""Encode input to latent distribution parameters."""
Parameters:
x(Array): Input data with shape(batch_size, *input_shape)rngs(Rngs, optional): Random number generators
Returns:
tuple[Array, Array]: Mean and log-variance of latent distributionmean: Shape(batch_size, latent_dim)log_var: Shape(batch_size, latent_dim)
Raises:
ValueError: If encoder output format is invalid
Example:
mean, log_var = vae.encode(x, rngs=rngs)
print(f"Mean shape: {mean.shape}") # (32, 20)
print(f"Log-var shape: {log_var.shape}") # (32, 20)
decode¤
def decode(
self,
z: jax.Array,
*,
rngs: nnx.Rngs | None = None
) -> jax.Array:
"""Decode latent vectors to reconstructions."""
Parameters:
z(Array): Latent vectors with shape(batch_size, latent_dim)rngs(Rngs, optional): Random number generators
Returns:
Array: Reconstructed outputs with shape(batch_size, *output_shape)
Example:
z = jax.random.normal(key, (32, 20))
reconstructed = vae.decode(z, rngs=rngs)
print(f"Reconstruction shape: {reconstructed.shape}") # (32, 784)
reparameterize¤
@nnx.jit
def reparameterize(
self,
mean: jax.Array,
log_var: jax.Array,
*,
rngs: nnx.Rngs | None = None
) -> jax.Array:
"""Apply the reparameterization trick."""
Parameters:
mean(Array): Mean vectors with shape(batch_size, latent_dim)log_var(Array): Log-variance vectors with shape(batch_size, latent_dim)rngs(Rngs, optional): Random number generators
Returns:
Array: Sampled latent vectors with shape(batch_size, latent_dim)
Details:
Implements the reparameterization trick: \(z = \mu + \sigma \odot \epsilon\) where \(\epsilon \sim \mathcal{N}(0, I)\)
Includes numerical stability via log-variance clipping to [-20, 20].
Example:
__call__¤
def __call__(
self,
x: jax.Array,
*,
rngs: nnx.Rngs | None = None
) -> dict[str, Any]:
"""Forward pass through the VAE."""
Parameters:
x(Array): Input data with shape(batch_size, *input_shape)rngs(Rngs, optional): Random number generators
Returns:
dict: Dictionary containing:reconstructed(Array): Reconstructed outputsreconstruction(Array): Alias for compatibilitymean(Array): Latent mean vectorslog_var(Array): Latent log-variance vectorslogvar(Array): Alias for compatibilityz(Array): Sampled latent vectors
Example:
outputs = vae(x, rngs=rngs)
print(outputs.keys())
# dict_keys(['reconstructed', 'reconstruction', 'mean', 'log_var', 'logvar', 'z'])
loss_fn¤
def loss_fn(
self,
params: dict | None = None,
batch: dict | None = None,
rng: jax.Array | None = None,
x: jax.Array | None = None,
outputs: dict[str, jax.Array] | None = None,
beta: float | None = None,
reconstruction_loss_fn: Callable | None = None,
**kwargs,
) -> dict[str, jax.Array]:
"""Calculate VAE loss (ELBO)."""
Parameters:
params(dict, optional): Model parameters (for Trainer compatibility)batch(dict, optional): Input batch (for Trainer compatibility)rng(Array, optional): Random number generatorx(Array, optional): Input data if not in batchoutputs(dict, optional): Pre-computed model outputsbeta(float, optional): KL divergence weight overridereconstruction_loss_fn(Callable, optional): Custom reconstruction loss**kwargs: Additional arguments
Returns:
dict: Dictionary containing:reconstruction_loss(Array): Reconstruction errorrecon_loss(Array): Alias for compatibilitykl_loss(Array): KL divergencetotal_loss(Array): Combined lossloss(Array): Alias for compatibility
Loss Formula:
Example:
outputs = vae(x, rngs=rngs)
losses = vae.loss_fn(x=x, outputs=outputs)
print(f"Reconstruction: {losses['reconstruction_loss']:.4f}")
print(f"KL Divergence: {losses['kl_loss']:.4f}")
print(f"Total Loss: {losses['total_loss']:.4f}")
sample¤
def sample(
self,
n_samples: int = 1,
*,
temperature: float = 1.0,
rngs: nnx.Rngs | None = None
) -> jax.Array:
"""Sample from the prior distribution."""
Parameters:
n_samples(int): Number of samples to generatetemperature(float): Scaling factor for sampling diversity (higher = more diverse)rngs(Rngs, optional): Random number generators
Returns:
Array: Generated samples with shape(n_samples, *output_shape)
Example:
# Generate 16 samples
samples = vae.sample(n_samples=16, temperature=1.0, rngs=rngs)
# More diverse samples
hot_samples = vae.sample(n_samples=16, temperature=2.0, rngs=rngs)
generate¤
def generate(
self,
n_samples: int = 1,
*,
temperature: float = 1.0,
rngs: nnx.Rngs | None = None,
**kwargs,
) -> jax.Array:
"""Generate samples (alias for sample)."""
Alias for sample() to maintain consistency with GenerativeModel interface.
reconstruct¤
def reconstruct(
self,
x: jax.Array,
*,
deterministic: bool = False,
rngs: nnx.Rngs | None = None
) -> jax.Array:
"""Reconstruct inputs."""
Parameters:
x(Array): Input datadeterministic(bool): If True, use mean instead of samplingrngs(Rngs, optional): Random number generators
Returns:
Array: Reconstructed outputs
Example:
# Stochastic reconstruction
recon = vae.reconstruct(x, deterministic=False, rngs=rngs)
# Deterministic reconstruction (use latent mean)
det_recon = vae.reconstruct(x, deterministic=True, rngs=rngs)
interpolate¤
def interpolate(
self,
x1: jax.Array,
x2: jax.Array,
steps: int = 10,
*,
rngs: nnx.Rngs | None = None,
) -> jax.Array:
"""Interpolate between two inputs in latent space."""
Parameters:
x1(Array): First inputx2(Array): Second inputsteps(int): Number of interpolation steps (including endpoints)rngs(Rngs, optional): Random number generators
Returns:
Array: Interpolated outputs with shape(steps, *output_shape)
Example:
x1 = test_images[0]
x2 = test_images[1]
interpolation = vae.interpolate(x1, x2, steps=10, rngs=rngs)
latent_traversal¤
def latent_traversal(
self,
x: jax.Array,
dim: int,
range_vals: tuple[float, float] = (-3.0, 3.0),
steps: int = 10,
*,
rngs: nnx.Rngs | None = None,
) -> jax.Array:
"""Traverse a single latent dimension."""
Parameters:
x(Array): Input datadim(int): Dimension to traverse (0 to latent_dim-1)range_vals(tuple): Range of values for traversalsteps(int): Number of traversal stepsrngs(Rngs, optional): Random number generators
Returns:
Array: Decoded outputs from traversal with shape(steps, *output_shape)
Raises:
ValueError: If dimension is out of range
Example:
# Traverse dimension 5
traversal = vae.latent_traversal(
x=test_image,
dim=5,
range_vals=(-3.0, 3.0),
steps=15,
rngs=rngs,
)
VAE Variants¤
BetaVAE¤
workshop.generative_models.models.vae.beta_vae.BetaVAE
¤
BetaVAE(config: BetaVAEConfig, *, rngs: Rngs)
Bases: VAE
Beta Variational Autoencoder implementation.
Beta-VAE is a variant of VAE that introduces a hyperparameter beta to the KL divergence term in the loss function. This allows for better control over the disentanglement of latent representations.
By setting beta > 1, the model is encouraged to learn more disentangled representations, but potentially at the cost of reconstruction quality.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BetaVAEConfig
|
BetaVAEConfig with encoder, decoder, encoder_type, and beta settings |
required |
rngs
|
Rngs
|
Random number generator for initialization |
required |
loss_fn
¤
loss_fn(params: dict | None = None, batch: dict | None = None, rng: Array | None = None, x: Array | None = None, outputs: dict[str, Array] | None = None, beta: float | None = None, **kwargs) -> dict[str, Array]
Calculate loss for BetaVAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
dict | None
|
Model parameters (optional, for compatibility with Trainer) |
None
|
batch
|
dict | None
|
Input batch (optional, for compatibility with Trainer) |
None
|
rng
|
Array | None
|
Random number generator (optional, for Trainer compatibility) |
None
|
x
|
Array | None
|
Input data (if not provided in batch) |
None
|
outputs
|
dict[str, Array] | None
|
Dictionary of model outputs from forward pass |
None
|
beta
|
float | None
|
Weight for KL divergence term. If None, uses model's beta_default with optional warmup. |
None
|
**kwargs
|
Additional arguments including current training step |
{}
|
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary of loss components |
β-VAE for learning disentangled representations.
Class Definition¤
class BetaVAE(VAE):
"""Beta Variational Autoencoder for disentanglement."""
def __init__(
self,
encoder: nnx.Module,
decoder: nnx.Module,
latent_dim: int,
beta_default: float = 1.0,
beta_warmup_steps: int = 0,
reconstruction_loss_type: str = "mse",
*,
rngs: nnx.Rngs,
) -> None:
"""Initialize a BetaVAE."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder |
nnx.Module |
Required | Encoder network |
decoder |
nnx.Module |
Required | Decoder network |
latent_dim |
int |
Required | Latent space dimension |
beta_default |
float |
1.0 |
Default β value for KL weighting |
beta_warmup_steps |
int |
0 |
Steps for β annealing (0 = no annealing) |
reconstruction_loss_type |
str |
"mse" |
Loss type: "mse" or "bce" |
rngs |
nnx.Rngs |
Required | Random number generators |
Key Differences from VAE¤
Modified Loss Function:
Beta Annealing:
When beta_warmup_steps > 0, β increases linearly from 0 to beta_default:
Example¤
beta_vae = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=4.0, # Higher β encourages disentanglement
beta_warmup_steps=10000, # Gradually increase β
reconstruction_loss_type="mse",
rngs=rngs,
)
# Training step with beta annealing
for step in range(num_steps):
outputs = beta_vae(batch, rngs=rngs)
losses = beta_vae.loss_fn(x=batch, outputs=outputs, step=step)
# losses['beta'] contains current β value
BetaVAEWithCapacity¤
workshop.generative_models.models.vae.beta_vae.BetaVAEWithCapacity
¤
BetaVAEWithCapacity(config: BetaVAEWithCapacityConfig, *, rngs: Rngs)
β-VAE with Burgess et al. capacity control mechanism.
Class Definition¤
class BetaVAEWithCapacity(BetaVAE):
"""β-VAE with capacity control."""
def __init__(
self,
encoder: nnx.Module,
decoder: nnx.Module,
latent_dim: int,
beta_default: float = 1.0,
beta_warmup_steps: int = 0,
reconstruction_loss_type: str = "mse",
use_capacity_control: bool = False,
capacity_max: float = 25.0,
capacity_num_iter: int = 25000,
gamma: float = 1000.0,
*,
rngs: nnx.Rngs,
) -> None:
"""Initialize BetaVAE with capacity control."""
Additional Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
use_capacity_control |
bool |
False |
Enable capacity control |
capacity_max |
float |
25.0 |
Maximum capacity in nats |
capacity_num_iter |
int |
25000 |
Steps to reach max capacity |
gamma |
float |
1000.0 |
Weight for capacity loss |
Capacity Loss¤
Where capacity \(C(t)\) increases linearly:
Example¤
capacity_vae = BetaVAEWithCapacity(
encoder=encoder,
decoder=decoder,
latent_dim=32,
beta_default=4.0,
use_capacity_control=True,
capacity_max=25.0,
capacity_num_iter=25000,
gamma=1000.0,
rngs=rngs,
)
# Loss includes capacity terms
losses = capacity_vae.loss_fn(x=batch, outputs=outputs, step=step)
print(losses['current_capacity']) # Current C value
print(losses['capacity_loss']) # γ * |KL - C|
print(losses['kl_capacity_diff']) # KL - C
ConditionalVAE¤
workshop.generative_models.models.vae.conditional.ConditionalVAE
¤
Bases: VAE
Conditional Variational Autoencoder implementation.
Extends the base VAE with conditioning capabilities by incorporating additional information at both encoding and decoding steps. This follows the standard CVAE pattern using one-hot concatenation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ConditionalVAEConfig
|
ConditionalVAEConfig with encoder, decoder, encoder_type, conditioning settings |
required |
rngs
|
Rngs
|
Random number generators |
required |
precision
|
Precision | None
|
Numerical precision for computations |
None
|
encoder
instance-attribute
¤
encoder = create_encoder(encoder, encoder_type, conditional=True, num_classes=condition_dim, rngs=rngs)
decoder
instance-attribute
¤
decoder = create_decoder(decoder, encoder_type, conditional=True, num_classes=condition_dim, rngs=rngs)
encode
¤
decode
¤
sample
¤
Sample from the model with conditioning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
temperature
|
float
|
Temperature parameter controlling randomness |
1.0
|
y
|
Array | None
|
Conditioning information (optional) |
None
|
**kwargs
|
Additional arguments for compatibility |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
generate
¤
generate(n_samples: int = 1, *, temperature: float = 1.0, y: Array | None = None, **kwargs) -> Array
Generate samples from the model with conditioning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
temperature
|
float
|
Temperature parameter controlling randomness |
1.0
|
y
|
Array | None
|
Conditioning information (optional) |
None
|
**kwargs
|
Additional arguments |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
reconstruct
¤
Conditional VAE for class-conditional generation.
Class Definition¤
class ConditionalVAE(VAE):
"""Conditional Variational Autoencoder."""
def __init__(
self,
encoder: nnx.Module,
decoder: nnx.Module,
latent_dim: int,
condition_dim: int = 10,
condition_type: str = "concat",
*,
rngs: nnx.Rngs | None = None,
):
"""Initialize Conditional VAE."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder |
nnx.Module |
Required | Conditional encoder |
decoder |
nnx.Module |
Required | Conditional decoder |
latent_dim |
int |
Required | Latent dimension |
condition_dim |
int |
10 |
Conditioning dimension |
condition_type |
str |
"concat" |
Conditioning strategy |
rngs |
nnx.Rngs |
None |
Random number generators |
Modified Methods¤
__call__¤
def __call__(
self,
x: jax.Array,
y: jax.Array | None = None,
*,
rngs: nnx.Rngs | None = None,
) -> dict[str, Any]:
"""Forward pass with conditioning."""
encode¤
def encode(
self,
x: jax.Array,
y: jax.Array | None = None,
*,
rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, jax.Array]:
"""Encode with conditioning."""
decode¤
def decode(
self,
z: jax.Array,
y: jax.Array | None = None,
*,
rngs: nnx.Rngs | None = None
) -> jax.Array:
"""Decode with conditioning."""
sample¤
def sample(
self,
n_samples: int = 1,
*,
temperature: float = 1.0,
y: jax.Array | None = None,
rngs: nnx.Rngs | None = None,
) -> jax.Array:
"""Sample with conditioning."""
Example¤
# Create conditional encoder/decoder
conditional_encoder = ConditionalEncoder(
encoder=base_encoder,
num_classes=10,
embed_dim=128,
rngs=rngs,
)
conditional_decoder = ConditionalDecoder(
decoder=base_decoder,
num_classes=10,
embed_dim=128,
rngs=rngs,
)
cvae = ConditionalVAE(
encoder=conditional_encoder,
decoder=conditional_decoder,
latent_dim=32,
condition_dim=10,
rngs=rngs,
)
# Forward pass with class labels
labels = jnp.array([0, 1, 2, 3, 4])
outputs = cvae(x, y=labels, rngs=rngs)
# Generate specific classes
target_labels = jnp.array([5, 5, 5, 5])
samples = cvae.sample(n_samples=4, y=target_labels, rngs=rngs)
VQVAE¤
workshop.generative_models.models.vae.vq_vae.VQVAE
¤
VQVAE(config: VQVAEConfig, *, rngs: Rngs)
Bases: VAE
Vector Quantized Variational Autoencoder implementation.
Extends the base VAE with discrete latent variables using a codebook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
VQVAEConfig
|
VQVAEConfig with encoder, decoder, encoder_type, and VQ settings |
required |
rngs
|
Rngs
|
Random number generators for initialization |
required |
embeddings
instance-attribute
¤
embeddings = Embed(num_embeddings=num_embeddings, features=embedding_dim, rngs=rngs)
quantize
¤
encode
¤
Encode input to continuous latent representation (before quantization).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input data, shape [batch_size, ...] |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Tuple of (mean, log_var) for VAE interface compatibility. |
Array
|
For VQ-VAE, log_var is zeros since encoding is deterministic. |
decode
¤
loss_fn
¤
sample
¤
Sample from the model. Args: n_samples: Number of samples to generate temperature: Temperature parameter controlling randomness (higher = more random) **kwargs: Additional keyword arguments for compatibility
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
generate
¤
Generate samples from the model.
Implements the required method from GenerativeModel base class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples
|
int
|
Number of samples to generate |
1
|
temperature
|
float
|
Temperature parameter controlling randomness |
1.0
|
**kwargs
|
Additional keyword arguments |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Generated samples |
Vector Quantized VAE with discrete latent representations.
Class Definition¤
class VQVAE(VAE):
"""Vector Quantized Variational Autoencoder."""
def __init__(
self,
encoder: nnx.Module,
decoder: nnx.Module,
latent_dim: int,
num_embeddings: int = 512,
embedding_dim: int = 64,
commitment_cost: float = 0.25,
*,
rngs: nnx.Rngs,
):
"""Initialize VQ-VAE."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder |
nnx.Module |
Required | Encoder network |
decoder |
nnx.Module |
Required | Decoder network |
latent_dim |
int |
Required | Latent dimension |
num_embeddings |
int |
512 |
Codebook size (number of embeddings) |
embedding_dim |
int |
64 |
Dimension of each embedding |
commitment_cost |
float |
0.25 |
Weight for commitment loss |
rngs |
nnx.Rngs |
Required | Random number generators |
Key Method: quantize¤
def quantize(
self,
encoding: jax.Array,
*,
rngs: nnx.Rngs | None = None
) -> tuple[jax.Array, dict[str, Any]]:
"""Quantize continuous encoding using codebook."""
Parameters:
encoding(Array): Continuous encoding from encoderrngs(Rngs, optional): Random number generators
Returns:
tuple: (quantized encoding, auxiliary info)quantized(Array): Discrete quantized vectorsaux(dict): Containscommitment_loss,codebook_loss,encoding_indices
Quantization Process:
- Find nearest codebook embedding for each encoding vector
- Replace encoding with codebook embedding
- Use straight-through estimator for gradients
Loss Function¤
VQ-VAE uses a specialized loss:
Where:
- First term: Reconstruction loss
- Second term: Codebook loss (update embeddings)
- Third term: Commitment loss (encourage encoder to commit)
Example¤
vqvae = VQVAE(
encoder=encoder,
decoder=decoder,
latent_dim=64,
num_embeddings=512, # 512 discrete codes
embedding_dim=64,
commitment_cost=0.25,
rngs=rngs,
)
# Forward pass includes quantization
outputs = vqvae(x, rngs=rngs)
print(outputs['z_e']) # Pre-quantization encoding
print(outputs['quantized']) # Quantized (discrete) encoding
print(outputs['commitment_loss']) # Commitment loss component
print(outputs['codebook_loss']) # Codebook loss component
# Loss includes VQ-specific terms
losses = vqvae.loss_fn(x=batch, outputs=outputs)
print(losses['reconstruction_loss'])
print(losses['commitment_loss'])
print(losses['codebook_loss'])
Encoders¤
MLPEncoder¤
workshop.generative_models.models.vae.encoders.MLPEncoder
¤
MLPEncoder(config: EncoderConfig, *, rngs: Rngs)
Bases: Module
Simple MLP encoder for VAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EncoderConfig
|
EncoderConfig with hidden_dims, latent_dim, activation, input_shape |
required |
rngs
|
Rngs
|
Random number generator |
required |
Fully-connected encoder for flattened inputs.
Class Definition¤
class MLPEncoder(nnx.Module):
"""MLP encoder for VAE."""
def __init__(
self,
hidden_dims: list,
latent_dim: int,
activation: str = "relu",
input_dim: tuple | None = None,
*,
rngs: nnx.Rngs,
):
"""Initialize MLP encoder."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims |
list[int] |
Required | Hidden layer dimensions |
latent_dim |
int |
Required | Latent space dimension |
activation |
str |
"relu" |
Activation function |
input_dim |
tuple \| None |
None |
Input dimensions for shape inference |
rngs |
nnx.Rngs |
Required | Random number generators |
Example¤
encoder = MLPEncoder(
hidden_dims=[512, 256, 128],
latent_dim=32,
activation="relu",
input_dim=(784,),
rngs=rngs,
)
mean, log_var = encoder(x, rngs=rngs)
CNNEncoder¤
workshop.generative_models.models.vae.encoders.CNNEncoder
¤
CNNEncoder(config: EncoderConfig, *, rngs: Rngs)
Bases: Module
CNN-based encoder for VAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
EncoderConfig
|
EncoderConfig with hidden_dims, latent_dim, activation, input_shape |
required |
rngs
|
Rngs
|
Random number generator |
required |
Convolutional encoder for image inputs.
Class Definition¤
class CNNEncoder(nnx.Module):
"""CNN encoder for VAE."""
def __init__(
self,
hidden_dims: list,
latent_dim: int,
activation: str = "relu",
input_dim: tuple | None = None,
*,
rngs: nnx.Rngs,
):
"""Initialize CNN encoder."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims |
list[int] |
Required | Channel dimensions for conv layers |
latent_dim |
int |
Required | Latent space dimension |
activation |
str |
"relu" |
Activation function |
input_dim |
tuple \| None |
None |
Input shape (H, W, C) |
rngs |
nnx.Rngs |
Required | Random number generators |
Architecture¤
- Series of convolutional layers with stride 2
- Each layer reduces spatial dimensions by half
- Global pooling before projecting to latent space
Example¤
encoder = CNNEncoder(
hidden_dims=[32, 64, 128, 256],
latent_dim=64,
activation="relu",
input_dim=(28, 28, 1),
rngs=rngs,
)
# Input shape: (batch, 28, 28, 1)
mean, log_var = encoder(images, rngs=rngs)
# Output shapes: (batch, 64), (batch, 64)
ConditionalEncoder¤
workshop.generative_models.models.vae.encoders.ConditionalEncoder
¤
Bases: Module
Wrapper that adds conditioning to any encoder.
This wrapper adds class conditioning to an existing encoder by converting labels to one-hot and concatenating them with the input. This follows the standard CVAE pattern from popular implementations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
encoder
|
Module
|
Base encoder to wrap |
required |
num_classes
|
int
|
Number of classes for conditioning |
required |
rngs
|
Rngs
|
Random number generators (for API consistency) |
required |
Wrapper that adds conditioning to any encoder.
Class Definition¤
class ConditionalEncoder(nnx.Module):
"""Conditional encoder wrapper."""
def __init__(
self,
encoder: nnx.Module,
num_classes: int,
embed_dim: int,
*,
rngs: nnx.Rngs,
):
"""Initialize conditional encoder."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
encoder |
nnx.Module |
Required | Base encoder to wrap |
num_classes |
int |
Required | Number of conditioning classes |
embed_dim |
int |
Required | Embedding dimension for labels |
rngs |
nnx.Rngs |
Required | Random number generators |
Example¤
base_encoder = MLPEncoder(
hidden_dims=[512, 256],
latent_dim=32,
rngs=rngs,
)
conditional_encoder = ConditionalEncoder(
encoder=base_encoder,
num_classes=10,
embed_dim=128,
rngs=rngs,
)
# Pass class labels as integers or one-hot
labels = jnp.array([0, 1, 2, 3])
mean, log_var = conditional_encoder(x, condition=labels, rngs=rngs)
Decoders¤
MLPDecoder¤
workshop.generative_models.models.vae.decoders.MLPDecoder
¤
MLPDecoder(config: DecoderConfig, *, rngs: Rngs)
Bases: Module
Simple MLP decoder for VAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DecoderConfig
|
DecoderConfig with hidden_dims, output_shape, latent_dim, activation |
required |
rngs
|
Rngs
|
Random number generator |
required |
Fully-connected decoder.
Class Definition¤
class MLPDecoder(nnx.Module):
"""MLP decoder for VAE."""
def __init__(
self,
hidden_dims: list[int],
output_dim: tuple[int, ...],
latent_dim: int,
activation: str = "relu",
*,
rngs: nnx.Rngs | None = None,
):
"""Initialize MLP decoder."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims |
list[int] |
Required | Hidden layer dimensions (reversed from encoder) |
output_dim |
tuple |
Required | Output reconstruction shape |
latent_dim |
int |
Required | Latent space dimension |
activation |
str |
"relu" |
Activation function |
rngs |
nnx.Rngs |
Required | Random number generators |
Example¤
decoder = MLPDecoder(
hidden_dims=[128, 256, 512], # Reversed from encoder
output_dim=(784,),
latent_dim=32,
activation="relu",
rngs=rngs,
)
reconstructed = decoder(z) # Shape: (batch, 784)
CNNDecoder¤
workshop.generative_models.models.vae.decoders.CNNDecoder
¤
CNNDecoder(config: DecoderConfig, *, rngs: Rngs)
Bases: Module
CNN-based decoder for VAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DecoderConfig
|
DecoderConfig with hidden_dims, output_shape, latent_dim, activation |
required |
rngs
|
Rngs
|
Random number generator |
required |
project
instance-attribute
¤
project = Linear(in_features=latent_dim, out_features=initial_size, rngs=rngs)
cnn
instance-attribute
¤
cnn = CNN(hidden_dims=decoder_dims, activation=activation, use_transpose=True, in_features=initial_features, rngs=rngs)
Transposed convolutional decoder for images.
Class Definition¤
class CNNDecoder(nnx.Module):
"""CNN decoder with transposed convolutions."""
def __init__(
self,
hidden_dims: list[int],
output_dim: tuple[int, ...],
latent_dim: int,
activation: str = "relu",
*,
rngs: nnx.Rngs | None = None,
):
"""Initialize CNN decoder."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_dims |
list[int] |
Required | Channel dimensions (reversed from encoder) |
output_dim |
tuple |
Required | Output shape (H, W, C) |
latent_dim |
int |
Required | Latent dimension |
activation |
str |
"relu" |
Activation function |
rngs |
nnx.Rngs |
Required | Random number generators |
Example¤
decoder = CNNDecoder(
hidden_dims=[256, 128, 64, 32], # Reversed channels
output_dim=(28, 28, 1),
latent_dim=64,
activation="relu",
rngs=rngs,
)
# Input: (batch, 64)
reconstructed = decoder(z) # Output: (batch, 28, 28, 1)
ConditionalDecoder¤
workshop.generative_models.models.vae.decoders.ConditionalDecoder
¤
Bases: Module
Wrapper that adds conditioning to any decoder.
This wrapper adds class conditioning to an existing decoder by converting labels to one-hot and concatenating them with the latent vector. This follows the standard CVAE pattern from popular implementations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decoder
|
Module
|
Base decoder to wrap |
required |
num_classes
|
int
|
Number of classes for conditioning |
required |
rngs
|
Rngs
|
Random number generators |
required |
Wrapper that adds conditioning to any decoder.
Class Definition¤
class ConditionalDecoder(nnx.Module):
"""Conditional decoder wrapper."""
def __init__(
self,
decoder: nnx.Module,
num_classes: int,
embed_dim: int,
*,
rngs: nnx.Rngs,
):
"""Initialize conditional decoder."""
Parameters¤
| Parameter | Type | Default | Description |
|---|---|---|---|
decoder |
nnx.Module |
Required | Base decoder to wrap |
num_classes |
int |
Required | Number of conditioning classes |
embed_dim |
int |
Required | Embedding dimension |
rngs |
nnx.Rngs |
Required | Random number generators |
Example¤
base_decoder = MLPDecoder(
hidden_dims=[128, 256, 512],
output_dim=(784,),
latent_dim=32,
rngs=rngs,
)
conditional_decoder = ConditionalDecoder(
decoder=base_decoder,
num_classes=10,
embed_dim=128,
rngs=rngs,
)
labels = jnp.array([0, 1, 2, 3])
reconstructed = conditional_decoder(z, condition=labels, rngs=rngs)
Utility Functions¤
create_encoder_unified¤
def create_encoder_unified(
*,
config: ModelConfiguration,
encoder_type: str,
conditional: bool = False,
num_classes: int | None = None,
embed_dim: int | None = None,
rngs: nnx.Rngs,
) -> nnx.Module:
"""Create encoder from unified configuration."""
Parameters:
config(ModelConfiguration): Model configurationencoder_type(str): Type of encoder:"dense","cnn","resnet"conditional(bool): Whether to wrap in conditional encodernum_classes(int, optional): Number of classes for conditioningembed_dim(int, optional): Embedding dimensionrngs(Rngs): Random number generators
Returns:
nnx.Module: Configured encoder
create_decoder_unified¤
def create_decoder_unified(
*,
config: ModelConfiguration,
decoder_type: str,
conditional: bool = False,
num_classes: int | None = None,
embed_dim: int | None = None,
rngs: nnx.Rngs,
) -> nnx.Module:
"""Create decoder from unified configuration."""
Parameters:
config(ModelConfiguration): Model configurationdecoder_type(str): Type of decoder:"dense","cnn","resnet"conditional(bool): Whether to wrap in conditional decodernum_classes(int, optional): Number of classes for conditioningembed_dim(int, optional): Embedding dimensionrngs(Rngs): Random number generators
Returns:
nnx.Module: Configured decoder
Complete Example¤
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.vae import BetaVAE
from workshop.generative_models.models.vae.encoders import CNNEncoder
from workshop.generative_models.models.vae.decoders import CNNDecoder
# Initialize
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
# Create encoder
encoder = CNNEncoder(
hidden_dims=[32, 64, 128, 256],
latent_dim=64,
activation="relu",
input_dim=(28, 28, 1),
rngs=rngs,
)
# Create decoder
decoder = CNNDecoder(
hidden_dims=[256, 128, 64, 32],
output_dim=(28, 28, 1),
latent_dim=64,
activation="relu",
rngs=rngs,
)
# Create β-VAE
model = BetaVAE(
encoder=encoder,
decoder=decoder,
latent_dim=64,
beta_default=4.0,
beta_warmup_steps=10000,
reconstruction_loss_type="mse",
rngs=rngs,
)
# Forward pass
x = jnp.ones((32, 28, 28, 1))
outputs = model(x, rngs=rngs)
# Calculate loss
losses = model.loss_fn(x=x, outputs=outputs, step=5000)
print(f"Total Loss: {losses['total_loss']:.4f}")
print(f"Beta: {losses['beta']:.4f}")
# Generate samples
samples = model.sample(n_samples=16, temperature=1.0, rngs=rngs)
# Latent traversal
traversal = model.latent_traversal(x[0], dim=10, steps=15, rngs=rngs)
See Also¤
- VAE Concepts — Theory and mathematical foundations
- VAE User Guide — Practical usage and examples
- Training Guide — Training VAE models
- Loss Functions — Available loss functions
- Configuration — Configuration system