Skip to content

Flow Models: Practical User Guide¤

This guide provides practical instructions for working with normalizing flow models in Workshop. We cover creating, training, and using various flow architectures for density estimation and generation.

Quick Start¤

Here's a minimal example to get started with RealNVP:

import jax
import jax.numpy as jnp
from flax import nnx

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP

# Create RNG streams
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Configure the model
config = ModelConfiguration(
    name="realnvp_model",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=784,  # MNIST flattened
    output_dim=784,
    hidden_dims=[512, 512],
    parameters={
        "num_coupling_layers": 8,
        "base_distribution": "normal",
    }
)

# Create model
model = RealNVP(config, rngs=rngs)

# Train on data
batch = jnp.array(...)  # Your training data
outputs = model(batch, rngs=rngs)
log_likelihood = outputs["log_prob"]
loss = -jnp.mean(log_likelihood)

# Generate samples
samples = model.generate(n_samples=16, rngs=rngs)

Creating Flow Models¤

Workshop provides multiple flow architectures. Choose based on your needs (see Flow Concepts for detailed comparison).

RealNVP offers a good balance between performance and computational cost.

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP
from flax import nnx
import jax

# Create RNGs
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Configure RealNVP
config = ModelConfiguration(
    name="realnvp_flow",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=64,  # Feature dimension
    output_dim=64,
    hidden_dims=[256, 256],  # Coupling network hidden layers
    parameters={
        "num_coupling_layers": 8,  # Number of coupling transformations
        "mask_type": "checkerboard",  # or "channel-wise" for images
        "base_distribution": "normal",
        "base_distribution_params": {
            "loc": 0.0,
            "scale": 1.0,
        },
    }
)

# Create model
model = RealNVP(config, rngs=rngs)

# Forward pass (data to latent)
x = jax.random.normal(rngs.sample(), (32, 64))
z, log_det = model.forward(x, rngs=rngs)

# Inverse pass (latent to data)
samples, _ = model.inverse(z, rngs=rngs)

# Compute log probability
log_prob = model.log_prob(x, rngs=rngs)
print(f"Log probability: {jnp.mean(log_prob):.3f}")

Mask Types:

  • "checkerboard": Alternates dimensions (good for tabular data)
  • "channel-wise": Splits along channels (better for images)
# For image data (H, W, C)
config_image = ModelConfiguration(
    name="realnvp_image",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=(28, 28, 1),  # MNIST shape
    output_dim=(28, 28, 1),
    hidden_dims=[512, 512],
    parameters={
        "num_coupling_layers": 12,
        "mask_type": "channel-wise",  # Better for images
        "base_distribution": "normal",
    }
)

Glow (High-Quality Image Generation)¤

Glow uses a multi-scale architecture with ActNorm, invertible 1×1 convolutions, and coupling layers.

from workshop.generative_models.models.flow import Glow

# Configure Glow
config = ModelConfiguration(
    name="glow_model",
    model_class="workshop.generative_models.models.flow.Glow",
    input_dim=(32, 32, 3),  # Image shape
    hidden_dims=[512, 512],
    parameters={
        "image_shape": (32, 32, 3),
        "num_scales": 3,  # Multi-scale architecture
        "blocks_per_scale": 6,  # Flow steps per scale
    }
)

# Create Glow model
rngs = nnx.Rngs(params=0, sample=1)
model = Glow(config, rngs=rngs)

# Training
images = jax.random.normal(rngs.sample(), (16, 32, 32, 3))
outputs = model(images, rngs=rngs)
loss = -jnp.mean(outputs["log_prob"])

# Generate high-quality samples
samples = model.generate(n_samples=16, rngs=rngs)

Glow Architecture Parameters:

  • num_scales: Number of multi-scale levels (typically 2-4)
  • blocks_per_scale: Flow steps at each scale (typically 4-8)
  • Higher values = more expressive but slower

MAF (Fast Density Estimation)¤

MAF (Masked Autoregressive Flow) excels at density estimation but has slow sampling.

from workshop.generative_models.models.flow import MAF

# Configure MAF
config = ModelConfiguration(
    name="maf_model",
    model_class="workshop.generative_models.models.flow.MAF",
    input_dim=64,
    output_dim=64,
    hidden_dims=[512],  # MADE hidden dimensions
    parameters={
        "num_layers": 5,  # Number of MAF layers
        "reverse_ordering": True,  # Alternate variable ordering
    }
)

