Flow Models: MNIST Tutorial¤
Level: Beginner | Runtime: ~5-10 minutes (GPU), ~15-20 minutes (CPU) | Format: Python + Jupyter
This tutorial demonstrates normalizing flow models using RealNVP on MNIST. You'll learn about bijective transformations, exact likelihood computation, and how flow models differ from VAEs and GANs.
Files¤
- Python Script:
examples/generative_models/image/flow/flow_mnist.py - Jupyter Notebook:
examples/generative_models/image/flow/flow_mnist.ipynb
Dual-Format Implementation
This example is available in two synchronized formats:
- Python Script (.py) - For version control, IDE development, and CI/CD integration
- Jupyter Notebook (.ipynb) - For interactive learning, experimentation, and exploration
Both formats contain identical content and can be used interchangeably. Choose the format that best suits your workflow.
Quick Start¤
# Activate Workshop environment
source activate.sh
# Run the Python script
python examples/generative_models/image/flow/flow_mnist.py
# Or launch Jupyter notebook
jupyter lab examples/generative_models/image/flow/flow_mnist.ipynb
Overview¤
Learning Objectives:
- Understand dequantization for discrete data
- Create and configure a RealNVP flow model
- Verify model invertibility (bijective property)
- Compute exact log-probabilities
- Generate samples from flow models
- Visualize latent space representations
- Calculate bits per dimension metric
Prerequisites:
- Basic understanding of neural networks and generative models
- Familiarity with JAX and Flax NNX basics
- Understanding of probability distributions
- Workshop installed with CUDA support (recommended)
Estimated Time: 10-15 minutes
What's Covered¤
-
:material-transform: Bijective Transformations
RealNVP coupling layers for invertible mappings
-
Exact Likelihoods
Computing exact log-probabilities (unlike VAEs/GANs)
-
Dequantization
Converting discrete MNIST to continuous data
-
Sample Generation
Generating new digits from base distribution
-
Latent Space
Visualizing learned representations
-
Evaluation Metrics
Bits per dimension for model quality
Expected Results:
- Runtime: ~5-10 minutes on GPU, ~15-20 minutes on CPU
- Exact log-probabilities for test data
- Latent space visualizations
- Generated samples (abstract for untrained model)
- Bits/dim: ~7-8 (untrained), ~3-4 (trained)
Prerequisites¤
Installation¤
Imports¤
import jax
import jax.numpy as jnp
from flax import nnx
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP, MAF, IAF
# For reproducibility
jax.config.update("jax_platform_name", "gpu") # or "cpu"
Data Loading and Preprocessing¤
Flow models require continuous data, but MNIST images are discrete (0-255 integers). We need to apply dequantization.
import tensorflow_datasets as tfds
import numpy as np
def load_mnist_data():
"""Load and preprocess MNIST dataset."""
# Load dataset
ds_train, ds_test = tfds.load(
'mnist',
split=['train', 'test'],
as_supervised=True,
batch_size=-1
)
# Convert to numpy
train_images, train_labels = tfds.as_numpy(ds_train)
test_images, test_labels = tfds.as_numpy(ds_test)
# Normalize to [0, 1]
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0
# Flatten images to 784 dimensions
train_images = train_images.reshape(-1, 784)
test_images = test_images.reshape(-1, 784)
print(f"Train shape: {train_images.shape}") # (60000, 784)
print(f"Test shape: {test_images.shape}") # (10000, 784)
return train_images, train_labels, test_images, test_labels
# Load data
train_images, train_labels, test_images, test_labels = load_mnist_data()
Dequantization¤
Add uniform noise to convert discrete values to continuous:
def dequantize(images, key):
"""Add uniform noise to dequantize discrete images.
Args:
images: Images in [0, 1] range
key: JAX random key
Returns:
Dequantized continuous images
"""
noise = jax.random.uniform(key, images.shape)
return images + noise / 256.0
def preprocess_for_flow(images, key):
"""Complete preprocessing pipeline for flow models.
Args:
images: Images in [0, 1] range
key: JAX random key
Returns:
Preprocessed images
"""
# Dequantize
images = dequantize(images, key)
# Scale to [-1, 1] for better training
images = (images - 0.5) / 0.5
return images
# Test preprocessing
key = jax.random.key(0)
sample_batch = train_images[:10]
preprocessed = preprocess_for_flow(sample_batch, key)
print(f"Original range: [{train_images.min():.3f}, {train_images.max():.3f}]")
print(f"Preprocessed range: [{preprocessed.min():.3f}, {preprocessed.max():.3f}]")
Data Loader¤
Create a simple data loader:
class DataLoader:
"""Simple data loader for training."""
def __init__(self, images, batch_size=128, shuffle=True):
self.images = images
self.batch_size = batch_size
self.shuffle = shuffle
self.n_samples = len(images)
self.n_batches = self.n_samples // batch_size
def __iter__(self):
# Shuffle indices
if self.shuffle:
indices = jax.random.permutation(
jax.random.key(int(jax.random.uniform(jax.random.key(0)) * 1e6)),
self.n_samples
)
else:
indices = jnp.arange(self.n_samples)
# Yield batches
for i in range(self.n_batches):
batch_indices = indices[i * self.batch_size:(i + 1) * self.batch_size]
yield self.images[batch_indices]
def __len__(self):
return self.n_batches
# Create data loaders
train_loader = DataLoader(train_images, batch_size=128, shuffle=True)
test_loader = DataLoader(test_images, batch_size=128, shuffle=False)
Model Creation¤
Let's create a RealNVP model for MNIST:
from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP
# Initialize RNG streams
rngs = nnx.Rngs(params=0, dropout=1, sample=2)
# Configure RealNVP
config = ModelConfiguration(
name="realnvp_mnist",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=784, # 28*28 flattened
output_dim=784,
hidden_dims=[512, 512], # Coupling network architecture
parameters={
"num_coupling_layers": 8,
"mask_type": "checkerboard",
"base_distribution": "normal",
"base_distribution_params": {
"loc": 0.0,
"scale": 1.0,
},
}
)
# Create model
model = RealNVP(config, rngs=rngs)
print(f"Created RealNVP with {config.parameters['num_coupling_layers']} coupling layers")
Test Forward and Inverse¤
Verify the model works and is invertible:
def test_model(model, rngs):
"""Test model forward and inverse passes."""
# Create test batch
test_x = jax.random.normal(rngs.sample(), (16, 784))
# Forward pass (data -> latent)
z, log_det_fwd = model.forward(test_x, rngs=rngs)
print(f"Forward: x{test_x.shape} -> z{z.shape}, log_det{log_det_fwd.shape}")
# Inverse pass (latent -> data)
x_recon, log_det_inv = model.inverse(z, rngs=rngs)
print(f"Inverse: z{z.shape} -> x{x_recon.shape}, log_det{log_det_inv.shape}")
# Check invertibility
reconstruction_error = jnp.max(jnp.abs(test_x - x_recon))
print(f"Reconstruction error: {reconstruction_error:.6f}")
# Check log-determinants
log_det_sum = log_det_fwd + log_det_inv
print(f"Log-det sum (should be ~0): {jnp.mean(log_det_sum):.6f}")
# Compute log probability
log_prob = model.log_prob(test_x, rngs=rngs)
print(f"Log probability: {jnp.mean(log_prob):.3f}")
# Test the model
test_model(model, rngs)
Training¤
Training Setup¤
import optax
# Create optimizer with gradient clipping
learning_rate = 1e-4
optimizer_chain = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping
optax.adam(learning_rate),
)
optimizer = nnx.Optimizer(model, optimizer_chain)
# Training configuration
num_epochs = 10
batch_size = 128
print(f"Training for {num_epochs} epochs")
print(f"Batch size: {batch_size}")
print(f"Learning rate: {learning_rate}")
Training Step¤
@nnx.jit
def train_step(model, optimizer, batch, rngs):
"""Single training step.
Args:
model: Flow model
optimizer: Model optimizer
batch: Training batch
rngs: Random number generators
Returns:
Dictionary of metrics
"""
def loss_fn(model):
# Forward pass
outputs = model(batch, rngs=rngs, training=True)
# Negative log-likelihood loss
log_prob = outputs["log_prob"]
loss = -jnp.mean(log_prob)
# Additional metrics
metrics = {
"loss": loss,
"nll": loss,
"log_prob": jnp.mean(log_prob),
"log_det": jnp.mean(outputs["logdet"]),
}
return loss, metrics
# Compute loss and gradients
(loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
# Update parameters
optimizer.update(grads)
return metrics
# Test training step
test_batch = train_images[:batch_size]
test_batch_processed = preprocess_for_flow(test_batch, rngs.sample())
test_metrics = train_step(model, optimizer, test_batch_processed, rngs)
print(f"Test metrics: {test_metrics}")
Training Loop¤
def train_flow_model(model, optimizer, train_loader, num_epochs, rngs):
"""Train the flow model.
Args:
model: Flow model
optimizer: Optimizer
train_loader: Training data loader
num_epochs: Number of epochs
rngs: Random number generators
Returns:
Training history
"""
history = {
"loss": [],
"log_prob": [],
"log_det": [],
}
for epoch in range(num_epochs):
epoch_metrics = []
# Training loop
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
for batch in pbar:
# Preprocess batch
batch_processed = preprocess_for_flow(batch, rngs.sample())
# Training step
metrics = train_step(model, optimizer, batch_processed, rngs)
epoch_metrics.append(metrics)
# Update progress bar
pbar.set_postfix({
"loss": f"{metrics['loss']:.3f}",
"log_prob": f"{metrics['log_prob']:.3f}",
})
# Compute epoch averages
avg_metrics = {
k: float(jnp.mean(jnp.array([m[k] for m in epoch_metrics])))
for k in epoch_metrics[0].keys()
}
# Store history
for k in ["loss", "log_prob", "log_det"]:
history[k].append(avg_metrics[k])
# Print epoch summary
print(f"\nEpoch {epoch+1} Summary:")
print(f" Loss: {avg_metrics['loss']:.4f}")
print(f" Log Prob: {avg_metrics['log_prob']:.4f}")
print(f" Log Det: {avg_metrics['log_det']:.4f}")
return history
# Train the model
print("Starting training...")
history = train_flow_model(
model, optimizer, train_loader, num_epochs, rngs
)
print("Training complete!")
Visualize Training¤
def plot_training_history(history):
"""Plot training metrics."""
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Loss
axes[0].plot(history["loss"])
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Negative Log-Likelihood")
axes[0].set_title("Training Loss")
axes[0].grid(True)
# Log Probability
axes[1].plot(history["log_prob"])
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Log Probability")
axes[1].set_title("Average Log Probability")
axes[1].grid(True)
# Log Determinant
axes[2].plot(history["log_det"])
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Log Determinant")
axes[2].set_title("Average Log Determinant")
axes[2].grid(True)
plt.tight_layout()
plt.show()
# Plot training history
plot_training_history(history)
Evaluation¤
Compute Test Log-Likelihood¤
def evaluate_flow(model, test_loader, rngs):
"""Evaluate flow model on test set.
Args:
model: Trained flow model
test_loader: Test data loader
rngs: Random number generators
Returns:
Dictionary of evaluation metrics
"""
all_log_probs = []
for batch in tqdm(test_loader, desc="Evaluating"):
# Preprocess
batch_processed = preprocess_for_flow(batch, rngs.sample())
# Compute log probability
log_prob = model.log_prob(batch_processed, rngs=rngs)
all_log_probs.append(log_prob)
# Concatenate all log probabilities
all_log_probs = jnp.concatenate(all_log_probs)
# Compute metrics
avg_log_prob = float(jnp.mean(all_log_probs))
# Bits per dimension
input_dim = 784
bits_per_dim = -avg_log_prob / (input_dim * jnp.log(2))
return {
"avg_log_prob": avg_log_prob,
"bits_per_dim": float(bits_per_dim),
"all_log_probs": all_log_probs,
}
# Evaluate on test set
print("Evaluating model...")
eval_results = evaluate_flow(model, test_loader, rngs)
print(f"\nTest Set Evaluation:")
print(f" Average Log-Likelihood: {eval_results['avg_log_prob']:.4f}")
print(f" Bits per Dimension: {eval_results['bits_per_dim']:.4f}")
Visualize Log-Likelihood Distribution¤
def plot_log_likelihood_distribution(log_probs):
"""Plot distribution of log-likelihoods."""
plt.figure(figsize=(10, 5))
plt.hist(log_probs, bins=50, density=True, alpha=0.7, edgecolor='black')
plt.axvline(jnp.mean(log_probs), color='red', linestyle='--',
label=f'Mean: {jnp.mean(log_probs):.2f}')
plt.xlabel('Log-Likelihood')
plt.ylabel('Density')
plt.title('Distribution of Test Log-Likelihoods')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# Plot distribution
plot_log_likelihood_distribution(eval_results['all_log_probs'])
Generation and Visualization¤
Generate Samples¤
def generate_samples(model, n_samples, rngs):
"""Generate samples from the flow model.
Args:
model: Trained flow model
n_samples: Number of samples
rngs: Random number generators
Returns:
Generated samples
"""
# Generate from model
samples = model.generate(n_samples=n_samples, rngs=rngs)
# Denormalize from [-1, 1] to [0, 1]
samples = (samples * 0.5) + 0.5
# Clip to valid range
samples = jnp.clip(samples, 0, 1)
return samples
# Generate 64 samples
n_samples = 64
generated_samples = generate_samples(model, n_samples, rngs)
print(f"Generated {n_samples} samples")
print(f"Sample shape: {generated_samples.shape}")
print(f"Sample range: [{generated_samples.min():.3f}, {generated_samples.max():.3f}]")
Visualize Generated Samples¤
def visualize_samples(samples, n_rows=8, n_cols=8, title="Generated Samples"):
"""Visualize generated samples in a grid.
Args:
samples: Generated samples (flattened)
n_rows: Number of rows in grid
n_cols: Number of columns in grid
title: Plot title
"""
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows))
for i in range(n_rows):
for j in range(n_cols):
idx = i * n_cols + j
if idx < len(samples):
# Reshape to 28x28
image = samples[idx].reshape(28, 28)
axes[i, j].imshow(image, cmap='gray')
axes[i, j].axis('off')
plt.suptitle(title, fontsize=16, y=0.995)
plt.tight_layout()
plt.show()
# Visualize generated samples
visualize_samples(generated_samples, title="RealNVP Generated MNIST Digits")
Compare with Real Data¤
# Show real vs generated
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
# Real samples
real_samples = test_images[:8]
for i in range(8):
axes[0, i].imshow(real_samples[i].reshape(28, 28), cmap='gray')
axes[0, i].axis('off')
axes[0, 0].set_ylabel('Real', fontsize=12, rotation=0, labelpad=30)
# Generated samples
gen_samples = generate_samples(model, 8, rngs)
for i in range(8):
axes[1, i].imshow(gen_samples[i].reshape(28, 28), cmap='gray')
axes[1, i].axis('off')
axes[1, 0].set_ylabel('Generated', fontsize=12, rotation=0, labelpad=30)
plt.suptitle('Real vs Generated MNIST Digits', fontsize=14)
plt.tight_layout()
plt.show()
Latent Space Interpolation¤
Interpolate between two digits in latent space:
def interpolate_in_latent_space(model, x1, x2, num_steps, rngs):
"""Interpolate between two data points in latent space.
Args:
model: Trained flow model
x1: First data point
x2: Second data point
num_steps: Number of interpolation steps
rngs: Random number generators
Returns:
Interpolated samples
"""
# Encode to latent space
z1, _ = model.forward(x1[None, ...], rngs=rngs)
z2, _ = model.forward(x2[None, ...], rngs=rngs)
# Linear interpolation
alphas = jnp.linspace(0, 1, num_steps)
z_interp = jnp.array([
(1 - alpha) * z1 + alpha * z2
for alpha in alphas
]).squeeze(1)
# Decode to data space
x_interp, _ = model.inverse(z_interp, rngs=rngs)
# Denormalize
x_interp = (x_interp * 0.5) + 0.5
x_interp = jnp.clip(x_interp, 0, 1)
return x_interp
# Select two test images
idx1, idx2 = 0, 7 # First and eighth test images
test_batch = preprocess_for_flow(test_images[:10], rngs.sample())
x1, x2 = test_batch[idx1], test_batch[idx2]
# Interpolate
num_steps = 10
interpolations = interpolate_in_latent_space(model, x1, x2, num_steps, rngs)
# Visualize interpolation
fig, axes = plt.subplots(1, num_steps, figsize=(num_steps, 1))
for i in range(num_steps):
axes[i].imshow(interpolations[i].reshape(28, 28), cmap='gray')
axes[i].axis('off')
plt.suptitle(f'Latent Space Interpolation: {test_labels[idx1]} → {test_labels[idx2]}',
fontsize=14, y=1.05)
plt.tight_layout()
plt.show()
Temperature Sampling¤
Control sample diversity with temperature:
def sample_with_temperature(model, n_samples, temperature, rngs):
"""Generate samples with temperature scaling.
Args:
model: Trained flow model
n_samples: Number of samples
temperature: Temperature parameter (>1: more diverse, <1: more conservative)
rngs: Random number generators
Returns:
Generated samples
"""
# Sample from base distribution
z = jax.random.normal(rngs.sample(), (n_samples, 784))
# Apply temperature
z = z * temperature
# Decode to data space
samples, _ = model.inverse(z, rngs=rngs)
# Denormalize
samples = (samples * 0.5) + 0.5
samples = jnp.clip(samples, 0, 1)
return samples
# Generate samples with different temperatures
temperatures = [0.5, 0.7, 1.0, 1.3, 1.5]
n_samples_per_temp = 8
fig, axes = plt.subplots(len(temperatures), n_samples_per_temp,
figsize=(n_samples_per_temp, len(temperatures)))
for i, temp in enumerate(temperatures):
samples = sample_with_temperature(model, n_samples_per_temp, temp, rngs)
for j in range(n_samples_per_temp):
axes[i, j].imshow(samples[j].reshape(28, 28), cmap='gray')
axes[i, j].axis('off')
axes[i, 0].set_ylabel(f'T={temp}', fontsize=10, rotation=0, labelpad=30)
plt.suptitle('Temperature Sampling (Lower = More Conservative)', fontsize=14, y=0.995)
plt.tight_layout()
plt.show()
Anomaly Detection¤
Use log-likelihood for anomaly detection:
# Compute threshold from training data
train_subset = train_images[:5000]
train_subset_processed = preprocess_for_flow(train_subset, rngs.sample())
train_log_probs = model.log_prob(train_subset_processed, rngs=rngs)
# Set threshold at 5th percentile
threshold = float(jnp.percentile(train_log_probs, 5))
print(f"Anomaly threshold (5th percentile): {threshold:.3f}")
# Evaluate on test set
test_subset = test_images[:1000]
test_subset_processed = preprocess_for_flow(test_subset, rngs.sample())
test_log_probs = model.log_prob(test_subset_processed, rngs=rngs)
# Identify anomalies
is_anomaly = test_log_probs < threshold
n_anomalies = int(jnp.sum(is_anomaly))
print(f"Detected {n_anomalies} anomalies out of {len(test_subset)} samples")
print(f"Anomaly rate: {n_anomalies / len(test_subset) * 100:.2f}%")
# Visualize some anomalies
anomaly_indices = jnp.where(is_anomaly)[0]
if len(anomaly_indices) > 0:
fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i in range(min(8, len(anomaly_indices))):
idx = int(anomaly_indices[i])
image = test_subset[idx].reshape(28, 28)
log_prob = float(test_log_probs[idx])
axes[0, i].imshow(image, cmap='gray')
axes[0, i].set_title(f'LL: {log_prob:.1f}', fontsize=8)
axes[0, i].axis('off')
# Show normal samples for comparison
normal_indices = jnp.where(~is_anomaly)[0]
for i in range(8):
idx = int(normal_indices[i])
image = test_subset[idx].reshape(28, 28)
log_prob = float(test_log_probs[idx])
axes[1, i].imshow(image, cmap='gray')
axes[1, i].set_title(f'LL: {log_prob:.1f}', fontsize=8)
axes[1, i].axis('off')
axes[0, 0].set_ylabel('Anomalies', fontsize=10, rotation=0, labelpad=40)
axes[1, 0].set_ylabel('Normal', fontsize=10, rotation=0, labelpad=40)
plt.suptitle('Anomaly Detection (Low vs High Log-Likelihood)', fontsize=14)
plt.tight_layout()
plt.show()
Comparing Flow Architectures¤
Compare RealNVP, MAF, and IAF:
def compare_flow_models():
"""Compare different flow architectures on MNIST."""
# Create different models
models = {}
# RealNVP (already trained)
models["RealNVP"] = model
# MAF configuration
maf_config = ModelConfiguration(
name="maf_mnist",
model_class="workshop.generative_models.models.flow.MAF",
input_dim=784,
output_dim=784,
hidden_dims=[512],
parameters={
"num_layers": 5,
"reverse_ordering": True,
}
)
models["MAF"] = MAF(maf_config, rngs=rngs)
# IAF configuration
iaf_config = ModelConfiguration(
name="iaf_mnist",
model_class="workshop.generative_models.models.flow.IAF",
input_dim=784,
output_dim=784,
hidden_dims=[512],
parameters={
"num_layers": 5,
"reverse_ordering": True,
}
)
models["IAF"] = IAF(iaf_config, rngs=rngs)
# Note: For a fair comparison, you would train MAF and IAF too
# Here we just demonstrate the API
print("Model Comparison:")
print("-" * 60)
for name, flow_model in models.items():
# Test forward pass
test_batch = test_images[:100]
test_batch_processed = preprocess_for_flow(test_batch, rngs.sample())
try:
log_prob = flow_model.log_prob(test_batch_processed, rngs=rngs)
avg_ll = float(jnp.mean(log_prob))
print(f"{name:10s} - Avg Log-Likelihood: {avg_ll:.3f}")
except Exception as e:
print(f"{name:10s} - Error: {str(e)[:50]}")
print("-" * 60)
print("\nKey Differences:")
print(" RealNVP: Fast forward & inverse (good balance)")
print(" MAF: Fast forward (good for density estimation)")
print(" IAF: Fast inverse (good for sampling)")
# Compare models
compare_flow_models()
Troubleshooting¤
Issue: NaN Loss¤
If you encounter NaN losses:
# 1. Check data preprocessing
assert jnp.all(jnp.isfinite(train_images)), "Data contains NaN/Inf"
# 2. Reduce learning rate
optimizer = nnx.Optimizer(
model,
optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(1e-5) # Lower learning rate
)
)
# 3. Monitor log-determinants
def train_step_with_checks(model, optimizer, batch, rngs):
def loss_fn(model):
outputs = model(batch, rngs=rngs)
log_det = outputs["logdet"]
# Check for extreme values
if jnp.any(jnp.abs(log_det) > 100):
print(f"Warning: Large log-det: {jnp.max(jnp.abs(log_det)):.2f}")
loss = -jnp.mean(outputs["log_prob"])
return loss
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return {"loss": loss}
Issue: Poor Sample Quality¤
If generated samples look poor:
# 1. Train longer
num_epochs = 50 # More epochs
# 2. Increase model capacity
config = ModelConfiguration(
name="larger_realnvp",
model_class="workshop.generative_models.models.flow.RealNVP",
input_dim=784,
hidden_dims=[1024, 1024], # Larger networks
parameters={
"num_coupling_layers": 12, # More layers
}
)
# 3. Use Neural Spline Flows for more expressiveness
from workshop.generative_models.models.flow import NeuralSplineFlow
spline_config = ModelConfiguration(
name="spline_mnist",
model_class="workshop.generative_models.models.flow.NeuralSplineFlow",
input_dim=784,
hidden_dims=[256, 256],
metadata={
"flow_params": {
"num_layers": 12,
"num_bins": 16, # More bins
}
}
)
Summary¤
In this tutorial, you learned:
✅ Data Preprocessing: Dequantization and normalization for flow models
✅ Model Creation: Setting up RealNVP with proper configuration
✅ Training: Training loop with gradient clipping and monitoring
✅ Evaluation: Computing exact log-likelihoods and bits per dimension
✅ Generation: Sampling from the trained model
✅ Advanced Techniques:
- Latent space interpolation
- Temperature sampling
- Anomaly detection
✅ Model Comparison: Understanding different flow architectures
Next Steps¤
- Explore Other Models: Try Glow for higher quality or MAF for better density estimation
- Conditional Flows: Implement class-conditional generation
- High-Resolution: Apply flows to higher resolution images
- Hybrid Models: Combine flows with VAEs or diffusion models
Further Reading¤
- Theory: Flow Concepts
- User Guide: Flow Models Guide
- API: Flow API Reference
References¤
- Dinh et al. (2016): "Density estimation using Real NVP"
- Kingma & Dhariwal (2018): "Glow: Generative Flow with Invertible 1x1 Convolutions"
- Papamakarios et al. (2017): "Masked Autoregressive Flow for Density Estimation"