Skip to content

Flow Models: MNIST Tutorial¤

Level: Beginner | Runtime: ~5-10 minutes (GPU), ~15-20 minutes (CPU) | Format: Python + Jupyter

This tutorial demonstrates normalizing flow models using RealNVP on MNIST. You'll learn about bijective transformations, exact likelihood computation, and how flow models differ from VAEs and GANs.

Files¤

Dual-Format Implementation

This example is available in two synchronized formats:

  • Python Script (.py) - For version control, IDE development, and CI/CD integration
  • Jupyter Notebook (.ipynb) - For interactive learning, experimentation, and exploration

Both formats contain identical content and can be used interchangeably. Choose the format that best suits your workflow.

Quick Start¤

# Activate Workshop environment
source activate.sh

# Run the Python script
python examples/generative_models/image/flow/flow_mnist.py

# Or launch Jupyter notebook
jupyter lab examples/generative_models/image/flow/flow_mnist.ipynb

Overview¤

Learning Objectives:

  • Understand dequantization for discrete data
  • Create and configure a RealNVP flow model
  • Verify model invertibility (bijective property)
  • Compute exact log-probabilities
  • Generate samples from flow models
  • Visualize latent space representations
  • Calculate bits per dimension metric

Prerequisites:

  • Basic understanding of neural networks and generative models
  • Familiarity with JAX and Flax NNX basics
  • Understanding of probability distributions
  • Workshop installed with CUDA support (recommended)

Estimated Time: 10-15 minutes

What's Covered¤

  • :material-transform: Bijective Transformations


    RealNVP coupling layers for invertible mappings

  • Exact Likelihoods


    Computing exact log-probabilities (unlike VAEs/GANs)

  • Dequantization


    Converting discrete MNIST to continuous data

  • Sample Generation


    Generating new digits from base distribution

  • Latent Space


    Visualizing learned representations

  • Evaluation Metrics


    Bits per dimension for model quality

Expected Results:

  • Runtime: ~5-10 minutes on GPU, ~15-20 minutes on CPU
  • Exact log-probabilities for test data
  • Latent space visualizations
  • Generated samples (abstract for untrained model)
  • Bits/dim: ~7-8 (untrained), ~3-4 (trained)

Prerequisites¤

Installation¤

# Install Workshop with CUDA support (recommended)
uv sync --extra cuda-dev

# Or CPU-only
uv sync

Imports¤

import jax
import jax.numpy as jnp
from flax import nnx
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP, MAF, IAF

# For reproducibility
jax.config.update("jax_platform_name", "gpu")  # or "cpu"

Data Loading and Preprocessing¤

Flow models require continuous data, but MNIST images are discrete (0-255 integers). We need to apply dequantization.

import tensorflow_datasets as tfds
import numpy as np

def load_mnist_data():
    """Load and preprocess MNIST dataset."""
    # Load dataset
    ds_train, ds_test = tfds.load(
        'mnist',
        split=['train', 'test'],
        as_supervised=True,
        batch_size=-1
    )

    # Convert to numpy
    train_images, train_labels = tfds.as_numpy(ds_train)
    test_images, test_labels = tfds.as_numpy(ds_test)

    # Normalize to [0, 1]
    train_images = train_images.astype(np.float32) / 255.0
    test_images = test_images.astype(np.float32) / 255.0

    # Flatten images to 784 dimensions
    train_images = train_images.reshape(-1, 784)
    test_images = test_images.reshape(-1, 784)

    print(f"Train shape: {train_images.shape}")  # (60000, 784)
    print(f"Test shape: {test_images.shape}")    # (10000, 784)

    return train_images, train_labels, test_images, test_labels

# Load data
train_images, train_labels, test_images, test_labels = load_mnist_data()

Dequantization¤

Add uniform noise to convert discrete values to continuous:

def dequantize(images, key):
    """Add uniform noise to dequantize discrete images.

    Args:
        images: Images in [0, 1] range
        key: JAX random key

    Returns:
        Dequantized continuous images
    """
    noise = jax.random.uniform(key, images.shape)
    return images + noise / 256.0

def preprocess_for_flow(images, key):
    """Complete preprocessing pipeline for flow models.

    Args:
        images: Images in [0, 1] range
        key: JAX random key

    Returns:
        Preprocessed images
    """
    # Dequantize
    images = dequantize(images, key)

    # Scale to [-1, 1] for better training
    images = (images - 0.5) / 0.5

    return images