# Create MAF model
rngs = nnx.Rngs(params=0, sample=1)
model = MAF(config, rngs=rngs)

# Fast forward pass (density estimation)
x = jax.random.normal(rngs.sample(), (100, 64))
log_prob = model.log_prob(x, rngs=rngs)
print(f"Mean log-likelihood: {jnp.mean(log_prob):.3f}")

# Slow inverse pass (sampling)
samples = model.sample(n_samples=10, rngs=rngs)  # Sequential, slower

When to Use MAF:

  • Primary goal is density estimation or anomaly detection
  • Sampling speed is not critical
  • Working with tabular or low-dimensional data
  • Need high-quality likelihood estimates

IAF (Fast Sampling)¤

IAF (Inverse Autoregressive Flow) provides fast sampling at the cost of slow density estimation.

from workshop.generative_models.models.flow import IAF

# Configure IAF
config = ModelConfiguration(
    name="iaf_model",
    model_class="workshop.generative_models.models.flow.IAF",
    input_dim=64,
    output_dim=64,
    hidden_dims=[512],
    parameters={
        "num_layers": 5,
        "reverse_ordering": True,
    }
)

# Create IAF model
rngs = nnx.Rngs(params=0, sample=1)
model = IAF(config, rngs=rngs)

# Fast sampling (parallel computation)
samples = model.sample(n_samples=100, rngs=rngs)  # Fast!

# Slow density estimation (sequential)
log_prob = model.log_prob(samples, rngs=rngs)  # Slower

When to Use IAF:

  • Fast sampling is critical (real-time generation)
  • Often used as variational posterior in VAEs
  • Density estimation is secondary
  • Generation frequency >> inference frequency

Neural Spline Flows (Most Expressive)¤

Neural Spline Flows use rational quadratic splines for highly expressive transformations.

from workshop.generative_models.models.flow import NeuralSplineFlow

# Configure Neural Spline Flow
config = ModelConfiguration(
    name="spline_flow",
    model_class="workshop.generative_models.models.flow.NeuralSplineFlow",
    input_dim=64,
    hidden_dims=[128, 128],
    metadata={
        "flow_params": {
            "num_layers": 8,
            "num_bins": 8,  # Number of spline segments
            "tail_bound": 3.0,  # Spline domain bounds
            "base_distribution": "normal",
        }
    }
)

# Create Neural Spline Flow
rngs = nnx.Rngs(params=0, sample=1)
model = NeuralSplineFlow(config, rngs=rngs)

# More expressive transformations
x = jax.random.normal(rngs.sample(), (32, 64))
log_prob = model.log_prob(x, rngs=rngs)

# Generate samples
samples = model.generate(n_samples=16, rngs=rngs)

Spline Parameters:

  • num_bins: Number of spline segments (8-16 typical)
  • tail_bound: Domain bounds for spline (3.0-5.0 typical)
  • More bins = more expressive but higher memory cost

Training Flow Models¤

Basic Training Loop¤

Flow models are trained using maximum likelihood estimation.

import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.flow import RealNVP

# Initialize model and optimizer
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
model = RealNVP(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-4))

# Training step function
@nnx.jit
def train_step(model, optimizer, batch, rngs):
    """Single training step."""
    def loss_fn(model):
        # Forward pass through flow
        outputs = model(batch, rngs=rngs, training=True)

        # Negative log-likelihood loss
        log_prob = outputs["log_prob"]
        loss = -jnp.mean(log_prob)

        return loss, {"nll": loss, "mean_log_prob": jnp.mean(log_prob)}

    # Compute loss and gradients
    (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

    # Update parameters
    optimizer.update(grads)

    return metrics

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    epoch_metrics = []

    for batch in train_dataloader:
        # Preprocess: add uniform noise for dequantization
        batch = batch + jax.random.uniform(rngs.sample(), batch.shape) / 256.0

        # Scale to appropriate range
        batch = (batch - 0.5) / 0.5  # Scale to [-1, 1]

        # Training step
        metrics = train_step(model, optimizer, batch, rngs)
        epoch_metrics.append(metrics)

    # Log epoch statistics
    avg_nll = jnp.mean(jnp.array([m["nll"] for m in epoch_metrics]))
    print(f"Epoch {epoch+1}/{num_epochs}, NLL: {avg_nll:.3f}")

Training with Gradient Clipping¤

Gradient clipping helps stabilize flow training:

import optax

# Create optimizer with gradient clipping
optimizer_chain = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip gradients
    optax.adam(learning_rate=1e-4),
)

