Skip to content

Custom Architectures¤

Build custom model architectures using Flax NNX in Workshop. This guide covers advanced architectural patterns, custom layers, and integration with Workshop's training and evaluation systems.

  • Custom Layers


    Create custom neural network layers with Flax NNX

    Learn more

  • Custom Models


    Build complete custom generative models

    Learn more

  • Architecture Patterns


    Common architectural patterns and best practices

    Learn more

  • Integration


    Integrate custom models with Workshop's systems

    Learn more

Overview¤

Workshop provides flexibility to create custom architectures while maintaining compatibility with the training, evaluation, and deployment infrastructure.

Why Custom Architectures?¤

Build custom architectures when:

  • Research: Implementing novel architectural ideas
  • Domain-Specific: Specialized requirements (proteins, molecules, etc.)
  • Optimization: Custom operations for performance
  • Experimentation: Rapid prototyping of new ideas

Custom Layers¤

Create custom neural network layers using Flax NNX.

Basic Custom Layer¤

import jax
import jax.numpy as jnp
from flax import nnx

class CustomLinear(nnx.Module):
    """Custom linear layer with additional features."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        use_bias: bool = True,
        weight_init: callable = nnx.initializers.lecun_normal(),
        bias_init: callable = nnx.initializers.zeros_init(),
        rngs: nnx.Rngs,
        dtype: jnp.dtype = jnp.float32,
    ):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias

        # Initialize weight
        self.weight = nnx.Param(
            weight_init(rngs.params(), (in_features, out_features), dtype)
        )

        # Initialize bias if needed
        if use_bias:
            self.bias = nnx.Param(
                bias_init(rngs.params(), (out_features,), dtype)
            )
        else:
            self.bias = None

    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass.

        Args:
            x: Input tensor (..., in_features)

        Returns:
            Output tensor (..., out_features)
        """
        # Matrix multiplication
        output = x @ self.weight.value

        # Add bias
        if self.use_bias:
            output = output + self.bias.value

        return output


# Usage
layer = CustomLinear(
    in_features=784,
    out_features=256,
    rngs=nnx.Rngs(0)
)

x = jnp.ones((32, 784))
output = layer(x)
print(f"Output shape: {output.shape}")  # (32, 256)

Advanced Custom Layer with Regularization¤