# Test preprocessing
key = jax.random.key(0)
sample_batch = train_images[:10]
preprocessed = preprocess_for_flow(sample_batch, key)

print(f"Original range: [{train_images.min():.3f}, {train_images.max():.3f}]")
print(f"Preprocessed range: [{preprocessed.min():.3f}, {preprocessed.max():.3f}]")

Data Loader¤

Create a simple data loader:

class DataLoader:
    """Simple data loader for training."""

    def __init__(self, images, batch_size=128, shuffle=True):
        self.images = images
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.n_samples = len(images)
        self.n_batches = self.n_samples // batch_size

    def __iter__(self):
        # Shuffle indices
        if self.shuffle:
            indices = jax.random.permutation(
                jax.random.key(int(jax.random.uniform(jax.random.key(0)) * 1e6)),
                self.n_samples
            )
        else:
            indices = jnp.arange(self.n_samples)

        # Yield batches
        for i in range(self.n_batches):
            batch_indices = indices[i * self.batch_size:(i + 1) * self.batch_size]
            yield self.images[batch_indices]

    def __len__(self):
        return self.n_batches

# Create data loaders
train_loader = DataLoader(train_images, batch_size=128, shuffle=True)
test_loader = DataLoader(test_images, batch_size=128, shuffle=False)

Model Creation¤

Let's create a RealNVP model for MNIST:

from workshop.generative_models.core.configuration import ModelConfiguration
from workshop.generative_models.models.flow import RealNVP

# Initialize RNG streams
rngs = nnx.Rngs(params=0, dropout=1, sample=2)

# Configure RealNVP
config = ModelConfiguration(
    name="realnvp_mnist",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=784,  # 28*28 flattened
    output_dim=784,
    hidden_dims=[512, 512],  # Coupling network architecture
    parameters={
        "num_coupling_layers": 8,
        "mask_type": "checkerboard",
        "base_distribution": "normal",
        "base_distribution_params": {
            "loc": 0.0,
            "scale": 1.0,
        },
    }
)

# Create model
model = RealNVP(config, rngs=rngs)

print(f"Created RealNVP with {config.parameters['num_coupling_layers']} coupling layers")

Test Forward and Inverse¤

Verify the model works and is invertible:

def test_model(model, rngs):
    """Test model forward and inverse passes."""
    # Create test batch
    test_x = jax.random.normal(rngs.sample(), (16, 784))

    # Forward pass (data -> latent)
    z, log_det_fwd = model.forward(test_x, rngs=rngs)
    print(f"Forward: x{test_x.shape} -> z{z.shape}, log_det{log_det_fwd.shape}")

    # Inverse pass (latent -> data)
    x_recon, log_det_inv = model.inverse(z, rngs=rngs)
    print(f"Inverse: z{z.shape} -> x{x_recon.shape}, log_det{log_det_inv.shape}")

    # Check invertibility
    reconstruction_error = jnp.max(jnp.abs(test_x - x_recon))
    print(f"Reconstruction error: {reconstruction_error:.6f}")

    # Check log-determinants
    log_det_sum = log_det_fwd + log_det_inv
    print(f"Log-det sum (should be ~0): {jnp.mean(log_det_sum):.6f}")

    # Compute log probability
    log_prob = model.log_prob(test_x, rngs=rngs)
    print(f"Log probability: {jnp.mean(log_prob):.3f}")

# Test the model
test_model(model, rngs)

Training¤

Training Setup¤

import optax

# Create optimizer with gradient clipping
learning_rate = 1e-4
optimizer_chain = optax.chain(
    optax.clip_by_global_norm(1.0),  # Gradient clipping
    optax.adam(learning_rate),
)

optimizer = nnx.Optimizer(model, optimizer_chain)

# Training configuration
num_epochs = 10
batch_size = 128

print(f"Training for {num_epochs} epochs")
print(f"Batch size: {batch_size}")
print(f"Learning rate: {learning_rate}")

Training Step¤

@nnx.jit
def train_step(model, optimizer, batch, rngs):
    """Single training step.

    Args:
        model: Flow model
        optimizer: Model optimizer
        batch: Training batch
        rngs: Random number generators

    Returns:
        Dictionary of metrics
    """
    def loss_fn(model):
        # Forward pass
        outputs = model(batch, rngs=rngs, training=True)

        # Negative log-likelihood loss
        log_prob = outputs["log_prob"]
        loss = -jnp.mean(log_prob)

        # Additional metrics
        metrics = {
            "loss": loss,
            "nll": loss,
            "log_prob": jnp.mean(log_prob),
            "log_det": jnp.mean(outputs["logdet"]),
        }

        return loss, metrics

    # Compute loss and gradients
    (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)

    # Update parameters
    optimizer.update(grads)

    return metrics

