Skip to content

Advanced Flow Examples¤

This guide demonstrates advanced normalizing flow architectures using Workshop and Flax NNX, including continuous normalizing flows (CNF), FFJORD, custom coupling flows, and conditional flows.

Overview¤

  • Continuous Normalizing Flows


    Neural ODEs for flexible continuous-time transformations

    CNF

  • FFJORD


    Free-form Jacobian of Reversible Dynamics with efficient trace estimation

    FFJORD

  • Advanced Coupling Flows


    Custom coupling architectures with attention and residual connections

    Coupling Flows

  • Conditional Flows


    Conditional generation with class, text, or image inputs

    Conditional Flows

Prerequisites¤

# Install Workshop with all dependencies
uv pip install "workshop[cuda]"  # With GPU support
# or
uv pip install workshop  # CPU only
import jax
import jax.numpy as jnp
from flax import nnx
import optax
from workshop.generative_models.core import DeviceManager
from workshop.generative_models.models.flow import FlowModel

Continuous Normalizing Flows (CNF)¤

CNFs use neural ODEs to learn continuous-time transformations, providing more flexibility than discrete flows.

Architecture¤

graph LR
    A[z ~ N(0,I)] --> B[ODE Solver]
    B --> C[x = z_T]
    C --> D[Data]

    style A fill:#e1f5ff
    style B fill:#f3e5f5
    style C fill:#e8f5e9
    style D fill:#fff3e0

CNF Implementation¤

from flax import nnx
import jax
import jax.numpy as jnp
from functools import partial