class RegularizedLinear(nnx.Module):
    """Linear layer with built-in regularization."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        dropout_rate: float = 0.0,
        weight_decay: float = 0.0,
        spectral_norm: bool = False,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.weight_decay = weight_decay
        self.spectral_norm = spectral_norm

        # Weight initialization
        self.weight = nnx.Param(
            nnx.initializers.lecun_normal()(
                rngs.params(),
                (in_features, out_features)
            )
        )

        self.bias = nnx.Param(jnp.zeros(out_features))

        # Dropout
        if dropout_rate > 0:
            self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
        else:
            self.dropout = None

    def _apply_spectral_norm(self, weight: jax.Array) -> jax.Array:
        """Apply spectral normalization to weight."""
        # Compute largest singular value
        u, s, vh = jnp.linalg.svd(weight, full_matrices=False)

        # Normalize by largest singular value
        weight_normalized = weight / s[0]

        return weight_normalized

    def __call__(
        self,
        x: jax.Array,
        *,
        deterministic: bool = False,
    ) -> jax.Array:
        """Forward pass with regularization.

        Args:
            x: Input tensor
            deterministic: If True, disable dropout

        Returns:
            Output tensor
        """
        # Get weight
        weight = self.weight.value

        # Apply spectral normalization
        if self.spectral_norm:
            weight = self._apply_spectral_norm(weight)

        # Linear transformation
        output = x @ weight + self.bias.value

        # Apply dropout
        if self.dropout is not None and not deterministic:
            output = self.dropout(output)

        return output

    def get_regularization_loss(self) -> jax.Array:
        """Compute regularization loss for this layer."""
        if self.weight_decay > 0:
            # L2 regularization
            return self.weight_decay * jnp.sum(self.weight.value ** 2)
        return 0.0


# Usage
layer = RegularizedLinear(
    in_features=784,
    out_features=256,
    dropout_rate=0.1,
    weight_decay=1e-4,
    spectral_norm=True,
    rngs=nnx.Rngs(0)
)

# Forward pass
x = jnp.ones((32, 784))
output = layer(x, deterministic=False)

# Get regularization loss
reg_loss = layer.get_regularization_loss()

Attention Layer¤

class MultiHeadAttention(nnx.Module):
    """Multi-head attention layer."""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        *,
        dropout_rate: float = 0.0,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        # Q, K, V projections
        self.q_proj = nnx.Linear(hidden_size, hidden_size, rngs=rngs)
        self.k_proj = nnx.Linear(hidden_size, hidden_size, rngs=rngs)
        self.v_proj = nnx.Linear(hidden_size, hidden_size, rngs=rngs)

        # Output projection
        self.out_proj = nnx.Linear(hidden_size, hidden_size, rngs=rngs)

        # Dropout
        if dropout_rate > 0:
            self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
        else:
            self.dropout = None

    def __call__(
        self,
        x: jax.Array,
        mask: jax.Array | None = None,
        *,
        deterministic: bool = False,
    ) -> jax.Array:
        """Multi-head attention forward pass.

        Args:
            x: Input tensor (batch, seq_len, hidden_size)
            mask: Optional attention mask (batch, seq_len, seq_len)
            deterministic: If True, disable dropout

        Returns:
            Output tensor (batch, seq_len, hidden_size)
        """
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V
        q = self.q_proj(x)  # (batch, seq_len, hidden_size)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape for multi-head attention
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose: (batch, num_heads, seq_len, head_dim)
        q = jnp.transpose(q, (0, 2, 1, 3))
        k = jnp.transpose(k, (0, 2, 1, 3))
        v = jnp.transpose(v, (0, 2, 1, 3))

        # Scaled dot-product attention
        scale = jnp.sqrt(self.head_dim)
        scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) / scale

        # Apply mask if provided
        if mask is not None:
            # Expand mask for heads: (batch, 1, seq_len, seq_len)
            mask = mask[:, None, :, :]
            scores = jnp.where(mask, scores, -1e9)

        # Softmax
        attention_weights = nnx.softmax(scores, axis=-1)

        # Apply dropout
        if self.dropout is not None and not deterministic:
            attention_weights = self.dropout(attention_weights)

        # Attend to values
        context = jnp.einsum("bhqk,bhkd->bhqd", attention_weights, v)

        # Reshape back: (batch, seq_len, hidden_size)
        context = jnp.transpose(context, (0, 2, 1, 3))
        context = context.reshape(batch_size, seq_len, self.hidden_size)

        # Output projection
        output = self.out_proj(context)

        return output


# Usage
attention = MultiHeadAttention(
    hidden_size=512,
    num_heads=8,
    dropout_rate=0.1,
    rngs=nnx.Rngs(0)
)

x = jnp.ones((2, 10, 512))  # (batch=2, seq_len=10, hidden=512)
output = attention(x, deterministic=False)
print(f"Output shape: {output.shape}")  # (2, 10, 512)

Residual Block¤

class ResidualBlock(nnx.Module):
    """Residual block with normalization and activation."""

    def __init__(
        self,
        channels: int,
        *,
        stride: int = 1,
        downsample: bool = False,
        activation: callable = nnx.relu,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.activation = activation

        # Main path
        self.conv1 = nnx.Conv(
            in_features=channels,
            out_features=channels,
            kernel_size=(3, 3),
            strides=(stride, stride),
            padding="SAME",
            rngs=rngs,
        )
        self.bn1 = nnx.BatchNorm(num_features=channels, rngs=rngs)

        self.conv2 = nnx.Conv(
            in_features=channels,
            out_features=channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding="SAME",
            rngs=rngs,
        )
        self.bn2 = nnx.BatchNorm(num_features=channels, rngs=rngs)

        # Shortcut path
        if downsample:
            self.shortcut = nnx.Conv(
                in_features=channels,
                out_features=channels,
                kernel_size=(1, 1),
                strides=(stride, stride),
                padding="VALID",
                rngs=rngs,
            )
            self.shortcut_bn = nnx.BatchNorm(num_features=channels, rngs=rngs)
        else:
            self.shortcut = None

    def __call__(
        self,
        x: jax.Array,
        *,
        use_running_average: bool = False,
    ) -> jax.Array:
        """Forward pass through residual block.

        Args:
            x: Input tensor (batch, height, width, channels)
            use_running_average: Use running stats for batch norm

        Returns:
            Output tensor (batch, height, width, channels)
        """
        # Save input for residual connection
        identity = x

        # Main path
        out = self.conv1(x)
        out = self.bn1(out, use_running_average=use_running_average)
        out = self.activation(out)

        out = self.conv2(out)
        out = self.bn2(out, use_running_average=use_running_average)

        # Shortcut path
        if self.shortcut is not None:
            identity = self.shortcut(identity)
            identity = self.shortcut_bn(
                identity,
                use_running_average=use_running_average
            )

        # Residual connection
        out = out + identity
        out = self.activation(out)

        return out


# Usage
block = ResidualBlock(
    channels=64,
    stride=2,
    downsample=True,
    rngs=nnx.Rngs(0)
)

x = jnp.ones((2, 32, 32, 64))
output = block(x, use_running_average=False)
print(f"Output shape: {output.shape}")  # (2, 16, 16, 64)

Custom Models¤

Build complete custom generative models.

Custom VAE Architecture¤

from workshop.generative_models.core.protocols import GenerativeModel
from flax import nnx
import jax
import jax.numpy as jnp

class CustomVAE(nnx.Module):
    """Custom VAE with flexible architecture."""

    def __init__(
        self,
        input_shape: tuple,
        latent_dim: int,
        encoder_layers: list[int],
        decoder_layers: list[int],
        *,
        activation: callable = nnx.relu,
        use_batch_norm: bool = True,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.input_shape = input_shape
        self.latent_dim = latent_dim
        self.activation = activation

        # Flatten input size
        self.input_dim = int(jnp.prod(jnp.array(input_shape)))

        # Encoder
        self.encoder = self._build_encoder(
            encoder_layers,
            use_batch_norm,
            rngs
        )

        # Latent projections
        self.mean_layer = nnx.Linear(
            in_features=encoder_layers[-1],
            out_features=latent_dim,
            rngs=rngs,
        )
        self.logvar_layer = nnx.Linear(
            in_features=encoder_layers[-1],
            out_features=latent_dim,
            rngs=rngs,
        )

        # Decoder
        self.decoder = self._build_decoder(
            decoder_layers,
            use_batch_norm,
            rngs
        )

        # Output layer
        self.output_layer = nnx.Linear(
            in_features=decoder_layers[-1],
            out_features=self.input_dim,
            rngs=rngs,
        )

    def _build_encoder(
        self,
        layers: list[int],
        use_batch_norm: bool,
        rngs: nnx.Rngs,
    ) -> list:
        """Build encoder layers."""
        encoder_layers = []

        # Input layer
        encoder_layers.append(nnx.Linear(self.input_dim, layers[0], rngs=rngs))
        if use_batch_norm:
            encoder_layers.append(nnx.BatchNorm(layers[0], rngs=rngs))

        # Hidden layers
        for i in range(len(layers) - 1):
            encoder_layers.append(
                nnx.Linear(layers[i], layers[i + 1], rngs=rngs)
            )
            if use_batch_norm:
                encoder_layers.append(nnx.BatchNorm(layers[i + 1], rngs=rngs))

        return encoder_layers

    def _build_decoder(
        self,
        layers: list[int],
        use_batch_norm: bool,
        rngs: nnx.Rngs,
    ) -> list:
        """Build decoder layers."""
        decoder_layers = []

        # Input layer (from latent)
        decoder_layers.append(nnx.Linear(self.latent_dim, layers[0], rngs=rngs))
        if use_batch_norm:
            decoder_layers.append(nnx.BatchNorm(layers[0], rngs=rngs))

        # Hidden layers
        for i in range(len(layers) - 1):
            decoder_layers.append(
                nnx.Linear(layers[i], layers[i + 1], rngs=rngs)
            )
            if use_batch_norm:
                decoder_layers.append(nnx.BatchNorm(layers[i + 1], rngs=rngs))

        return decoder_layers

    def encode(
        self,
        x: jax.Array,
        *,
        use_running_average: bool = False,
    ) -> dict[str, jax.Array]:
        """Encode input to latent distribution.

        Args:
            x: Input tensor (batch, *input_shape)
            use_running_average: Use running stats for batch norm

        Returns:
            Dictionary with 'mean' and 'logvar'
        """
        # Flatten input
        batch_size = x.shape[0]
        x = x.reshape(batch_size, -1)

        # Forward through encoder
        for layer in self.encoder:
            if isinstance(layer, nnx.BatchNorm):
                x = layer(x, use_running_average=use_running_average)
            else:
                x = layer(x)
            x = self.activation(x)

        # Latent parameters
        mean = self.mean_layer(x)
        logvar = self.logvar_layer(x)

        return {"mean": mean, "logvar": logvar}

    def reparameterize(
        self,
        mean: jax.Array,
        logvar: jax.Array,
        *,
        rngs: nnx.Rngs | None = None,
    ) -> jax.Array:
        """Reparameterization trick.

        Args:
            mean: Mean of latent distribution
            logvar: Log variance of latent distribution
            rngs: RNG for sampling

        Returns:
            Sampled latent vector
        """
        if rngs is not None and "sample" in rngs:
            key = rngs.sample()
        else:
            key = jax.random.key(0)

        std = jnp.exp(0.5 * logvar)
        eps = jax.random.normal(key, mean.shape)
        z = mean + eps * std

        return z

    def decode(
        self,
        z: jax.Array,
        *,
        use_running_average: bool = False,
    ) -> jax.Array:
        """Decode latent vector to reconstruction.

        Args:
            z: Latent vector (batch, latent_dim)
            use_running_average: Use running stats for batch norm

        Returns:
            Reconstruction (batch, *input_shape)
        """
        x = z

        # Forward through decoder
        for layer in self.decoder:
            if isinstance(layer, nnx.BatchNorm):
                x = layer(x, use_running_average=use_running_average)
            else:
                x = layer(x)
            x = self.activation(x)

        # Output layer
        x = self.output_layer(x)
        x = nnx.sigmoid(x)  # Normalize to [0, 1]

        # Reshape to input shape
        batch_size = z.shape[0]
        x = x.reshape(batch_size, *self.input_shape)

        return x

    def __call__(
        self,
        x: jax.Array,
        *,
        rngs: nnx.Rngs | None = None,
        use_running_average: bool = False,
    ) -> dict[str, jax.Array]:
        """Full forward pass (encode-reparameterize-decode).

        Args:
            x: Input tensor (batch, *input_shape)
            rngs: RNG for sampling
            use_running_average: Use running stats for batch norm

        Returns:
            Dictionary with 'reconstruction', 'mean', 'logvar', 'latent'
        """
        # Encode
        latent_params = self.encode(x, use_running_average=use_running_average)

        # Reparameterize
        z = self.reparameterize(
            latent_params["mean"],
            latent_params["logvar"],
            rngs=rngs
        )

        # Decode
        reconstruction = self.decode(z, use_running_average=use_running_average)

        return {
            "reconstruction": reconstruction,
            "mean": latent_params["mean"],
            "logvar": latent_params["logvar"],
            "latent": z,
        }


# Create custom VAE
model = CustomVAE(
    input_shape=(28, 28, 1),  # MNIST-like
    latent_dim=20,
    encoder_layers=[512, 256, 128],
    decoder_layers=[128, 256, 512],
    use_batch_norm=True,
    rngs=nnx.Rngs(0),
)

# Forward pass
x = jnp.ones((32, 28, 28, 1))
output = model(x, rngs=nnx.Rngs(1))

print(f"Reconstruction shape: {output['reconstruction'].shape}")  # (32, 28, 28, 1)
print(f"Latent shape: {output['latent'].shape}")  # (32, 20)

Custom GAN with Advanced Techniques¤

class CustomGenerator(nnx.Module):
    """Custom generator with self-attention."""

    def __init__(
        self,
        latent_dim: int,
        output_shape: tuple,
        hidden_dims: list[int],
        *,
        use_attention: bool = True,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.latent_dim = latent_dim
        self.output_shape = output_shape
        self.use_attention = use_attention

        # Build generator network
        self.layers = []

        # Initial projection
        self.layers.append(nnx.Linear(latent_dim, hidden_dims[0], rngs=rngs))
        self.layers.append(nnx.BatchNorm(hidden_dims[0], rngs=rngs))

        # Hidden layers
        for i in range(len(hidden_dims) - 1):
            self.layers.append(
                nnx.Linear(hidden_dims[i], hidden_dims[i + 1], rngs=rngs)
            )
            self.layers.append(nnx.BatchNorm(hidden_dims[i + 1], rngs=rngs))

            # Add self-attention at middle layer
            if use_attention and i == len(hidden_dims) // 2:
                self.attention = MultiHeadAttention(
                    hidden_size=hidden_dims[i + 1],
                    num_heads=4,
                    rngs=rngs,
                )

        # Output layer
        output_dim = int(jnp.prod(jnp.array(output_shape)))
        self.output_layer = nnx.Linear(hidden_dims[-1], output_dim, rngs=rngs)

    def __call__(
        self,
        z: jax.Array,
        *,
        use_running_average: bool = False,
    ) -> jax.Array:
        """Generate samples from noise.

        Args:
            z: Noise vector (batch, latent_dim)
            use_running_average: Use running stats for batch norm

        Returns:
            Generated samples (batch, *output_shape)
        """
        x = z

        # Forward through layers
        for i, layer in enumerate(self.layers):
            if isinstance(layer, nnx.BatchNorm):
                x = layer(x, use_running_average=use_running_average)
            else:
                x = layer(x)
                x = nnx.relu(x)

            # Apply self-attention if available
            if self.use_attention and hasattr(self, "attention"):
                if i == len(self.layers) // 2:
                    # Reshape for attention (add sequence dimension)
                    batch_size = x.shape[0]
                    x = x.reshape(batch_size, 1, -1)
                    x = self.attention(x)
                    x = x.reshape(batch_size, -1)

        # Output layer with tanh activation
        x = self.output_layer(x)
        x = nnx.tanh(x)

        # Reshape to output shape
        batch_size = z.shape[0]
        x = x.reshape(batch_size, *self.output_shape)

        return x


class CustomDiscriminator(nnx.Module):
    """Custom discriminator with spectral normalization."""

    def __init__(
        self,
        input_shape: tuple,
        hidden_dims: list[int],
        *,
        spectral_norm: bool = True,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.input_shape = input_shape
        self.spectral_norm = spectral_norm

        input_dim = int(jnp.prod(jnp.array(input_shape)))

        # Build discriminator network
        self.layers = []

        # Input layer
        if spectral_norm:
            self.layers.append(
                RegularizedLinear(
                    input_dim,
                    hidden_dims[0],
                    spectral_norm=True,
                    rngs=rngs,
                )
            )
        else:
            self.layers.append(nnx.Linear(input_dim, hidden_dims[0], rngs=rngs))

        # Hidden layers
        for i in range(len(hidden_dims) - 1):
            if spectral_norm:
                self.layers.append(
                    RegularizedLinear(
                        hidden_dims[i],
                        hidden_dims[i + 1],
                        spectral_norm=True,
                        rngs=rngs,
                    )
                )
            else:
                self.layers.append(
                    nnx.Linear(hidden_dims[i], hidden_dims[i + 1], rngs=rngs)
                )

        # Output layer (binary classification)
        if spectral_norm:
            self.output_layer = RegularizedLinear(
                hidden_dims[-1], 1,
                spectral_norm=True,
                rngs=rngs,
            )
        else:
            self.output_layer = nnx.Linear(hidden_dims[-1], 1, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        """Discriminate real vs fake samples.

        Args:
            x: Input samples (batch, *input_shape)

        Returns:
            Logits (batch, 1)
        """
        # Flatten input
        batch_size = x.shape[0]
        x = x.reshape(batch_size, -1)

        # Forward through layers
        for layer in self.layers:
            x = layer(x)
            x = nnx.leaky_relu(x, negative_slope=0.2)

        # Output layer (no activation, return logits)
        logits = self.output_layer(x)

        return logits


# Create custom GAN
generator = CustomGenerator(
    latent_dim=100,
    output_shape=(28, 28, 1),
    hidden_dims=[256, 512, 1024],
    use_attention=True,
    rngs=nnx.Rngs(0),
)

discriminator = CustomDiscriminator(
    input_shape=(28, 28, 1),
    hidden_dims=[1024, 512, 256],
    spectral_norm=True,
    rngs=nnx.Rngs(1),
)

# Generate samples
z = jax.random.normal(jax.random.key(0), (32, 100))
fake_samples = generator(z)
print(f"Generated shape: {fake_samples.shape}")  # (32, 28, 28, 1)

# Discriminate
real_samples = jnp.ones((32, 28, 28, 1))
real_logits = discriminator(real_samples)
fake_logits = discriminator(fake_samples)
print(f"Real logits shape: {real_logits.shape}")  # (32, 1)

Architecture Patterns¤

Common architectural patterns and best practices.

Residual Connections¤

class ResidualNetwork(nnx.Module):
    """Network with residual connections."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_blocks: int,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        # Input projection
        self.input_proj = nnx.Linear(input_dim, hidden_dim, rngs=rngs)

        # Residual blocks
        self.blocks = [
            self._create_residual_block(hidden_dim, rngs)
            for _ in range(num_blocks)
        ]

        # Output projection
        self.output_proj = nnx.Linear(hidden_dim, input_dim, rngs=rngs)

    def _create_residual_block(
        self,
        hidden_dim: int,
        rngs: nnx.Rngs,
    ) -> list:
        """Create a single residual block."""
        return [
            nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
            nnx.LayerNorm(hidden_dim, rngs=rngs),
            nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
            nnx.LayerNorm(hidden_dim, rngs=rngs),
        ]

    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass with residual connections."""
        # Input projection
        x = self.input_proj(x)

        # Residual blocks
        for block_layers in self.blocks:
            residual = x

            # Forward through block layers
            x = block_layers[0](x)  # Linear
            x = nnx.relu(x)
            x = block_layers[1](x)  # LayerNorm

            x = block_layers[2](x)  # Linear
            x = block_layers[3](x)  # LayerNorm

            # Residual connection
            x = x + residual
            x = nnx.relu(x)

        # Output projection
        x = self.output_proj(x)

        return x

Skip Connections (U-Net Style)¤

class UNetEncoder(nnx.Module):
    """U-Net style encoder with skip connections."""

    def __init__(
        self,
        channels_list: list[int],
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.down_blocks = []

        for i in range(len(channels_list) - 1):
            in_channels = channels_list[i]
            out_channels = channels_list[i + 1]

            # Downsampling block
            block = [
                nnx.Conv(
                    in_features=in_channels,
                    out_features=out_channels,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    padding="SAME",
                    rngs=rngs,
                ),
                nnx.BatchNorm(out_channels, rngs=rngs),
            ]
            self.down_blocks.append(block)

    def __call__(
        self,
        x: jax.Array,
        *,
        use_running_average: bool = False,
    ) -> tuple[jax.Array, list[jax.Array]]:
        """Forward pass with skip connections.

        Args:
            x: Input tensor
            use_running_average: Use running stats for batch norm

        Returns:
            (encoded, skip_connections)
        """
        skip_connections = []

        for block in self.down_blocks:
            # Save for skip connection
            skip_connections.append(x)

            # Downsample
            x = block[0](x)  # Conv
            x = block[1](x, use_running_average=use_running_average)  # BatchNorm
            x = nnx.relu(x)

        return x, skip_connections


class UNetDecoder(nnx.Module):
    """U-Net style decoder with skip connections."""

    def __init__(
        self,
        channels_list: list[int],
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.up_blocks = []

        for i in range(len(channels_list) - 1):
            in_channels = channels_list[i]
            out_channels = channels_list[i + 1]

            # Upsampling block
            block = [
                nnx.ConvTranspose(
                    in_features=in_channels,
                    out_features=out_channels,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    padding="SAME",
                    rngs=rngs,
                ),
                nnx.BatchNorm(out_channels, rngs=rngs),
            ]
            self.up_blocks.append(block)

    def __call__(
        self,
        x: jax.Array,
        skip_connections: list[jax.Array],
        *,
        use_running_average: bool = False,
    ) -> jax.Array:
        """Forward pass with skip connections.

        Args:
            x: Encoded tensor
            skip_connections: Skip connections from encoder
            use_running_average: Use running stats for batch norm

        Returns:
            Decoded tensor
        """
        for i, block in enumerate(self.up_blocks):
            # Upsample
            x = block[0](x)  # ConvTranspose
            x = block[1](x, use_running_average=use_running_average)  # BatchNorm
            x = nnx.relu(x)

            # Add skip connection
            if i < len(skip_connections):
                skip = skip_connections[-(i + 1)]
                x = jnp.concatenate([x, skip], axis=-1)

        return x

Dense Connections (DenseNet Style)¤

class DenseBlock(nnx.Module):
    """Dense block with concatenated connections."""

    def __init__(
        self,
        in_channels: int,
        growth_rate: int,
        num_layers: int,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.growth_rate = growth_rate

        self.layers = []
        for i in range(num_layers):
            layer_in_channels = in_channels + i * growth_rate

            layer = [
                nnx.BatchNorm(layer_in_channels, rngs=rngs),
                nnx.Conv(
                    in_features=layer_in_channels,
                    out_features=growth_rate,
                    kernel_size=(3, 3),
                    padding="SAME",
                    rngs=rngs,
                ),
            ]
            self.layers.append(layer)

    def __call__(
        self,
        x: jax.Array,
        *,
        use_running_average: bool = False,
    ) -> jax.Array:
        """Forward pass with dense connections.

        Args:
            x: Input tensor
            use_running_average: Use running stats for batch norm

        Returns:
            Output tensor with all features concatenated
        """
        features = [x]

        for layer in self.layers:
            # BatchNorm + ReLU + Conv
            out = layer[0](x, use_running_average=use_running_average)
            out = nnx.relu(out)
            out = layer[1](out)

            # Concatenate with previous features
            features.append(out)
            x = jnp.concatenate(features, axis=-1)

        return x

Workshop Integration¤

Integrate custom models with Workshop's systems.

Implementing the GenerativeModel Protocol¤

from workshop.generative_models.core.protocols import GenerativeModel
from flax import nnx
import jax
import jax.numpy as jnp

class MyCustomGenerativeModel(nnx.Module):
    """Custom model implementing Workshop's GenerativeModel protocol."""

    def __init__(
        self,
        config: dict,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        # Extract config
        self.latent_dim = config.get("latent_dim", 20)
        self.input_shape = config.get("input_shape", (28, 28, 1))

        # Build architecture (your custom design)
        self.encoder = CustomVAE(
            input_shape=self.input_shape,
            latent_dim=self.latent_dim,
            encoder_layers=[512, 256],
            decoder_layers=[256, 512],
            rngs=rngs,
        )

    def __call__(
        self,
        x: jax.Array,
        *,
        rngs: nnx.Rngs | None = None,
        **kwargs
    ) -> dict[str, jax.Array]:
        """Forward pass returning Workshop-compatible output.

        Must return dictionary with at least:
        - 'loss': scalar loss for training
        - Model-specific outputs (e.g., 'reconstruction', 'samples')
        """
        # Forward pass
        output = self.encoder(x, rngs=rngs, **kwargs)

        # Compute loss
        reconstruction_loss = jnp.mean((x - output["reconstruction"]) ** 2)
        kl_loss = -0.5 * jnp.mean(
            1 + output["logvar"] - output["mean"] ** 2 - jnp.exp(output["logvar"])
        )
        total_loss = reconstruction_loss + kl_loss

        # Return Workshop-compatible output
        return {
            "loss": total_loss,
            "reconstruction": output["reconstruction"],
            "latent": output["latent"],
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    def sample(
        self,
        num_samples: int,
        *,
        rngs: nnx.Rngs | None = None,
        **kwargs
    ) -> jax.Array:
        """Generate samples (required for GenerativeModel protocol)."""
        if rngs is not None and "sample" in rngs:
            key = rngs.sample()
        else:
            key = jax.random.key(0)

        # Sample from prior
        z = jax.random.normal(key, (num_samples, self.latent_dim))

        # Decode
        samples = self.encoder.decode(z, **kwargs)

        return samples


# Usage with Workshop
from workshop.generative_models.training.trainer import Trainer

model = MyCustomGenerativeModel(
    config={"latent_dim": 20, "input_shape": (28, 28, 1)},
    rngs=nnx.Rngs(0),
)

# Integrate with Workshop's trainer
trainer = Trainer(
    model=model,
    # ... other config
)

# Training and evaluation work automatically
# trainer.train(train_dataset, val_dataset)

Custom Loss Functions¤

def custom_vae_loss(
    model: nnx.Module,
    batch: dict[str, jax.Array],
    *,
    beta: float = 1.0,
    **kwargs
) -> tuple[jax.Array, dict]:
    """Custom VAE loss with β-weighting.

    Args:
        model: The VAE model
        batch: Batch dictionary with 'data' key
        beta: Weight for KL divergence term

    Returns:
        (total_loss, metrics_dict)
    """
    # Forward pass
    output = model(batch["data"], **kwargs)

    # Reconstruction loss
    recon_loss = jnp.mean(
        (batch["data"] - output["reconstruction"]) ** 2
    )

    # KL divergence
    kl_loss = -0.5 * jnp.mean(
        1 + output["logvar"]
        - output["mean"] ** 2
        - jnp.exp(output["logvar"])
    )

    # Total loss with β-weighting
    total_loss = recon_loss + beta * kl_loss

    # Metrics for logging
    metrics = {
        "loss": total_loss,
        "reconstruction_loss": recon_loss,
        "kl_loss": kl_loss,
        "beta": beta,
    }

    return total_loss, metrics


# Use custom loss in training
@jax.jit
def train_step(model_state, batch, optimizer_state):
    """Training step with custom loss."""
    model = nnx.merge(model_graphdef, model_state)

    # Compute loss and gradients
    (loss, metrics), grads = nnx.value_and_grad(
        lambda m: custom_vae_loss(m, batch, beta=2.0),
        has_aux=True
    )(model)

    # Update parameters
    updates, optimizer_state = optimizer.update(grads, optimizer_state)
    model_state = optax.apply_updates(model_state, updates)

    return model_state, optimizer_state, metrics

Best Practices¤

DO¤

  • Follow Flax NNX patterns - use nnx.Module, nnx.Param
  • Call super().__init__() - always in module constructors
  • Use proper RNG handling - check if key exists, provide fallback
  • Implement protocols - match Workshop's interface expectations
  • Return dictionaries - structured outputs for logging
  • Use type hints - document input/output shapes
  • Test components separately - unit test layers before integration
  • Profile performance - measure speed and memory
  • Document architecture - explain design choices
  • Version your models - track architectural changes

DON'T¤

  • Don't use Flax Linen - only use Flax NNX
  • Don't forget super().**init()** - causes initialization issues
  • Don't use numpy inside modules - use jax.numpy instead
  • Don't mix PyTorch/TensorFlow - stay in JAX ecosystem
  • Don't hardcode shapes - make them configurable
  • Don't skip validation - verify outputs are correct
  • Don't ignore memory - monitor GPU usage
  • Don't over-engineer - start simple, add complexity as needed
  • Don't skip documentation - explain architecture decisions
  • Don't forget batch dimensions - always handle batched inputs

Summary¤

Custom architectures in Workshop:

  1. Custom Layers: Build reusable components with Flax NNX
  2. Custom Models: Create complete generative models
  3. Architecture Patterns: Residual, skip, dense connections
  4. Workshop Integration: Implement protocols for seamless integration

Key principles:

  • Use Flax NNX exclusively
  • Follow Workshop's protocol interfaces
  • Return structured outputs (dictionaries)
  • Document architecture choices
  • Test and profile before deploying

Next Steps¤