optimizer = nnx.Optimizer(model, optimizer_chain)

# Training step with clipping
@nnx.jit
def train_step_clipped(model, optimizer, batch, rngs):
    def loss_fn(model):
        outputs = model(batch, rngs=rngs, training=True)
        loss = -jnp.mean(outputs["log_prob"])
        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)

    # Optimizer applies gradient clipping automatically
    optimizer.update(grads)

    return {"loss": loss}

Learning Rate Scheduling¤

Use learning rate warmup and decay for better convergence:

import optax

# Learning rate schedule
warmup_steps = 1000
total_steps = 50000

schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-7,
    peak_value=1e-4,
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps,
    end_value=1e-6,
)

# Create optimizer with schedule
optimizer = nnx.Optimizer(
    model,
    optax.adam(learning_rate=schedule)
)

# Track global step
global_step = 0

# Training loop
for epoch in range(num_epochs):
    for batch in train_dataloader:
        metrics = train_step(model, optimizer, batch, rngs)
        global_step += 1

        # Learning rate automatically updated by optax

Monitoring Training¤

Track important metrics during training:

# Training with metrics tracking
@nnx.jit
def train_step_with_metrics(model, optimizer, batch, rngs):
    def loss_fn(model):
        # Forward pass
        z, log_det = model.forward(batch, rngs=rngs)

        # Base distribution log prob
        log_p_z = -0.5 * jnp.sum(z**2, axis=-1) - 0.5 * z.shape[-1] * jnp.log(2 * jnp.pi)

        # Total log probability
        log_prob = log_p_z + log_det

        # Loss
        loss = -jnp.mean(log_prob)

        # Additional metrics
        metrics = {
            "loss": loss,
            "log_p_z": jnp.mean(log_p_z),
            "log_det": jnp.mean(log_det),
            "log_prob": jnp.mean(log_prob),
            "z_norm": jnp.mean(jnp.linalg.norm(z, axis=-1)),
        }

        return loss, metrics

    (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(grads)

    return metrics

# Training loop with logging
for epoch in range(num_epochs):
    metrics_list = []

    for batch in train_dataloader:
        batch = preprocess(batch)
        metrics = train_step_with_metrics(model, optimizer, batch, rngs)
        metrics_list.append(metrics)

    # Aggregate epoch metrics
    epoch_metrics = {
        k: jnp.mean(jnp.array([m[k] for m in metrics_list]))
        for k in metrics_list[0].keys()
    }

    print(f"Epoch {epoch+1}: {epoch_metrics}")

Important Metrics to Monitor:

  • loss: Negative log-likelihood (should decrease)
  • log_p_z: Base distribution log-prob (should be near 0 for standard Gaussian)
  • log_det: Jacobian log-determinant (tracks transformation magnitude)
  • z_norm: Latent norm (should be near √d for d-dimensional Gaussian)

Data Preprocessing for Flows¤

Proper preprocessing is crucial for flow models:

Dequantization¤

Images are discrete (0-255), but flows need continuous values:

def dequantize(images, rngs):
    """Add uniform noise to dequantize discrete images."""
    # images should be in [0, 1] range
    noise = jax.random.uniform(rngs.sample(), images.shape)
    return images + noise / 256.0

# Apply during training
batch = dequantize(batch, rngs)

Logit Transform¤

Map bounded data to unbounded space:

def logit_transform(x, alpha=0.05):
    """Apply logit transform with boundary handling."""
    # Squeeze to (alpha, 1-alpha) to avoid infinities
    x = alpha + (1 - 2*alpha) * x

    # Apply logit
    return jnp.log(x) - jnp.log(1 - x)

def inverse_logit_transform(y, alpha=0.05):
    """Inverse of logit transform."""
    x = jax.nn.sigmoid(y)
    return (x - alpha) / (1 - 2*alpha)

# Apply during training
batch = logit_transform(batch)

Normalization¤

Standardize data to zero mean and unit variance:

# Compute statistics on training data
train_mean = jnp.mean(train_data, axis=0, keepdims=True)
train_std = jnp.std(train_data, axis=0, keepdims=True)

# Normalize
batch = (batch - train_mean) / (train_std + 1e-6)

# Remember to denormalize samples
samples = model.generate(n_samples=16, rngs=rngs)
samples = samples * train_std + train_mean

Sampling and Generation¤

Basic Sampling¤

Generate samples from a trained flow model:

# Generate samples
n_samples = 16
samples = model.generate(n_samples=n_samples, rngs=rngs)

# For image data, reshape if needed
if isinstance(config.input_dim, tuple):
    # Already in correct shape
    images = samples
else:
    # Reshape to image dimensions
    H, W, C = 28, 28, 1
    images = samples.reshape(n_samples, H, W, C)

# Denormalize for visualization
images = (images * 0.5) + 0.5  # From [-1, 1] to [0, 1]
images = jnp.clip(images, 0, 1)

Temperature Sampling¤

Control sample diversity with temperature:

def sample_with_temperature(model, n_samples, temperature, rngs):
    """Sample with temperature scaling.

    temperature > 1: More diverse samples
    temperature < 1: More conservative samples
    temperature = 1: Standard sampling
    """
    # Sample from base distribution
    z = jax.random.normal(rngs.sample(), (n_samples, model.latent_dim))

    # Scale by temperature
    z = z * temperature

    # Transform to data space
    samples, _ = model.inverse(z, rngs=rngs)

    return samples

# Conservative samples (sharper, less diverse)
conservative = sample_with_temperature(model, 16, temperature=0.7, rngs=rngs)

# Diverse samples (more variety, less sharp)
diverse = sample_with_temperature(model, 16, temperature=1.3, rngs=rngs)

Conditional Sampling¤

Some flow architectures support conditional generation:

# For conditional flows (if implemented)
from workshop.generative_models.models.flow import ConditionalRealNVP

# Create conditional model
config = ModelConfiguration(
    name="conditional_realnvp",
    model_class="workshop.generative_models.models.flow.ConditionalRealNVP",
    input_dim=784,
    output_dim=784,
    hidden_dims=[512, 512],
    parameters={
        "num_coupling_layers": 8,
        "condition_dim": 10,  # e.g., class labels
    }
)

model = ConditionalRealNVP(config, rngs=rngs)

# Sample conditioned on class labels
class_labels = jax.nn.one_hot(jnp.array([0, 1, 2]), 10)  # 3 classes
conditional_samples = model.generate(
    n_samples=3,
    condition=class_labels,
    rngs=rngs
)

Interpolation in Latent Space¤

Interpolate between two data points:

def interpolate(model, x1, x2, num_steps, rngs):
    """Interpolate between two data points in latent space."""
    # Encode to latent space
    z1, _ = model.forward(x1[None, ...], rngs=rngs)
    z2, _ = model.forward(x2[None, ...], rngs=rngs)

    # Linear interpolation in latent space
    alphas = jnp.linspace(0, 1, num_steps)
    z_interp = jnp.array([
        (1 - alpha) * z1 + alpha * z2
        for alpha in alphas
    ]).squeeze(1)

    # Decode to data space
    x_interp, _ = model.inverse(z_interp, rngs=rngs)

    return x_interp

# Interpolate between two images
x1 = train_data[0]  # First image
x2 = train_data[1]  # Second image
interpolations = interpolate(model, x1, x2, num_steps=10, rngs=rngs)

Density Estimation and Evaluation¤

Computing Log-Likelihood¤

Flow models provide exact log-likelihood:

# Compute log-likelihood for test data
test_data = ...  # Your test dataset

log_likelihoods = []
for batch in test_dataloader:
    # Preprocess same as training
    batch = dequantize(batch, rngs)
    batch = (batch - 0.5) / 0.5

    # Compute log probability
    log_prob = model.log_prob(batch, rngs=rngs)
    log_likelihoods.append(log_prob)

# Average log-likelihood
all_log_probs = jnp.concatenate(log_likelihoods)
avg_log_likelihood = jnp.mean(all_log_probs)
print(f"Test log-likelihood: {avg_log_likelihood:.3f}")

# Bits per dimension (common metric)
input_dim = jnp.prod(jnp.array(config.input_dim))
bits_per_dim = -avg_log_likelihood / (input_dim * jnp.log(2))
print(f"Bits per dimension: {bits_per_dim:.3f}")

Anomaly Detection¤

Use log-likelihood for anomaly detection:

def detect_anomalies(model, data, threshold, rngs):
    """Detect anomalies using log-likelihood threshold."""
    # Compute log probabilities
    log_probs = model.log_prob(data, rngs=rngs)

    # Flag samples below threshold as anomalies
    is_anomaly = log_probs < threshold

    return is_anomaly, log_probs

# Set threshold (e.g., 5th percentile of training data)
train_log_probs = model.log_prob(train_data, rngs=rngs)
threshold = jnp.percentile(train_log_probs, 5)

# Detect anomalies in test data
anomalies, test_log_probs = detect_anomalies(
    model, test_data, threshold, rngs
)

print(f"Detected {jnp.sum(anomalies)} anomalies out of {len(test_data)} samples")

Model Comparison¤

Compare different flow architectures using likelihood:

# Train multiple models
models = {
    "RealNVP": realnvp_model,
    "Glow": glow_model,
    "MAF": maf_model,
    "Spline": spline_model,
}

# Evaluate on test set
results = {}
for name, model in models.items():
    log_probs = []
    for batch in test_dataloader:
        batch = preprocess(batch)
        log_prob = model.log_prob(batch, rngs=rngs)
        log_probs.append(log_prob)

    avg_log_prob = jnp.mean(jnp.concatenate(log_probs))
    results[name] = avg_log_prob

    print(f"{name}: {avg_log_prob:.3f} (higher is better)")

# Best model
best_model = max(results, key=results.get)
print(f"Best model: {best_model}")

Common Patterns¤

Multi-Modal Data Distribution¤

For data with multiple modes, increase model capacity:

# Increase number of layers
config = ModelConfiguration(
    name="multimodal_flow",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=64,
    output_dim=64,
    hidden_dims=[1024, 1024, 1024],  # Deeper networks
    parameters={
        "num_coupling_layers": 16,  # More layers
    }
)

# Or use Neural Spline Flows for higher expressiveness
config_spline = ModelConfiguration(
    name="multimodal_spline",
    model_class="workshop.generative_models.models.flow.NeuralSplineFlow",
    input_dim=64,
    hidden_dims=[256, 256],
    metadata={
        "flow_params": {
            "num_layers": 12,
            "num_bins": 16,  # More bins for expressiveness
        }
    }
)

High-Dimensional Data¤

For very high-dimensional data (e.g., high-resolution images):

# Use Glow with multi-scale architecture
config = ModelConfiguration(
    name="glow_highres",
    model_class="workshop.generative_models.models.flow.Glow",
    input_dim=(64, 64, 3),
    hidden_dims=[512, 512],
    parameters={
        "image_shape": (64, 64, 3),
        "num_scales": 4,  # More scales for higher resolution
        "blocks_per_scale": 8,
    }
)

# Reduce memory by processing in patches
def train_on_patches(model, image, patch_size=32):
    """Train on image patches to reduce memory."""
    H, W, C = image.shape
    patches = []

    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = image[i:i+patch_size, j:j+patch_size, :]
            patches.append(patch)

    patches = jnp.array(patches)
    return model(patches, rngs=rngs)

Tabular Data with Mixed Types¤

For tabular data with continuous and categorical features:

# Preprocess mixed types
def preprocess_tabular(data, categorical_indices):
    """Preprocess tabular data with mixed types."""
    continuous = data.copy()

    # One-hot encode categorical features
    for idx in categorical_indices:
        # One-hot encode
        n_categories = int(jnp.max(data[:, idx])) + 1
        one_hot = jax.nn.one_hot(data[:, idx].astype(int), n_categories)

        # Replace categorical column with one-hot
        continuous = jnp.concatenate([
            continuous[:, :idx],
            one_hot,
            continuous[:, idx+1:],
        ], axis=1)

    return continuous

# Use MAF for tabular data (good density estimation)
config = ModelConfiguration(
    name="tabular_maf",
    model_class="workshop.generative_models.models.flow.MAF",
    input_dim=processed_dim,
    hidden_dims=[512, 512],
    parameters={
        "num_layers": 10,  # More layers for complex dependencies
    }
)

Exact Reconstruction¤

Verify model invertibility:

def test_invertibility(model, x, rngs, tolerance=1e-4):
    """Test that forward and inverse are true inverses."""
    # Forward then inverse
    z, _ = model.forward(x, rngs=rngs)
    x_recon, _ = model.inverse(z, rngs=rngs)

    # Compute reconstruction error
    error = jnp.max(jnp.abs(x - x_recon))

    print(f"Max reconstruction error: {error:.6f}")
    assert error < tolerance, f"Reconstruction error {error} exceeds tolerance {tolerance}"

# Test on random data
x = jax.random.normal(rngs.sample(), (10, 64))
test_invertibility(model, x, rngs)

Troubleshooting¤

Issue: NaN Loss During Training¤

Symptoms: Loss becomes NaN after a few iterations.

Solutions:

  1. Add gradient clipping:
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(1e-4)
    )
)
  1. Check data preprocessing:
# Ensure data is properly scaled
assert jnp.all(jnp.isfinite(batch)), "Data contains NaN or Inf"
assert jnp.abs(jnp.mean(batch)) < 10, "Data not properly normalized"
  1. Reduce learning rate:
optimizer = nnx.Optimizer(model, optax.adam(1e-5))  # Lower LR
  1. Check Jacobian stability:
# Monitor log-determinant magnitude
z, log_det = model.forward(batch, rngs=rngs)
print(f"Log-det range: [{jnp.min(log_det):.2f}, {jnp.max(log_det):.2f}]")

# If log_det has extreme values, reduce model capacity or layers

Issue: Poor Sample Quality¤

Symptoms: Generated samples look noisy or unrealistic.

Solutions:

  1. Increase model capacity:
config = ModelConfiguration(
    name="larger_model",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=784,
    hidden_dims=[1024, 1024],  # Larger networks
    parameters={
        "num_coupling_layers": 16,  # More layers
    }
)
  1. Use more expressive architecture:
# Switch from RealNVP to Neural Spline Flows
model = NeuralSplineFlow(spline_config, rngs=rngs)
  1. Improve data preprocessing:
# Apply logit transform for bounded data
batch = logit_transform(batch)
  1. Train longer:
num_epochs = 200  # More epochs
# Monitor validation likelihood to check for convergence