class ContinuousNormalizingFlow(nnx.Module):
    """Continuous Normalizing Flow using Neural ODEs."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 128,
        num_layers: int = 3,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        # Dynamics network f(z, t)
        # Maps (z, t) -> dz/dt
        layers = []
        for i in range(num_layers):
            in_dim = input_dim + 1 if i == 0 else hidden_dim  # +1 for time
            out_dim = input_dim if i == num_layers - 1 else hidden_dim

            layers.append(nnx.Linear(in_dim, out_dim, rngs=rngs))

            if i < num_layers - 1:
                layers.append(nnx.Lambda(lambda x: nnx.tanh(x)))

        self.dynamics_net = nnx.Sequential(*layers)

    def dynamics(self, t: float, z: jax.Array) -> jax.Array:
        """
        Compute dz/dt at time t.

        Args:
            t: Current time (scalar)
            z: Current state [batch, input_dim]

        Returns:
            Time derivative dz/dt [batch, input_dim]
        """
        batch_size = z.shape[0]

        # Concatenate time to input
        t_expanded = jnp.full((batch_size, 1), t)
        z_t = jnp.concatenate([z, t_expanded], axis=-1)

        # Compute dynamics
        return self.dynamics_net(z_t)

    def forward(
        self,
        z0: jax.Array,
        t_span: tuple[float, float] = (0.0, 1.0),
        num_steps: int = 100,
    ) -> jax.Array:
        """
        Integrate from z0 at t=t0 to t=t1.

        Args:
            z0: Initial state [batch, input_dim]
            t_span: Time interval (t0, t1)
            num_steps: Number of integration steps

        Returns:
            Final state z1 [batch, input_dim]
        """
        from jax.experimental.ode import odeint

        # Time points
        t_eval = jnp.linspace(t_span[0], t_span[1], num_steps)

        # Solve ODE: dz/dt = f(z, t)
        def ode_func(z, t):
            return self.dynamics(t, z)

        # Integrate (returns [num_steps, batch, input_dim])
        z_trajectory = odeint(ode_func, z0, t_eval)

        # Return final state
        return z_trajectory[-1]

    def inverse(
        self,
        z1: jax.Array,
        t_span: tuple[float, float] = (1.0, 0.0),
        num_steps: int = 100,
    ) -> jax.Array:
        """
        Integrate backwards from z1 at t=t1 to t=t0.

        Args:
            z1: Final state [batch, input_dim]
            t_span: Time interval (t1, t0) - reversed
            num_steps: Number of integration steps

        Returns:
            Initial state z0 [batch, input_dim]
        """
        return self.forward(z1, t_span, num_steps)

    def log_prob(
        self,
        x: jax.Array,
        base_log_prob_fn,
        num_steps: int = 100,
    ) -> jax.Array:
        """
        Compute log probability using instantaneous change of variables.

        Args:
            x: Data samples [batch, input_dim]
            base_log_prob_fn: Log probability function for base distribution
            num_steps: Integration steps

        Returns:
            Log probabilities [batch]
        """
        from jax.experimental.ode import odeint

        batch_size = x.shape[0]

        # Integrate backwards to get z0 and log determinant
        def augmented_dynamics(augmented_state, t):
            z, _ = augmented_state
            dz_dt = self.dynamics(t, z)

            # Trace of Jacobian (Hutchinson's trace estimator in FFJORD)
            # For exact computation (expensive):
            def dynamics_fn(z_single):
                return self.dynamics(t, z_single[None, :])[0]

            jacobian = jax.jacfwd(dynamics_fn)(z)
            trace = jnp.trace(jacobian)

            return dz_dt, -trace  # Negative for inverse

        # Initial augmented state
        initial_state = (x, jnp.zeros(batch_size))

        # Integrate
        t_eval = jnp.linspace(1.0, 0.0, num_steps)

        def ode_func(state, t):
            return augmented_dynamics(state, t)

        trajectory = odeint(ode_func, initial_state, t_eval)

        z0, log_det_jacobian = trajectory[0], trajectory[1]

        # Compute log probability
        base_log_prob = base_log_prob_fn(z0)
        return base_log_prob + log_det_jacobian


def train_cnf(
    model: ContinuousNormalizingFlow,
    train_data: jnp.ndarray,
    num_epochs: int = 100,
    batch_size: int = 128,
):
    """Train continuous normalizing flow."""

    rngs = nnx.Rngs(42)
    optimizer = nnx.Optimizer(model, optax.adam(1e-3))

    # Base distribution (standard normal)
    def base_log_prob(z):
        return -0.5 * jnp.sum(z ** 2, axis=-1) - 0.5 * z.shape[-1] * jnp.log(2 * jnp.pi)

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0

        for batch_idx in range(0, len(train_data), batch_size):
            batch = train_data[batch_idx:batch_idx + batch_size]

            # Compute negative log likelihood
            log_probs = model.log_prob(batch, base_log_prob, num_steps=50)
            loss = -jnp.mean(log_probs)

            # Update
            optimizer.update(jax.grad(lambda m: loss)(model))

            epoch_loss += loss
            num_batches += 1

        avg_loss = epoch_loss / num_batches

        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{num_epochs}, Loss: {avg_loss:.4f}")

    return model


def sample_from_cnf(
    model: ContinuousNormalizingFlow,
    num_samples: int,
    *,
    rngs: nnx.Rngs,
) -> jax.Array:
    """Sample from learned distribution."""

    # Sample from base distribution
    z0 = jax.random.normal(rngs.sample(), (num_samples, model.input_dim))

    # Transform to data space
    x = model.forward(z0, t_span=(0.0, 1.0), num_steps=100)

    return x

FFJORD¤

FFJORD (Free-Form Jacobian of Reversible Dynamics) uses Hutchinson's trace estimator for efficient computation of log determinants.

Hutchinson's Trace Estimator¤

class FFJORD(nnx.Module):
    """
    FFJORD: Scalable Continuous Normalizing Flow.

    Uses Hutchinson's trace estimator for O(1) memory complexity.
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 64,
        num_layers: int = 3,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.input_dim = input_dim

        # Time-conditioned dynamics network
        self.dynamics_net = self._build_dynamics_net(
            input_dim,
            hidden_dim,
            num_layers,
            rngs,
        )

    def _build_dynamics_net(
        self,
        input_dim: int,
        hidden_dim: int,
        num_layers: int,
        rngs: nnx.Rngs,
    ) -> nnx.Module:
        """Build dynamics network with time conditioning."""

        class TimeConditionedMLP(nnx.Module):
            def __init__(self, in_dim, hidden, n_layers, *, rngs):
                super().__init__()

                layers = []
                for i in range(n_layers):
                    layer_in = in_dim + 1 if i == 0 else hidden  # +1 for time
                    layer_out = in_dim if i == n_layers - 1 else hidden

                    layers.append(nnx.Linear(layer_in, layer_out, rngs=rngs))

                    if i < n_layers - 1:
                        layers.append(nnx.Lambda(lambda x: nnx.softplus(x)))

                self.net = nnx.Sequential(*layers)

            def __call__(self, z, t):
                batch_size = z.shape[0]
                t_expanded = jnp.full((batch_size, 1), t)
                z_t = jnp.concatenate([z, t_expanded], axis=-1)
                return self.net(z_t)

        return TimeConditionedMLP(input_dim, hidden_dim, num_layers, rngs=rngs)

    def dynamics(self, t: float, z: jax.Array) -> jax.Array:
        """Compute dz/dt."""
        return self.dynamics_net(z, t)

    def divergence_approx(
        self,
        t: float,
        z: jax.Array,
        epsilon: jax.Array,
    ) -> jax.Array:
        """
        Approximate divergence using Hutchinson's trace estimator.

        Tr(J) ≈ E[ε^T J ε] where ε ~ N(0, I)

        Args:
            t: Time
            z: State [batch, dim]
            epsilon: Random vector [batch, dim]

        Returns:
            Trace estimate [batch]
        """

        def dynamics_fn(z_single):
            return self.dynamics(t, z_single[None, :])[0]

        # Compute Jacobian-vector product efficiently
        _, jvp_result = jax.jvp(dynamics_fn, (z,), (epsilon,))

        # Hutchinson estimator: ε^T J ε
        trace_estimate = jnp.sum(epsilon * jvp_result, axis=-1)

        return trace_estimate

    def ode_with_log_prob(
        self,
        z: jax.Array,
        t_span: tuple[float, float] = (0.0, 1.0),
        num_steps: int = 100,
        *,
        rngs: nnx.Rngs,
    ) -> tuple[jax.Array, jax.Array]:
        """
        Integrate ODE and compute log probability.

        Args:
            z: Initial/final state [batch, dim]
            t_span: Time interval
            num_steps: Integration steps
            rngs: For trace estimation

        Returns:
            Tuple of (final_state, log_determinant)
        """
        from jax.experimental.ode import odeint

        batch_size = z.shape[0]

        # Sample noise for trace estimation (reuse across time)
        epsilon = jax.random.normal(rngs.sample(), (batch_size, self.input_dim))

        def augmented_dynamics(augmented_state, t):
            z_current, _ = augmented_state

            # Dynamics
            dz_dt = self.dynamics(t, z_current)

            # Trace estimate
            trace = self.divergence_approx(t, z_current, epsilon)

            # For forward: positive trace, for inverse: negative
            sign = 1.0 if t_span[1] > t_span[0] else -1.0

            return dz_dt, sign * trace

        # Initial state
        initial_augmented = (z, jnp.zeros(batch_size))

        # Time points
        t_eval = jnp.linspace(t_span[0], t_span[1], num_steps)

        # Integrate
        def ode_func(state, t):
            return augmented_dynamics(state, t)

        trajectory = odeint(ode_func, initial_augmented, t_eval)

        # Extract final values
        z_final = trajectory[0][-1]
        log_det_jacobian = trajectory[1][-1]

        return z_final, log_det_jacobian

    def forward_and_log_det(
        self,
        z0: jax.Array,
        *,
        rngs: nnx.Rngs,
    ) -> tuple[jax.Array, jax.Array]:
        """Forward transformation with log determinant."""
        return self.ode_with_log_prob(z0, t_span=(0.0, 1.0), rngs=rngs)

    def inverse_and_log_det(
        self,
        x: jax.Array,
        *,
        rngs: nnx.Rngs,
    ) -> tuple[jax.Array, jax.Array]:
        """Inverse transformation with log determinant."""
        z0, log_det = self.ode_with_log_prob(x, t_span=(1.0, 0.0), rngs=rngs)
        return z0, -log_det  # Negate for inverse

    def log_prob(
        self,
        x: jax.Array,
        *,
        rngs: nnx.Rngs,
    ) -> jax.Array:
        """Compute log probability."""

        # Transform to base space
        z0, log_det = self.inverse_and_log_det(x, rngs=rngs)

        # Base distribution log prob (standard normal)
        base_log_prob = -0.5 * jnp.sum(z0 ** 2, axis=-1) - 0.5 * self.input_dim * jnp.log(2 * jnp.pi)

        return base_log_prob + log_det