# Test training step
test_batch = train_images[:batch_size]
test_batch_processed = preprocess_for_flow(test_batch, rngs.sample())
test_metrics = train_step(model, optimizer, test_batch_processed, rngs)
print(f"Test metrics: {test_metrics}")

Training Loop¤

def train_flow_model(model, optimizer, train_loader, num_epochs, rngs):
    """Train the flow model.

    Args:
        model: Flow model
        optimizer: Optimizer
        train_loader: Training data loader
        num_epochs: Number of epochs
        rngs: Random number generators

    Returns:
        Training history
    """
    history = {
        "loss": [],
        "log_prob": [],
        "log_det": [],
    }

    for epoch in range(num_epochs):
        epoch_metrics = []

        # Training loop
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in pbar:
            # Preprocess batch
            batch_processed = preprocess_for_flow(batch, rngs.sample())

            # Training step
            metrics = train_step(model, optimizer, batch_processed, rngs)
            epoch_metrics.append(metrics)

            # Update progress bar
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.3f}",
                "log_prob": f"{metrics['log_prob']:.3f}",
            })

        # Compute epoch averages
        avg_metrics = {
            k: float(jnp.mean(jnp.array([m[k] for m in epoch_metrics])))
            for k in epoch_metrics[0].keys()
        }

        # Store history
        for k in ["loss", "log_prob", "log_det"]:
            history[k].append(avg_metrics[k])

        # Print epoch summary
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Loss: {avg_metrics['loss']:.4f}")
        print(f"  Log Prob: {avg_metrics['log_prob']:.4f}")
        print(f"  Log Det: {avg_metrics['log_det']:.4f}")

    return history

# Train the model
print("Starting training...")
history = train_flow_model(
    model, optimizer, train_loader, num_epochs, rngs
)
print("Training complete!")

Visualize Training¤