Issue: Slow Training¤

Symptoms: Training takes too long per iteration.

Solutions:

  1. Use JIT compilation:
@nnx.jit
def train_step(model, optimizer, batch, rngs):
    # ... training step code
    pass
  1. Reduce model complexity:
# Fewer coupling layers
parameters={"num_coupling_layers": 6}  # Instead of 16

# Smaller hidden dimensions
hidden_dims=[256, 256]  # Instead of [1024, 1024]
  1. Use IAF for fast sampling (if sampling is the bottleneck):
# IAF has fast sampling
model = IAF(config, rngs=rngs)
  1. Batch processing:
# Increase batch size (if memory allows)
batch_size = 128  # Instead of 32

Issue: Mode Collapse¤

Symptoms: Model generates similar samples repeatedly.

Solutions:

  1. Check latent space coverage:
# Generate many samples and check latent codes
samples = model.generate(n_samples=1000, rngs=rngs)
z_samples, _ = model.forward(samples, rngs=rngs)

# Check if latents cover the expected distribution
z_mean = jnp.mean(z_samples, axis=0)
z_std = jnp.std(z_samples, axis=0)

print(f"Latent mean: {jnp.mean(z_mean):.3f} (should be ~0)")
print(f"Latent std: {jnp.mean(z_std):.3f} (should be ~1)")
  1. Increase model expressiveness:
# Use Neural Spline Flows
# Or increase number of flow layers
  1. Check for numerical issues:
# Ensure stable training
# Use gradient clipping and proper LR

Issue: Memory Errors¤

Symptoms: Out of memory errors during training.

Solutions:

  1. Reduce batch size:
batch_size = 16  # Smaller batches
  1. Use gradient checkpointing (if available):
# Recompute intermediate activations during backward pass
# (implementation-specific)
  1. Reduce model size:
# Fewer layers or smaller hidden dimensions
hidden_dims=[256, 256]
parameters={"num_coupling_layers": 6}
  1. Use mixed precision training:
# Use float16 for some computations (implementation-specific)

Best Practices¤

DO¤

Preprocess data properly:

# Always dequantize discrete data
batch = dequantize(batch, rngs)

# Normalize to appropriate range
batch = (batch - 0.5) / 0.5

Monitor multiple metrics:

# Track loss, log_det, base log prob
metrics = {
    "loss": loss,
    "log_det": jnp.mean(log_det),
    "log_p_z": jnp.mean(log_p_z),
}

Use gradient clipping:

optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(1e-4)
    )
)

Validate invertibility:

# Periodically test reconstruction
test_invertibility(model, validation_batch, rngs)

Choose architecture for your task:

# MAF for density estimation
# IAF for fast sampling
# RealNVP for balance
# Glow for high-quality images
# Spline Flows for expressiveness