def train_ffjord(
    model: FFJORD,
    train_data: jnp.ndarray,
    num_epochs: int = 100,
    batch_size: int = 128,
):
    """Train FFJORD model."""

    rngs = nnx.Rngs(42)
    optimizer = nnx.Optimizer(model, optax.adam(1e-3))

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        num_batches = 0

        for batch_idx in range(0, len(train_data), batch_size):
            batch = train_data[batch_idx:batch_idx + batch_size]

            # Compute negative log likelihood
            log_probs = model.log_prob(batch, rngs=rngs)
            loss = -jnp.mean(log_probs)

            # Update
            optimizer.update(jax.grad(lambda m: loss)(model))

            epoch_loss += loss
            num_batches += 1

        if epoch % 10 == 0:
            avg_loss = epoch_loss / num_batches
            print(f"Epoch {epoch}/{num_epochs}, NLL: {avg_loss:.4f}")

    return model

Advanced Coupling Flows¤

Custom coupling architectures with attention mechanisms and residual connections for improved expressiveness.

Attention Coupling Layer¤

class AttentionCouplingLayer(nnx.Module):
    """Coupling layer with self-attention in the transformation network."""

    def __init__(
        self,
        features: int,
        hidden_dim: int = 256,
        num_heads: int = 4,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.features = features
        self.split_dim = features // 2

        # Transformation network with attention
        self.transform_net = nnx.Sequential(
            nnx.Linear(self.split_dim, hidden_dim, rngs=rngs),
            nnx.Lambda(lambda x: nnx.relu(x)),

            # Reshape for attention
            nnx.Lambda(lambda x: x.reshape(x.shape[0], -1, hidden_dim)),

            # Self-attention
            nnx.MultiHeadAttention(
                num_heads=num_heads,
                in_features=hidden_dim,
                rngs=rngs,
            ),

            # Reshape back
            nnx.Lambda(lambda x: x.reshape(x.shape[0], -1)),

            nnx.Linear(hidden_dim, self.split_dim * 2, rngs=rngs),  # *2 for scale and shift
        )

    def __call__(
        self,
        x: jax.Array,
        reverse: bool = False,
    ) -> tuple[jax.Array, jax.Array]:
        """
        Forward or inverse transformation.

        Args:
            x: Input [batch, features]
            reverse: If True, compute inverse

        Returns:
            Tuple of (output, log_det_jacobian)
        """
        # Split input
        x1, x2 = jnp.split(x, [self.split_dim], axis=-1)

        if not reverse:
            # Forward: x2' = x2 * exp(s(x1)) + t(x1)
            transform_params = self.transform_net(x1)
            log_scale, shift = jnp.split(transform_params, 2, axis=-1)

            # Bound log scale for stability
            log_scale = jnp.tanh(log_scale)

            x2_transformed = x2 * jnp.exp(log_scale) + shift

            output = jnp.concatenate([x1, x2_transformed], axis=-1)
            log_det = jnp.sum(log_scale, axis=-1)

        else:
            # Inverse: x2 = (x2' - t(x1)) / exp(s(x1))
            transform_params = self.transform_net(x1)
            log_scale, shift = jnp.split(transform_params, 2, axis=-1)

            log_scale = jnp.tanh(log_scale)

            x2_original = (x2 - shift) * jnp.exp(-log_scale)

            output = jnp.concatenate([x1, x2_original], axis=-1)
            log_det = -jnp.sum(log_scale, axis=-1)

        return output, log_det


class ResidualCouplingFlow(nnx.Module):
    """Coupling flow with residual connections."""

    def __init__(
        self,
        features: int,
        hidden_dim: int = 256,
        num_blocks: int = 3,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.features = features
        self.split_dim = features // 2

        # Residual transformation blocks
        self.blocks = []
        for _ in range(num_blocks):
            block = nnx.Sequential(
                nnx.Linear(self.split_dim, hidden_dim, rngs=rngs),
                nnx.Lambda(lambda x: nnx.relu(x)),
                nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
                nnx.Lambda(lambda x: nnx.relu(x)),
                nnx.Linear(hidden_dim, self.split_dim, rngs=rngs),
            )
            self.blocks.append(block)

        # Final projection
        self.final_proj = nnx.Linear(self.split_dim, self.split_dim * 2, rngs=rngs)

    def transform_network(self, x1: jax.Array) -> jax.Array:
        """Apply residual blocks."""
        h = x1

        for block in self.blocks:
            h = h + block(h)  # Residual connection

        return self.final_proj(h)

    def __call__(
        self,
        x: jax.Array,
        reverse: bool = False,
    ) -> tuple[jax.Array, jax.Array]:
        """Forward or inverse transformation."""

        x1, x2 = jnp.split(x, [self.split_dim], axis=-1)

        # Get transformation parameters
        params = self.transform_network(x1)
        log_scale, shift = jnp.split(params, 2, axis=-1)

        # Stabilize log scale
        log_scale = 2.0 * jnp.tanh(log_scale / 2.0)

        if not reverse:
            x2_new = x2 * jnp.exp(log_scale) + shift
            log_det = jnp.sum(log_scale, axis=-1)
        else:
            x2_new = (x2 - shift) * jnp.exp(-log_scale)
            log_det = -jnp.sum(log_scale, axis=-1)

        output = jnp.concatenate([x1, x2_new], axis=-1)
        return output, log_det


class AdvancedCouplingFlow(nnx.Module):
    """Multi-scale coupling flow with attention and residual connections."""

    def __init__(
        self,
        features: int,
        num_layers: int = 8,
        hidden_dim: int = 256,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.features = features

        # Alternating coupling layers
        self.layers = []
        for i in range(num_layers):
            if i % 3 == 0:
                # Use attention every 3 layers
                layer = AttentionCouplingLayer(
                    features=features,
                    hidden_dim=hidden_dim,
                    rngs=rngs,
                )
            else:
                # Use residual coupling
                layer = ResidualCouplingFlow(
                    features=features,
                    hidden_dim=hidden_dim,
                    num_blocks=2,
                    rngs=rngs,
                )

            self.layers.append(layer)

            # Add permutation after each layer (alternate between patterns)
            if i % 2 == 0:
                permutation = jnp.arange(features)[::-1]  # Reverse
            else:
                permutation = jnp.roll(jnp.arange(features), features // 2)  # Roll

            self.layers.append(lambda x, perm=permutation: (x[:, perm], jnp.zeros(x.shape[0])))

    def __call__(
        self,
        x: jax.Array,
        reverse: bool = False,
    ) -> tuple[jax.Array, jax.Array]:
        """Forward or inverse pass through all layers."""

        log_det_total = jnp.zeros(x.shape[0])

        layers = self.layers if not reverse else reversed(self.layers)

        for layer in layers:
            if callable(layer):
                if isinstance(layer, (AttentionCouplingLayer, ResidualCouplingFlow)):
                    x, log_det = layer(x, reverse=reverse)
                    log_det_total += log_det
                else:
                    # Permutation layer
                    x, _ = layer(x)

        return x, log_det_total

    def log_prob(self, x: jax.Array) -> jax.Array:
        """Compute log probability."""

        # Transform to base space
        z, log_det = self(x, reverse=True)

        # Base distribution log prob
        base_log_prob = -0.5 * jnp.sum(z ** 2, axis=-1) - 0.5 * self.features * jnp.log(2 * jnp.pi)

        return base_log_prob + log_det

    def sample(self, num_samples: int, *, rngs: nnx.Rngs) -> jax.Array:
        """Sample from the flow."""

        # Sample from base
        z = jax.random.normal(rngs.sample(), (num_samples, self.features))

        # Transform to data space
        x, _ = self(z, reverse=False)

        return x

Conditional Flows¤

Flows can be conditioned on additional information for controlled generation.

Class-Conditional Flow¤

class ConditionalCouplingLayer(nnx.Module):
    """Coupling layer with class conditioning."""

    def __init__(
        self,
        features: int,
        num_classes: int,
        hidden_dim: int = 256,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.features = features
        self.split_dim = features // 2

        # Class embedding
        self.class_embedding = nnx.Embed(
            num_embeddings=num_classes,
            features=hidden_dim,
            rngs=rngs,
        )

        # Conditioned transformation network
        self.transform_net = nnx.Sequential(
            nnx.Linear(self.split_dim + hidden_dim, hidden_dim * 2, rngs=rngs),
            nnx.Lambda(lambda x: nnx.relu(x)),
            nnx.Linear(hidden_dim * 2, hidden_dim, rngs=rngs),
            nnx.Lambda(lambda x: nnx.relu(x)),
            nnx.Linear(hidden_dim, self.split_dim * 2, rngs=rngs),
        )

    def __call__(
        self,
        x: jax.Array,
        class_labels: jax.Array,
        reverse: bool = False,
    ) -> tuple[jax.Array, jax.Array]:
        """
        Conditional transformation.

        Args:
            x: Input [batch, features]
            class_labels: Class indices [batch]
            reverse: Forward or inverse

        Returns:
            Tuple of (output, log_det)
        """
        # Split
        x1, x2 = jnp.split(x, [self.split_dim], axis=-1)

        # Embed class
        class_embed = self.class_embedding(class_labels)

        # Concatenate x1 and class embedding
        x1_conditioned = jnp.concatenate([x1, class_embed], axis=-1)

        # Get transformation parameters
        params = self.transform_net(x1_conditioned)
        log_scale, shift = jnp.split(params, 2, axis=-1)

        log_scale = jnp.tanh(log_scale)

        if not reverse:
            x2_new = x2 * jnp.exp(log_scale) + shift
            log_det = jnp.sum(log_scale, axis=-1)
        else:
            x2_new = (x2 - shift) * jnp.exp(-log_scale)
            log_det = -jnp.sum(log_scale, axis=-1)

        output = jnp.concatenate([x1, x2_new], axis=-1)
        return output, log_det


class ConditionalNormalizingFlow(nnx.Module):
    """Full conditional normalizing flow model."""

    def __init__(
        self,
        features: int,
        num_classes: int,
        num_layers: int = 8,
        hidden_dim: int = 256,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.features = features
        self.num_classes = num_classes

        # Stack of conditional coupling layers
        self.layers = []
        for i in range(num_layers):
            layer = ConditionalCouplingLayer(
                features=features,
                num_classes=num_classes,
                hidden_dim=hidden_dim,
                rngs=rngs,
            )
            self.layers.append(layer)

    def __call__(
        self,
        x: jax.Array,
        class_labels: jax.Array,
        reverse: bool = False,
    ) -> tuple[jax.Array, jax.Array]:
        """Forward or inverse pass."""

        log_det_total = jnp.zeros(x.shape[0])

        layers = self.layers if not reverse else reversed(self.layers)

        for layer in layers:
            x, log_det = layer(x, class_labels, reverse=reverse)
            log_det_total += log_det

        return x, log_det_total

    def log_prob(self, x: jax.Array, class_labels: jax.Array) -> jax.Array:
        """Compute conditional log probability."""

        z, log_det = self(x, class_labels, reverse=True)

        base_log_prob = -0.5 * jnp.sum(z ** 2, axis=-1) - 0.5 * self.features * jnp.log(2 * jnp.pi)

        return base_log_prob + log_det

    def sample(
        self,
        num_samples: int,
        class_labels: jax.Array,
        *,
        rngs: nnx.Rngs,
    ) -> jax.Array:
        """Sample conditioned on classes."""

        z = jax.random.normal(rngs.sample(), (num_samples, self.features))

        x, _ = self(z, class_labels, reverse=False)

        return x


def train_conditional_flow(
    model: ConditionalNormalizingFlow,
    train_data: jnp.ndarray,
    train_labels: jnp.ndarray,
    num_epochs: int = 100,
):
    """Train conditional flow."""

    rngs = nnx.Rngs(42)
    optimizer = nnx.Optimizer(model, optax.adam(1e-3))

    for epoch in range(num_epochs):
        for batch_idx in range(0, len(train_data), 128):
            batch = train_data[batch_idx:batch_idx + 128]
            labels = train_labels[batch_idx:batch_idx + 128]

            # Negative log likelihood
            log_probs = model.log_prob(batch, labels)
            loss = -jnp.mean(log_probs)

            # Update
            optimizer.update(jax.grad(lambda m: loss)(model))

        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{num_epochs}, Loss: {loss:.4f}")

    return model


# Generate samples for specific class
def generate_class_samples(
    model: ConditionalNormalizingFlow,
    class_id: int,
    num_samples: int = 16,
    *,
    rngs: nnx.Rngs,
) -> jax.Array:
    """Generate samples for a specific class."""

    class_labels = jnp.full(num_samples, class_id)
    return model.sample(num_samples, class_labels, rngs=rngs)

Best Practices¤

DO

  • Use FFJORD for high-dimensional data (more efficient)
  • Add residual connections in coupling networks
  • Use attention for long-range dependencies
  • Monitor both NLL and sample quality
  • Use adaptive ODE solvers for CNF
  • Implement gradient clipping for training stability

DON'T

  • Don't use too few ODE steps (<20 for CNF)
  • Don't forget to alternate coupling directions
  • Don't use unbounded activation in scale networks
  • Don't skip permutation layers between couplings
  • Don't train with learning rates >1e-3
  • Don't use batch norm in flow transformations

Troubleshooting¤

Issue Cause Solution
Training instability Unbounded scales Use tanh or bounded activations for log_scale
Slow ODE integration Too many steps Use adaptive solvers, reduce steps
Poor sample quality Insufficient coupling Add more layers, use attention
NaN in training Exploding gradients Add gradient clipping, reduce learning rate
High memory usage Full Jacobian computation Use FFJORD with Hutchinson estimator

Summary¤

We covered four advanced normalizing flow techniques:

  1. Continuous Normalizing Flows: Flexible continuous-time transformations with Neural ODEs
  2. FFJORD: Efficient CNF with Hutchinson's trace estimator
  3. Advanced Coupling: Attention and residual connections for expressiveness
  4. Conditional Flows: Class or context-conditional generation

Key Takeaways:

  • CNF provides more flexibility than discrete flows
  • FFJORD makes CNF scalable to high dimensions
  • Attention and residual connections improve coupling flows
  • Conditional flows enable controlled generation

Next Steps¤