Flow Models: Practical User Guide¤
This guide provides practical instructions for working with normalizing flow models in Workshop. We cover creating, training, and using various flow architectures for density estimation and generation.
Quick Start¤
Here's a minimal example to get started with RealNVP:
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP
# Create RNG streams
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
# Configure the model
config = ModelConfiguration(
name="realnvp_model",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=784, # MNIST flattened
output_dim=784,
hidden_dims=[512, 512],
parameters={
"num_coupling_layers": 8,
"base_distribution": "normal",
}
)
# Create model
model = RealNVP(config, rngs=rngs)
# Train on data
batch = jnp.array(...) # Your training data
outputs = model(batch, rngs=rngs)
log_likelihood = outputs["log_prob"]
loss = -jnp.mean(log_likelihood)
# Generate samples
samples = model.generate(n_samples=16, rngs=rngs)
Creating Flow Models¤
Workshop provides multiple flow architectures. Choose based on your needs (see Flow Concepts for detailed comparison).
RealNVP (Recommended for Most Tasks)¤
RealNVP offers a good balance between performance and computational cost.
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP
from flax import nnx
import jax
# Create RNGs
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
# Configure RealNVP
config = ModelConfiguration(
name="realnvp_flow",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=64, # Feature dimension
output_dim=64,
hidden_dims=[256, 256], # Coupling network hidden layers
parameters={
"num_coupling_layers": 8, # Number of coupling transformations
"mask_type": "checkerboard", # or "channel-wise" for images
"base_distribution": "normal",
"base_distribution_params": {
"loc": 0.0,
"scale": 1.0,
},
}
)
# Create model
model = RealNVP(config, rngs=rngs)
# Forward pass (data to latent)
x = jax.random.normal(rngs.sample(), (32, 64))
z, log_det = model.forward(x, rngs=rngs)
# Inverse pass (latent to data)
samples, _ = model.inverse(z, rngs=rngs)
# Compute log probability
log_prob = model.log_prob(x, rngs=rngs)
print(f"Log probability: {jnp.mean(log_prob):.3f}")
Mask Types:
"checkerboard": Alternates dimensions (good for tabular data)"channel-wise": Splits along channels (better for images)
# For image data (H, W, C)
config_image = ModelConfiguration(
name="realnvp_image",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=(28, 28, 1), # MNIST shape
output_dim=(28, 28, 1),
hidden_dims=[512, 512],
parameters={
"num_coupling_layers": 12,
"mask_type": "channel-wise", # Better for images
"base_distribution": "normal",
}
)
Glow (High-Quality Image Generation)¤
Glow uses a multi-scale architecture with ActNorm, invertible 1×1 convolutions, and coupling layers.
from workshop.generative_models.models.flow import Glow
# Configure Glow
config = ModelConfiguration(
name="glow_model",
model_class="workshop.generative_models.models.flow.Glow",
input_dim=(32, 32, 3), # Image shape
hidden_dims=[512, 512],
parameters={
"image_shape": (32, 32, 3),
"num_scales": 3, # Multi-scale architecture
"blocks_per_scale": 6, # Flow steps per scale
}
)
# Create Glow model
rngs = nnx.Rngs(params=0, sample=1)
model = Glow(config, rngs=rngs)
# Training
images = jax.random.normal(rngs.sample(), (16, 32, 32, 3))
outputs = model(images, rngs=rngs)
loss = -jnp.mean(outputs["log_prob"])
# Generate high-quality samples
samples = model.generate(n_samples=16, rngs=rngs)
Glow Architecture Parameters:
num_scales: Number of multi-scale levels (typically 2-4)blocks_per_scale: Flow steps at each scale (typically 4-8)- Higher values = more expressive but slower
MAF (Fast Density Estimation)¤
MAF (Masked Autoregressive Flow) excels at density estimation but has slow sampling.
from workshop.generative_models.models.flow import MAF
# Configure MAF
config = ModelConfiguration(
name="maf_model",
model_class="workshop.generative_models.models.flow.MAF",
input_dim=64,
output_dim=64,
hidden_dims=[512], # MADE hidden dimensions
parameters={
"num_layers": 5, # Number of MAF layers
"reverse_ordering": True, # Alternate variable ordering
}
)
# Create MAF model
rngs = nnx.Rngs(params=0, sample=1)
model = MAF(config, rngs=rngs)
# Fast forward pass (density estimation)
x = jax.random.normal(rngs.sample(), (100, 64))
log_prob = model.log_prob(x, rngs=rngs)
print(f"Mean log-likelihood: {jnp.mean(log_prob):.3f}")
# Slow inverse pass (sampling)
samples = model.sample(n_samples=10, rngs=rngs) # Sequential, slower
When to Use MAF:
- Primary goal is density estimation or anomaly detection
- Sampling speed is not critical
- Working with tabular or low-dimensional data
- Need high-quality likelihood estimates
IAF (Fast Sampling)¤
IAF (Inverse Autoregressive Flow) provides fast sampling at the cost of slow density estimation.
from workshop.generative_models.models.flow import IAF
# Configure IAF
config = ModelConfiguration(
name="iaf_model",
model_class="workshop.generative_models.models.flow.IAF",
input_dim=64,
output_dim=64,
hidden_dims=[512],
parameters={
"num_layers": 5,
"reverse_ordering": True,
}
)
# Create IAF model
rngs = nnx.Rngs(params=0, sample=1)
model = IAF(config, rngs=rngs)
# Fast sampling (parallel computation)
samples = model.sample(n_samples=100, rngs=rngs) # Fast!
# Slow density estimation (sequential)
log_prob = model.log_prob(samples, rngs=rngs) # Slower
When to Use IAF:
- Fast sampling is critical (real-time generation)
- Often used as variational posterior in VAEs
- Density estimation is secondary
- Generation frequency >> inference frequency
Neural Spline Flows (Most Expressive)¤
Neural Spline Flows use rational quadratic splines for highly expressive transformations.
from workshop.generative_models.models.flow import NeuralSplineFlow
# Configure Neural Spline Flow
config = ModelConfiguration(
name="spline_flow",
model_class="workshop.generative_models.models.flow.NeuralSplineFlow",
input_dim=64,
hidden_dims=[128, 128],
metadata={
"flow_params": {
"num_layers": 8,
"num_bins": 8, # Number of spline segments
"tail_bound": 3.0, # Spline domain bounds
"base_distribution": "normal",
}
}
)
# Create Neural Spline Flow
rngs = nnx.Rngs(params=0, sample=1)
model = NeuralSplineFlow(config, rngs=rngs)
# More expressive transformations
x = jax.random.normal(rngs.sample(), (32, 64))
log_prob = model.log_prob(x, rngs=rngs)
# Generate samples
samples = model.generate(n_samples=16, rngs=rngs)
Spline Parameters:
num_bins: Number of spline segments (8-16 typical)tail_bound: Domain bounds for spline (3.0-5.0 typical)- More bins = more expressive but higher memory cost
Training Flow Models¤
Basic Training Loop¤
Flow models are trained using maximum likelihood estimation.
import jax
import jax.numpy as jnp
from flax import nnx
from workshop.generative_models.models.flow import RealNVP
# Initialize model and optimizer
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
model = RealNVP(config, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-4))
# Training step function
@nnx.jit
def train_step(model, optimizer, batch, rngs):
"""Single training step."""
def loss_fn(model):
# Forward pass through flow
outputs = model(batch, rngs=rngs, training=True)
# Negative log-likelihood loss
log_prob = outputs["log_prob"]
loss = -jnp.mean(log_prob)
return loss, {"nll": loss, "mean_log_prob": jnp.mean(log_prob)}
# Compute loss and gradients
(loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
# Update parameters
optimizer.update(grads)
return metrics
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
epoch_metrics = []
for batch in train_dataloader:
# Preprocess: add uniform noise for dequantization
batch = batch + jax.random.uniform(rngs.sample(), batch.shape) / 256.0
# Scale to appropriate range
batch = (batch - 0.5) / 0.5 # Scale to [-1, 1]
# Training step
metrics = train_step(model, optimizer, batch, rngs)
epoch_metrics.append(metrics)
# Log epoch statistics
avg_nll = jnp.mean(jnp.array([m["nll"] for m in epoch_metrics]))
print(f"Epoch {epoch+1}/{num_epochs}, NLL: {avg_nll:.3f}")
Training with Gradient Clipping¤
Gradient clipping helps stabilize flow training:
import optax
# Create optimizer with gradient clipping
optimizer_chain = optax.chain(
optax.clip_by_global_norm(1.0), # Clip gradients
optax.adam(learning_rate=1e-4),
)
optimizer = nnx.Optimizer(model, optimizer_chain)
# Training step with clipping
@nnx.jit
def train_step_clipped(model, optimizer, batch, rngs):
def loss_fn(model):
outputs = model(batch, rngs=rngs, training=True)
loss = -jnp.mean(outputs["log_prob"])
return loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
# Optimizer applies gradient clipping automatically
optimizer.update(grads)
return {"loss": loss}
Learning Rate Scheduling¤
Use learning rate warmup and decay for better convergence:
import optax
# Learning rate schedule
warmup_steps = 1000
total_steps = 50000
schedule = optax.warmup_cosine_decay_schedule(
init_value=1e-7,
peak_value=1e-4,
warmup_steps=warmup_steps,
decay_steps=total_steps - warmup_steps,
end_value=1e-6,
)
# Create optimizer with schedule
optimizer = nnx.Optimizer(
model,
optax.adam(learning_rate=schedule)
)
# Track global step
global_step = 0
# Training loop
for epoch in range(num_epochs):
for batch in train_dataloader:
metrics = train_step(model, optimizer, batch, rngs)
global_step += 1
# Learning rate automatically updated by optax
Monitoring Training¤
Track important metrics during training:
# Training with metrics tracking
@nnx.jit
def train_step_with_metrics(model, optimizer, batch, rngs):
def loss_fn(model):
# Forward pass
z, log_det = model.forward(batch, rngs=rngs)
# Base distribution log prob
log_p_z = -0.5 * jnp.sum(z**2, axis=-1) - 0.5 * z.shape[-1] * jnp.log(2 * jnp.pi)
# Total log probability
log_prob = log_p_z + log_det
# Loss
loss = -jnp.mean(log_prob)
# Additional metrics
metrics = {
"loss": loss,
"log_p_z": jnp.mean(log_p_z),
"log_det": jnp.mean(log_det),
"log_prob": jnp.mean(log_prob),
"z_norm": jnp.mean(jnp.linalg.norm(z, axis=-1)),
}
return loss, metrics
(loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
optimizer.update(grads)
return metrics
# Training loop with logging
for epoch in range(num_epochs):
metrics_list = []
for batch in train_dataloader:
batch = preprocess(batch)
metrics = train_step_with_metrics(model, optimizer, batch, rngs)
metrics_list.append(metrics)
# Aggregate epoch metrics
epoch_metrics = {
k: jnp.mean(jnp.array([m[k] for m in metrics_list]))
for k in metrics_list[0].keys()
}
print(f"Epoch {epoch+1}: {epoch_metrics}")
Important Metrics to Monitor:
loss: Negative log-likelihood (should decrease)log_p_z: Base distribution log-prob (should be near 0 for standard Gaussian)log_det: Jacobian log-determinant (tracks transformation magnitude)z_norm: Latent norm (should be near √d for d-dimensional Gaussian)
Data Preprocessing for Flows¤
Proper preprocessing is crucial for flow models:
Dequantization¤
Images are discrete (0-255), but flows need continuous values:
def dequantize(images, rngs):
"""Add uniform noise to dequantize discrete images."""
# images should be in [0, 1] range
noise = jax.random.uniform(rngs.sample(), images.shape)
return images + noise / 256.0
# Apply during training
batch = dequantize(batch, rngs)
Logit Transform¤
Map bounded data to unbounded space:
def logit_transform(x, alpha=0.05):
"""Apply logit transform with boundary handling."""
# Squeeze to (alpha, 1-alpha) to avoid infinities
x = alpha + (1 - 2*alpha) * x
# Apply logit
return jnp.log(x) - jnp.log(1 - x)
def inverse_logit_transform(y, alpha=0.05):
"""Inverse of logit transform."""
x = jax.nn.sigmoid(y)
return (x - alpha) / (1 - 2*alpha)
# Apply during training
batch = logit_transform(batch)
Normalization¤
Standardize data to zero mean and unit variance:
# Compute statistics on training data
train_mean = jnp.mean(train_data, axis=0, keepdims=True)
train_std = jnp.std(train_data, axis=0, keepdims=True)
# Normalize
batch = (batch - train_mean) / (train_std + 1e-6)
# Remember to denormalize samples
samples = model.generate(n_samples=16, rngs=rngs)
samples = samples * train_std + train_mean
Sampling and Generation¤
Basic Sampling¤
Generate samples from a trained flow model:
# Generate samples
n_samples = 16
samples = model.generate(n_samples=n_samples, rngs=rngs)
# For image data, reshape if needed
if isinstance(config.input_dim, tuple):
# Already in correct shape
images = samples
else:
# Reshape to image dimensions
H, W, C = 28, 28, 1
images = samples.reshape(n_samples, H, W, C)
# Denormalize for visualization
images = (images * 0.5) + 0.5 # From [-1, 1] to [0, 1]
images = jnp.clip(images, 0, 1)
Temperature Sampling¤
Control sample diversity with temperature:
def sample_with_temperature(model, n_samples, temperature, rngs):
"""Sample with temperature scaling.
temperature > 1: More diverse samples
temperature < 1: More conservative samples
temperature = 1: Standard sampling
"""
# Sample from base distribution
z = jax.random.normal(rngs.sample(), (n_samples, model.latent_dim))
# Scale by temperature
z = z * temperature
# Transform to data space
samples, _ = model.inverse(z, rngs=rngs)
return samples
# Conservative samples (sharper, less diverse)
conservative = sample_with_temperature(model, 16, temperature=0.7, rngs=rngs)
# Diverse samples (more variety, less sharp)
diverse = sample_with_temperature(model, 16, temperature=1.3, rngs=rngs)
Conditional Sampling¤
Some flow architectures support conditional generation:
# For conditional flows (if implemented)
from workshop.generative_models.models.flow import ConditionalRealNVP
# Create conditional model
config = ModelConfiguration(
name="conditional_realnvp",
model_class="workshop.generative_models.models.flow.ConditionalRealNVP",
input_dim=784,
output_dim=784,
hidden_dims=[512, 512],
parameters={
"num_coupling_layers": 8,
"condition_dim": 10, # e.g., class labels
}
)
model = ConditionalRealNVP(config, rngs=rngs)
# Sample conditioned on class labels
class_labels = jax.nn.one_hot(jnp.array([0, 1, 2]), 10) # 3 classes
conditional_samples = model.generate(
n_samples=3,
condition=class_labels,
rngs=rngs
)
Interpolation in Latent Space¤
Interpolate between two data points:
def interpolate(model, x1, x2, num_steps, rngs):
"""Interpolate between two data points in latent space."""
# Encode to latent space
z1, _ = model.forward(x1[None, ...], rngs=rngs)
z2, _ = model.forward(x2[None, ...], rngs=rngs)
# Linear interpolation in latent space
alphas = jnp.linspace(0, 1, num_steps)
z_interp = jnp.array([
(1 - alpha) * z1 + alpha * z2
for alpha in alphas
]).squeeze(1)
# Decode to data space
x_interp, _ = model.inverse(z_interp, rngs=rngs)
return x_interp
# Interpolate between two images
x1 = train_data[0] # First image
x2 = train_data[1] # Second image
interpolations = interpolate(model, x1, x2, num_steps=10, rngs=rngs)
Density Estimation and Evaluation¤
Computing Log-Likelihood¤
Flow models provide exact log-likelihood:
# Compute log-likelihood for test data
test_data = ... # Your test dataset
log_likelihoods = []
for batch in test_dataloader:
# Preprocess same as training
batch = dequantize(batch, rngs)
batch = (batch - 0.5) / 0.5
# Compute log probability
log_prob = model.log_prob(batch, rngs=rngs)
log_likelihoods.append(log_prob)
# Average log-likelihood
all_log_probs = jnp.concatenate(log_likelihoods)
avg_log_likelihood = jnp.mean(all_log_probs)
print(f"Test log-likelihood: {avg_log_likelihood:.3f}")
# Bits per dimension (common metric)
input_dim = jnp.prod(jnp.array(config.input_dim))
bits_per_dim = -avg_log_likelihood / (input_dim * jnp.log(2))
print(f"Bits per dimension: {bits_per_dim:.3f}")
Anomaly Detection¤
Use log-likelihood for anomaly detection:
def detect_anomalies(model, data, threshold, rngs):
"""Detect anomalies using log-likelihood threshold."""
# Compute log probabilities
log_probs = model.log_prob(data, rngs=rngs)
# Flag samples below threshold as anomalies
is_anomaly = log_probs < threshold
return is_anomaly, log_probs
# Set threshold (e.g., 5th percentile of training data)
train_log_probs = model.log_prob(train_data, rngs=rngs)
threshold = jnp.percentile(train_log_probs, 5)
# Detect anomalies in test data
anomalies, test_log_probs = detect_anomalies(
model, test_data, threshold, rngs
)
print(f"Detected {jnp.sum(anomalies)} anomalies out of {len(test_data)} samples")
Model Comparison¤
Compare different flow architectures using likelihood:
# Train multiple models
models = {
"RealNVP": realnvp_model,
"Glow": glow_model,
"MAF": maf_model,
"Spline": spline_model,
}
# Evaluate on test set
results = {}
for name, model in models.items():
log_probs = []
for batch in test_dataloader:
batch = preprocess(batch)
log_prob = model.log_prob(batch, rngs=rngs)
log_probs.append(log_prob)
avg_log_prob = jnp.mean(jnp.concatenate(log_probs))
results[name] = avg_log_prob
print(f"{name}: {avg_log_prob:.3f} (higher is better)")
# Best model
best_model = max(results, key=results.get)
print(f"Best model: {best_model}")
Common Patterns¤
Multi-Modal Data Distribution¤
For data with multiple modes, increase model capacity:
# Increase number of layers
config = ModelConfiguration(
name="multimodal_flow",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=64,
output_dim=64,
hidden_dims=[1024, 1024, 1024], # Deeper networks
parameters={
"num_coupling_layers": 16, # More layers
}
)
# Or use Neural Spline Flows for higher expressiveness
config_spline = ModelConfiguration(
name="multimodal_spline",
model_class="workshop.generative_models.models.flow.NeuralSplineFlow",
input_dim=64,
hidden_dims=[256, 256],
metadata={
"flow_params": {
"num_layers": 12,
"num_bins": 16, # More bins for expressiveness
}
}
)
High-Dimensional Data¤
For very high-dimensional data (e.g., high-resolution images):
# Use Glow with multi-scale architecture
config = ModelConfiguration(
name="glow_highres",
model_class="workshop.generative_models.models.flow.Glow",
input_dim=(64, 64, 3),
hidden_dims=[512, 512],
parameters={
"image_shape": (64, 64, 3),
"num_scales": 4, # More scales for higher resolution
"blocks_per_scale": 8,
}
)
# Reduce memory by processing in patches
def train_on_patches(model, image, patch_size=32):
"""Train on image patches to reduce memory."""
H, W, C = image.shape
patches = []
for i in range(0, H, patch_size):
for j in range(0, W, patch_size):
patch = image[i:i+patch_size, j:j+patch_size, :]
patches.append(patch)
patches = jnp.array(patches)
return model(patches, rngs=rngs)
Tabular Data with Mixed Types¤
For tabular data with continuous and categorical features:
# Preprocess mixed types
def preprocess_tabular(data, categorical_indices):
"""Preprocess tabular data with mixed types."""
continuous = data.copy()
# One-hot encode categorical features
for idx in categorical_indices:
# One-hot encode
n_categories = int(jnp.max(data[:, idx])) + 1
one_hot = jax.nn.one_hot(data[:, idx].astype(int), n_categories)
# Replace categorical column with one-hot
continuous = jnp.concatenate([
continuous[:, :idx],
one_hot,
continuous[:, idx+1:],
], axis=1)
return continuous
# Use MAF for tabular data (good density estimation)
config = ModelConfiguration(
name="tabular_maf",
model_class="workshop.generative_models.models.flow.MAF",
input_dim=processed_dim,
hidden_dims=[512, 512],
parameters={
"num_layers": 10, # More layers for complex dependencies
}
)
Exact Reconstruction¤
Verify model invertibility:
def test_invertibility(model, x, rngs, tolerance=1e-4):
"""Test that forward and inverse are true inverses."""
# Forward then inverse
z, _ = model.forward(x, rngs=rngs)
x_recon, _ = model.inverse(z, rngs=rngs)
# Compute reconstruction error
error = jnp.max(jnp.abs(x - x_recon))
print(f"Max reconstruction error: {error:.6f}")
assert error < tolerance, f"Reconstruction error {error} exceeds tolerance {tolerance}"
# Test on random data
x = jax.random.normal(rngs.sample(), (10, 64))
test_invertibility(model, x, rngs)
Troubleshooting¤
Issue: NaN Loss During Training¤
Symptoms: Loss becomes NaN after a few iterations.
Solutions:
- Add gradient clipping:
- Check data preprocessing:
# Ensure data is properly scaled
assert jnp.all(jnp.isfinite(batch)), "Data contains NaN or Inf"
assert jnp.abs(jnp.mean(batch)) < 10, "Data not properly normalized"
- Reduce learning rate:
- Check Jacobian stability:
# Monitor log-determinant magnitude
z, log_det = model.forward(batch, rngs=rngs)
print(f"Log-det range: [{jnp.min(log_det):.2f}, {jnp.max(log_det):.2f}]")
# If log_det has extreme values, reduce model capacity or layers
Issue: Poor Sample Quality¤
Symptoms: Generated samples look noisy or unrealistic.
Solutions:
- Increase model capacity:
config = ModelConfiguration(
name="larger_model",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=784,
hidden_dims=[1024, 1024], # Larger networks
parameters={
"num_coupling_layers": 16, # More layers
}
)
- Use more expressive architecture:
- Improve data preprocessing:
- Train longer:
Issue: Slow Training¤
Symptoms: Training takes too long per iteration.
Solutions:
- Use JIT compilation:
- Reduce model complexity:
# Fewer coupling layers
parameters={"num_coupling_layers": 6} # Instead of 16
# Smaller hidden dimensions
hidden_dims=[256, 256] # Instead of [1024, 1024]
- Use IAF for fast sampling (if sampling is the bottleneck):
- Batch processing:
Issue: Mode Collapse¤
Symptoms: Model generates similar samples repeatedly.
Solutions:
- Check latent space coverage:
# Generate many samples and check latent codes
samples = model.generate(n_samples=1000, rngs=rngs)
z_samples, _ = model.forward(samples, rngs=rngs)
# Check if latents cover the expected distribution
z_mean = jnp.mean(z_samples, axis=0)
z_std = jnp.std(z_samples, axis=0)
print(f"Latent mean: {jnp.mean(z_mean):.3f} (should be ~0)")
print(f"Latent std: {jnp.mean(z_std):.3f} (should be ~1)")
- Increase model expressiveness:
- Check for numerical issues:
Issue: Memory Errors¤
Symptoms: Out of memory errors during training.
Solutions:
- Reduce batch size:
- Use gradient checkpointing (if available):
- Reduce model size:
# Fewer layers or smaller hidden dimensions
hidden_dims=[256, 256]
parameters={"num_coupling_layers": 6}
- Use mixed precision training:
Best Practices¤
DO¤
✅ Preprocess data properly:
# Always dequantize discrete data
batch = dequantize(batch, rngs)
# Normalize to appropriate range
batch = (batch - 0.5) / 0.5
✅ Monitor multiple metrics:
# Track loss, log_det, base log prob
metrics = {
"loss": loss,
"log_det": jnp.mean(log_det),
"log_p_z": jnp.mean(log_p_z),
}
✅ Use gradient clipping:
✅ Validate invertibility:
✅ Choose architecture for your task:
# MAF for density estimation
# IAF for fast sampling
# RealNVP for balance
# Glow for high-quality images
# Spline Flows for expressiveness
DON'T¤
❌ Don't skip data preprocessing:
# BAD: Using raw discrete images
model(raw_images, rngs=rngs) # Will perform poorly!
# GOOD: Dequantize and normalize
processed = dequantize(raw_images, rngs)
processed = (processed - 0.5) / 0.5
model(processed, rngs=rngs)
❌ Don't ignore numerical stability:
# BAD: No gradient clipping
# Can lead to NaN losses
# GOOD: Use gradient clipping
optimizer = nnx.Optimizer(
model,
optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4))
)
❌ Don't use wrong architecture for task:
# BAD: Using IAF when density estimation is primary goal
# IAF has slow forward pass
# GOOD: Use MAF for density estimation
model = MAF(config, rngs=rngs)
❌ Don't overtrain on small datasets:
Hyperparameter Guidelines¤
Number of Flow Layers¤
| Data Type | Recommended Layers | Notes |
|---|---|---|
| Tabular (low-dim) | 6-10 | Simpler distributions |
| Tabular (high-dim) | 10-16 | More complex dependencies |
| Images (low-res) | 8-12 | MNIST, CIFAR-10 |
| Images (high-res) | 12-24 | Higher resolution |
Hidden Dimensions¤
| Model | Recommended | Notes |
|---|---|---|
| RealNVP | [512, 512] | 2-3 layers sufficient |
| Glow | [512, 512] | Larger for high-res |
| MAF/IAF | [512] - [1024] | Single deep network |
| Spline | [128, 128] | Splines are expressive |
Learning Rates¤
| Stage | Learning Rate | Notes |
|---|---|---|
| Warmup | 1e-7 → 1e-4 | First 1000 steps |
| Training | 1e-4 | Standard rate |
| Fine-tuning | 1e-5 | Near convergence |
Batch Sizes¤
| Data Type | Batch Size | Notes |
|---|---|---|
| Tabular | 256-1024 | Can use large batches |
| Images (32×32) | 64-128 | Memory dependent |
| Images (64×64) | 32-64 | Reduce for Glow |
| Images (128×128) | 16-32 | Limited by memory |
Summary¤
Quick Reference¤
Model Selection:
- RealNVP: Balanced performance, good for most tasks
- Glow: Best for high-quality image generation
- MAF: Optimal for density estimation
- IAF: Optimal for fast sampling
- Spline Flows: Most expressive transformations
Training Checklist:
- ✅ Preprocess data (dequantize, normalize)
- ✅ Use gradient clipping
- ✅ Monitor multiple metrics
- ✅ Validate invertibility
- ✅ Apply learning rate warmup
- ✅ Check for NaN/Inf values
Common Workflows:
# Density estimation workflow
model = MAF(config, rngs=rngs)
log_probs = model.log_prob(data, rngs=rngs)
# Generation workflow
model = RealNVP(config, rngs=rngs)
samples = model.generate(n_samples=16, rngs=rngs)
# Anomaly detection workflow
model = MAF(config, rngs=rngs)
threshold = jnp.percentile(train_log_probs, 5)
anomalies = test_log_probs < threshold
Next Steps¤
- Theory: See Flow Concepts for mathematical foundations
- API Reference: Check Flow API for complete documentation
- Tutorial: Follow Flow MNIST Example for hands-on practice
References¤
- Dinh et al. (2016): "Density estimation using Real NVP"
- Kingma & Dhariwal (2018): "Glow: Generative Flow with Invertible 1x1 Convolutions"
- Papamakarios et al. (2017): "Masked Autoregressive Flow for Density Estimation"
- Durkan et al. (2019): "Neural Spline Flows"