DON'T¤

Don't skip data preprocessing:

# BAD: Using raw discrete images
model(raw_images, rngs=rngs)  # Will perform poorly!

# GOOD: Dequantize and normalize
processed = dequantize(raw_images, rngs)
processed = (processed - 0.5) / 0.5
model(processed, rngs=rngs)

Don't ignore numerical stability:

# BAD: No gradient clipping
# Can lead to NaN losses

# GOOD: Use gradient clipping
optimizer = nnx.Optimizer(
    model,
    optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4))
)

Don't use wrong architecture for task:

# BAD: Using IAF when density estimation is primary goal
# IAF has slow forward pass

# GOOD: Use MAF for density estimation
model = MAF(config, rngs=rngs)

Don't overtrain on small datasets:

# Monitor validation likelihood
# Use early stopping

Hyperparameter Guidelines¤

Number of Flow Layers¤

Data Type Recommended Layers Notes
Tabular (low-dim) 6-10 Simpler distributions
Tabular (high-dim) 10-16 More complex dependencies
Images (low-res) 8-12 MNIST, CIFAR-10
Images (high-res) 12-24 Higher resolution

Hidden Dimensions¤

Model Recommended Notes
RealNVP [512, 512] 2-3 layers sufficient
Glow [512, 512] Larger for high-res
MAF/IAF [512] - [1024] Single deep network
Spline [128, 128] Splines are expressive

Learning Rates¤

Stage Learning Rate Notes
Warmup 1e-7 → 1e-4 First 1000 steps
Training 1e-4 Standard rate
Fine-tuning 1e-5 Near convergence

Batch Sizes¤

Data Type Batch Size Notes
Tabular 256-1024 Can use large batches
Images (32×32) 64-128 Memory dependent
Images (64×64) 32-64 Reduce for Glow
Images (128×128) 16-32 Limited by memory

Summary¤

Quick Reference¤

Model Selection:

  • RealNVP: Balanced performance, good for most tasks
  • Glow: Best for high-quality image generation
  • MAF: Optimal for density estimation
  • IAF: Optimal for fast sampling
  • Spline Flows: Most expressive transformations

Training Checklist:

  1. ✅ Preprocess data (dequantize, normalize)
  2. ✅ Use gradient clipping
  3. ✅ Monitor multiple metrics
  4. ✅ Validate invertibility
  5. ✅ Apply learning rate warmup
  6. ✅ Check for NaN/Inf values

Common Workflows:

# Density estimation workflow
model = MAF(config, rngs=rngs)
log_probs = model.log_prob(data, rngs=rngs)

# Generation workflow
model = RealNVP(config, rngs=rngs)
samples = model.generate(n_samples=16, rngs=rngs)

# Anomaly detection workflow
model = MAF(config, rngs=rngs)
threshold = jnp.percentile(train_log_probs, 5)
anomalies = test_log_probs < threshold

Next Steps¤

References¤

  • Dinh et al. (2016): "Density estimation using Real NVP"
  • Kingma & Dhariwal (2018): "Glow: Generative Flow with Invertible 1x1 Convolutions"
  • Papamakarios et al. (2017): "Masked Autoregressive Flow for Density Estimation"
  • Durkan et al. (2019): "Neural Spline Flows"