def plot_training_history(history):
    """Plot training metrics."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Loss
    axes[0].plot(history["loss"])
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Negative Log-Likelihood")
    axes[0].set_title("Training Loss")
    axes[0].grid(True)

    # Log Probability
    axes[1].plot(history["log_prob"])
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Log Probability")
    axes[1].set_title("Average Log Probability")
    axes[1].grid(True)

    # Log Determinant
    axes[2].plot(history["log_det"])
    axes[2].set_xlabel("Epoch")
    axes[2].set_ylabel("Log Determinant")
    axes[2].set_title("Average Log Determinant")
    axes[2].grid(True)

    plt.tight_layout()
    plt.show()

# Plot training history
plot_training_history(history)

Evaluation¤

Compute Test Log-Likelihood¤

def evaluate_flow(model, test_loader, rngs):
    """Evaluate flow model on test set.

    Args:
        model: Trained flow model
        test_loader: Test data loader
        rngs: Random number generators

    Returns:
        Dictionary of evaluation metrics
    """
    all_log_probs = []

    for batch in tqdm(test_loader, desc="Evaluating"):
        # Preprocess
        batch_processed = preprocess_for_flow(batch, rngs.sample())

        # Compute log probability
        log_prob = model.log_prob(batch_processed, rngs=rngs)
        all_log_probs.append(log_prob)

    # Concatenate all log probabilities
    all_log_probs = jnp.concatenate(all_log_probs)

    # Compute metrics
    avg_log_prob = float(jnp.mean(all_log_probs))

    # Bits per dimension
    input_dim = 784
    bits_per_dim = -avg_log_prob / (input_dim * jnp.log(2))

    return {
        "avg_log_prob": avg_log_prob,
        "bits_per_dim": float(bits_per_dim),
        "all_log_probs": all_log_probs,
    }

# Evaluate on test set
print("Evaluating model...")
eval_results = evaluate_flow(model, test_loader, rngs)

print(f"\nTest Set Evaluation:")
print(f"  Average Log-Likelihood: {eval_results['avg_log_prob']:.4f}")
print(f"  Bits per Dimension: {eval_results['bits_per_dim']:.4f}")

Visualize Log-Likelihood Distribution¤

def plot_log_likelihood_distribution(log_probs):
    """Plot distribution of log-likelihoods."""
    plt.figure(figsize=(10, 5))

    plt.hist(log_probs, bins=50, density=True, alpha=0.7, edgecolor='black')
    plt.axvline(jnp.mean(log_probs), color='red', linestyle='--',
                label=f'Mean: {jnp.mean(log_probs):.2f}')
    plt.xlabel('Log-Likelihood')
    plt.ylabel('Density')
    plt.title('Distribution of Test Log-Likelihoods')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# Plot distribution
plot_log_likelihood_distribution(eval_results['all_log_probs'])

Generation and Visualization¤

Generate Samples¤

def generate_samples(model, n_samples, rngs):
    """Generate samples from the flow model.

    Args:
        model: Trained flow model
        n_samples: Number of samples
        rngs: Random number generators

    Returns:
        Generated samples
    """
    # Generate from model
    samples = model.generate(n_samples=n_samples, rngs=rngs)

    # Denormalize from [-1, 1] to [0, 1]
    samples = (samples * 0.5) + 0.5

    # Clip to valid range
    samples = jnp.clip(samples, 0, 1)

    return samples

# Generate 64 samples
n_samples = 64
generated_samples = generate_samples(model, n_samples, rngs)

print(f"Generated {n_samples} samples")
print(f"Sample shape: {generated_samples.shape}")
print(f"Sample range: [{generated_samples.min():.3f}, {generated_samples.max():.3f}]")

Visualize Generated Samples¤

def visualize_samples(samples, n_rows=8, n_cols=8, title="Generated Samples"):
    """Visualize generated samples in a grid.

    Args:
        samples: Generated samples (flattened)
        n_rows: Number of rows in grid
        n_cols: Number of columns in grid
        title: Plot title
    """
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows))

    for i in range(n_rows):
        for j in range(n_cols):
            idx = i * n_cols + j
            if idx < len(samples):
                # Reshape to 28x28
                image = samples[idx].reshape(28, 28)
                axes[i, j].imshow(image, cmap='gray')
            axes[i, j].axis('off')

    plt.suptitle(title, fontsize=16, y=0.995)
    plt.tight_layout()
    plt.show()

# Visualize generated samples
visualize_samples(generated_samples, title="RealNVP Generated MNIST Digits")

Compare with Real Data¤

# Show real vs generated
fig, axes = plt.subplots(2, 8, figsize=(12, 3))

# Real samples
real_samples = test_images[:8]
for i in range(8):
    axes[0, i].imshow(real_samples[i].reshape(28, 28), cmap='gray')
    axes[0, i].axis('off')
axes[0, 0].set_ylabel('Real', fontsize=12, rotation=0, labelpad=30)

# Generated samples
gen_samples = generate_samples(model, 8, rngs)
for i in range(8):
    axes[1, i].imshow(gen_samples[i].reshape(28, 28), cmap='gray')
    axes[1, i].axis('off')
axes[1, 0].set_ylabel('Generated', fontsize=12, rotation=0, labelpad=30)

plt.suptitle('Real vs Generated MNIST Digits', fontsize=14)
plt.tight_layout()
plt.show()

Latent Space Interpolation¤

Interpolate between two digits in latent space:

def interpolate_in_latent_space(model, x1, x2, num_steps, rngs):
    """Interpolate between two data points in latent space.

    Args:
        model: Trained flow model
        x1: First data point
        x2: Second data point
        num_steps: Number of interpolation steps
        rngs: Random number generators

    Returns:
        Interpolated samples
    """
    # Encode to latent space
    z1, _ = model.forward(x1[None, ...], rngs=rngs)
    z2, _ = model.forward(x2[None, ...], rngs=rngs)

    # Linear interpolation
    alphas = jnp.linspace(0, 1, num_steps)
    z_interp = jnp.array([
        (1 - alpha) * z1 + alpha * z2
        for alpha in alphas
    ]).squeeze(1)

    # Decode to data space
    x_interp, _ = model.inverse(z_interp, rngs=rngs)

    # Denormalize
    x_interp = (x_interp * 0.5) + 0.5
    x_interp = jnp.clip(x_interp, 0, 1)

    return x_interp

# Select two test images
idx1, idx2 = 0, 7  # First and eighth test images
test_batch = preprocess_for_flow(test_images[:10], rngs.sample())
x1, x2 = test_batch[idx1], test_batch[idx2]

# Interpolate
num_steps = 10
interpolations = interpolate_in_latent_space(model, x1, x2, num_steps, rngs)

# Visualize interpolation
fig, axes = plt.subplots(1, num_steps, figsize=(num_steps, 1))
for i in range(num_steps):
    axes[i].imshow(interpolations[i].reshape(28, 28), cmap='gray')
    axes[i].axis('off')

plt.suptitle(f'Latent Space Interpolation: {test_labels[idx1]}{test_labels[idx2]}',
             fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

Temperature Sampling¤

Control sample diversity with temperature:

def sample_with_temperature(model, n_samples, temperature, rngs):
    """Generate samples with temperature scaling.

    Args:
        model: Trained flow model
        n_samples: Number of samples
        temperature: Temperature parameter (>1: more diverse, <1: more conservative)
        rngs: Random number generators

    Returns:
        Generated samples
    """
    # Sample from base distribution
    z = jax.random.normal(rngs.sample(), (n_samples, 784))

    # Apply temperature
    z = z * temperature

    # Decode to data space
    samples, _ = model.inverse(z, rngs=rngs)

    # Denormalize
    samples = (samples * 0.5) + 0.5
    samples = jnp.clip(samples, 0, 1)

    return samples

# Generate samples with different temperatures
temperatures = [0.5, 0.7, 1.0, 1.3, 1.5]
n_samples_per_temp = 8

fig, axes = plt.subplots(len(temperatures), n_samples_per_temp,
                         figsize=(n_samples_per_temp, len(temperatures)))

for i, temp in enumerate(temperatures):
    samples = sample_with_temperature(model, n_samples_per_temp, temp, rngs)

    for j in range(n_samples_per_temp):
        axes[i, j].imshow(samples[j].reshape(28, 28), cmap='gray')
        axes[i, j].axis('off')

    axes[i, 0].set_ylabel(f'T={temp}', fontsize=10, rotation=0, labelpad=30)

plt.suptitle('Temperature Sampling (Lower = More Conservative)', fontsize=14, y=0.995)
plt.tight_layout()
plt.show()

Anomaly Detection¤

Use log-likelihood for anomaly detection:

# Compute threshold from training data
train_subset = train_images[:5000]
train_subset_processed = preprocess_for_flow(train_subset, rngs.sample())
train_log_probs = model.log_prob(train_subset_processed, rngs=rngs)

# Set threshold at 5th percentile
threshold = float(jnp.percentile(train_log_probs, 5))
print(f"Anomaly threshold (5th percentile): {threshold:.3f}")

# Evaluate on test set
test_subset = test_images[:1000]
test_subset_processed = preprocess_for_flow(test_subset, rngs.sample())
test_log_probs = model.log_prob(test_subset_processed, rngs=rngs)

# Identify anomalies
is_anomaly = test_log_probs < threshold
n_anomalies = int(jnp.sum(is_anomaly))

print(f"Detected {n_anomalies} anomalies out of {len(test_subset)} samples")
print(f"Anomaly rate: {n_anomalies / len(test_subset) * 100:.2f}%")

# Visualize some anomalies
anomaly_indices = jnp.where(is_anomaly)[0]
if len(anomaly_indices) > 0:
    fig, axes = plt.subplots(2, 8, figsize=(12, 3))

    for i in range(min(8, len(anomaly_indices))):
        idx = int(anomaly_indices[i])
        image = test_subset[idx].reshape(28, 28)
        log_prob = float(test_log_probs[idx])

        axes[0, i].imshow(image, cmap='gray')
        axes[0, i].set_title(f'LL: {log_prob:.1f}', fontsize=8)
        axes[0, i].axis('off')

    # Show normal samples for comparison
    normal_indices = jnp.where(~is_anomaly)[0]
    for i in range(8):
        idx = int(normal_indices[i])
        image = test_subset[idx].reshape(28, 28)
        log_prob = float(test_log_probs[idx])

        axes[1, i].imshow(image, cmap='gray')
        axes[1, i].set_title(f'LL: {log_prob:.1f}', fontsize=8)
        axes[1, i].axis('off')

    axes[0, 0].set_ylabel('Anomalies', fontsize=10, rotation=0, labelpad=40)
    axes[1, 0].set_ylabel('Normal', fontsize=10, rotation=0, labelpad=40)

    plt.suptitle('Anomaly Detection (Low vs High Log-Likelihood)', fontsize=14)
    plt.tight_layout()
    plt.show()

Comparing Flow Architectures¤

Compare RealNVP, MAF, and IAF:

def compare_flow_models():
    """Compare different flow architectures on MNIST."""
    # Create different models
    models = {}

    # RealNVP (already trained)
    models["RealNVP"] = model

    # MAF configuration
    maf_config = ModelConfiguration(
        name="maf_mnist",
        model_class="workshop.generative_models.models.flow.MAF",
        input_dim=784,
        output_dim=784,
        hidden_dims=[512],
        parameters={
            "num_layers": 5,
            "reverse_ordering": True,
        }
    )
    models["MAF"] = MAF(maf_config, rngs=rngs)

    # IAF configuration
    iaf_config = ModelConfiguration(
        name="iaf_mnist",
        model_class="workshop.generative_models.models.flow.IAF",
        input_dim=784,
        output_dim=784,
        hidden_dims=[512],
        parameters={
            "num_layers": 5,
            "reverse_ordering": True,
        }
    )
    models["IAF"] = IAF(iaf_config, rngs=rngs)

    # Note: For a fair comparison, you would train MAF and IAF too
    # Here we just demonstrate the API

    print("Model Comparison:")
    print("-" * 60)

    for name, flow_model in models.items():
        # Test forward pass
        test_batch = test_images[:100]
        test_batch_processed = preprocess_for_flow(test_batch, rngs.sample())

        try:
            log_prob = flow_model.log_prob(test_batch_processed, rngs=rngs)
            avg_ll = float(jnp.mean(log_prob))

            print(f"{name:10s} - Avg Log-Likelihood: {avg_ll:.3f}")
        except Exception as e:
            print(f"{name:10s} - Error: {str(e)[:50]}")

    print("-" * 60)
    print("\nKey Differences:")
    print("  RealNVP: Fast forward & inverse (good balance)")
    print("  MAF:     Fast forward (good for density estimation)")
    print("  IAF:     Fast inverse (good for sampling)")

# Compare models
compare_flow_models()

Troubleshooting¤

Issue: NaN Loss¤

If you encounter NaN losses:

# 1. Check data preprocessing
assert jnp.all(jnp.isfinite(train_images)), "Data contains NaN/Inf"

# 2. Reduce learning rate
optimizer = nnx.Optimizer(
    model,
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(1e-5)  # Lower learning rate
    )
)

# 3. Monitor log-determinants
def train_step_with_checks(model, optimizer, batch, rngs):
    def loss_fn(model):
        outputs = model(batch, rngs=rngs)
        log_det = outputs["logdet"]

        # Check for extreme values
        if jnp.any(jnp.abs(log_det) > 100):
            print(f"Warning: Large log-det: {jnp.max(jnp.abs(log_det)):.2f}")

        loss = -jnp.mean(outputs["log_prob"])
        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return {"loss": loss}

Issue: Poor Sample Quality¤

If generated samples look poor:

# 1. Train longer
num_epochs = 50  # More epochs

# 2. Increase model capacity
config = ModelConfiguration(
    name="larger_realnvp",
    model_class="workshop.generative_models.models.flow.RealNVP",
    input_dim=784,
    hidden_dims=[1024, 1024],  # Larger networks
    parameters={
        "num_coupling_layers": 12,  # More layers
    }
)

# 3. Use Neural Spline Flows for more expressiveness
from workshop.generative_models.models.flow import NeuralSplineFlow

spline_config = ModelConfiguration(
    name="spline_mnist",
    model_class="workshop.generative_models.models.flow.NeuralSplineFlow",
    input_dim=784,
    hidden_dims=[256, 256],
    metadata={
        "flow_params": {
            "num_layers": 12,
            "num_bins": 16,  # More bins
        }
    }
)

Summary¤

In this tutorial, you learned:

Data Preprocessing: Dequantization and normalization for flow models

Model Creation: Setting up RealNVP with proper configuration

Training: Training loop with gradient clipping and monitoring

Evaluation: Computing exact log-likelihoods and bits per dimension

Generation: Sampling from the trained model

Advanced Techniques:

  • Latent space interpolation
  • Temperature sampling
  • Anomaly detection

Model Comparison: Understanding different flow architectures

Next Steps¤

  • Explore Other Models: Try Glow for higher quality or MAF for better density estimation
  • Conditional Flows: Implement class-conditional generation
  • High-Resolution: Apply flows to higher resolution images
  • Hybrid Models: Combine flows with VAEs or diffusion models

Further Reading¤

References¤

  • Dinh et al. (2016): "Density estimation using Real NVP"
  • Kingma & Dhariwal (2018): "Glow: Generative Flow with Invertible 1x1 Convolutions"
  • Papamakarios et al. (2017): "Masked Autoregressive Flow for Density Estimation"