Model Parallelism¤
Model parallelism techniques for training large models that don't fit on a single device. Workshop supports tensor parallelism, pipeline parallelism, and hybrid strategies using JAX's sharding capabilities.
-
Tensor Parallelism
Split weight matrices across devices within a single layer
-
:material-pipeline:{ .lg .middle } Pipeline Parallelism
Distribute model layers across devices in a pipeline
-
Hybrid Strategies
Combine multiple parallelism techniques for maximum scale
-
Memory Optimization
Techniques to maximize model size on limited memory
Overview¤
Model parallelism becomes necessary when:
- Model Too Large: Parameters exceed single device memory
- Activation Memory: Forward/backward activations don't fit
- Batch Size Constraints: Can't reduce batch size further
- Extreme Scale: Training models with billions of parameters
Parallelism Strategies Comparison¤
| Strategy | When to Use | Pros | Cons |
|---|---|---|---|
| Data Parallel | Model fits on one device | Simple, efficient | Limited by model size |
| Tensor Parallel | Large layers, fast interconnect | Balances compute | Communication overhead |
| Pipeline Parallel | Many layers, slower interconnect | Minimal communication | Pipeline bubbles |
| Hybrid | Extreme scale | Maximum efficiency | Complex to implement |
Tensor Parallelism¤
Tensor parallelism splits individual weight matrices across multiple devices.
Megatron-Style Tensor Parallelism¤
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from flax import nnx
class ColumnParallelLinear(nnx.Module):
"""Linear layer with column-parallel weight matrix.
Splits weight matrix along output dimension:
Y = X @ W where W is split as [W1, W2]
Y = [X @ W1, X @ W2] (concatenate outputs)
"""
def __init__(
self,
in_features: int,
out_features: int,
*,
rngs: nnx.Rngs,
mesh: Mesh,
gather_output: bool = True,
):
super().__init__()
self.gather_output = gather_output
# Create weight sharded along output dimension
self.weight = nnx.Param(
nnx.initializers.lecun_normal()(
rngs.params(),
(in_features, out_features)
)
)
# Shard along columns (output dimension)
weight_sharding = NamedSharding(mesh, P(None, "model"))
self.weight.value = jax.device_put(self.weight.value, weight_sharding)
# Bias sharded same way
self.bias = nnx.Param(jnp.zeros(out_features))
bias_sharding = NamedSharding(mesh, P("model"))
self.bias.value = jax.device_put(self.bias.value, bias_sharding)
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with column parallelism.
Args:
x: Input activations (batch, in_features)
Returns:
Output activations (batch, out_features)
"""
# Input is replicated across model parallel devices
# Output is sharded along out_features dimension
# Matrix multiplication (automatically parallelized)
output = x @ self.weight.value + self.bias.value
if self.gather_output:
# Gather output across model parallel devices
output = jax.lax.all_gather(
output,
axis_name="model",
axis=1, # Gather along feature dimension
tiled=True
)
return output
class RowParallelLinear(nnx.Module):
"""Linear layer with row-parallel weight matrix.
Splits weight matrix along input dimension:
Y = X @ W where W is split as [W1; W2]
Y = X1 @ W1 + X2 @ W2 (sum partial results)
"""
def __init__(
self,
in_features: int,
out_features: int,
*,
rngs: nnx.Rngs,
mesh: Mesh,
input_is_parallel: bool = True,
):
super().__init__()
self.input_is_parallel = input_is_parallel
# Create weight sharded along input dimension
self.weight = nnx.Param(
nnx.initializers.lecun_normal()(
rngs.params(),
(in_features, out_features)
)
)
# Shard along rows (input dimension)
weight_sharding = NamedSharding(mesh, P("model", None))
self.weight.value = jax.device_put(self.weight.value, weight_sharding)
# Bias replicated (only added once after reduce)
self.bias = nnx.Param(jnp.zeros(out_features))
bias_sharding = NamedSharding(mesh, P())
self.bias.value = jax.device_put(self.bias.value, bias_sharding)
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with row parallelism.
Args:
x: Input activations (batch, in_features)
Either sharded or replicated depending on input_is_parallel
Returns:
Output activations (batch, out_features), replicated
"""
if not self.input_is_parallel:
# Split input across model parallel devices if not already split
x = jax.lax.all_split(x, axis_name="model", split_axis=1)
# Matrix multiplication (each device has partial result)
partial_output = x @ self.weight.value
# All-reduce to sum partial results
output = jax.lax.psum(partial_output, axis_name="model")
# Add bias (only once after reduction)
output = output + self.bias.value
return output
Transformer with Tensor Parallelism¤
class ParallelTransformerLayer(nnx.Module):
"""Transformer layer with Megatron-style tensor parallelism."""
def __init__(
self,
hidden_size: int,
num_heads: int,
ffn_hidden_size: int,
*,
rngs: nnx.Rngs,
mesh: Mesh,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
# Attention: column-parallel for Q, K, V projections
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * hidden_size, # Q, K, V concatenated
rngs=rngs,
mesh=mesh,
gather_output=False, # Keep sharded for attention
)
# Attention: row-parallel for output projection
self.output_proj = RowParallelLinear(
hidden_size,
hidden_size,
rngs=rngs,
mesh=mesh,
input_is_parallel=True, # Input from sharded attention
)
# Feed-forward: column-parallel for first layer
self.ffn_layer1 = ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
rngs=rngs,
mesh=mesh,
gather_output=False, # Keep sharded
)
# Feed-forward: row-parallel for second layer
self.ffn_layer2 = RowParallelLinear(
ffn_hidden_size,
hidden_size,
rngs=rngs,
mesh=mesh,
input_is_parallel=True,
)
# Layer norms (replicated)
self.ln1 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.ln2 = nnx.LayerNorm(hidden_size, rngs=rngs)
def attention(self, x: jax.Array) -> jax.Array:
"""Multi-head attention with tensor parallelism.
Args:
x: (batch, seq_len, hidden_size)
Returns:
(batch, seq_len, hidden_size)
"""
batch_size, seq_len, hidden_size = x.shape
head_dim = hidden_size // self.num_heads
# Q, K, V projection (column-parallel)
qkv = self.qkv_proj(x) # (batch, seq_len, 3 * hidden_size) sharded
# Split into Q, K, V
q, k, v = jnp.split(qkv, 3, axis=-1)
# Reshape for multi-head attention
# Each device has subset of heads
q = q.reshape(batch_size, seq_len, -1, head_dim)
k = k.reshape(batch_size, seq_len, -1, head_dim)
v = v.reshape(batch_size, seq_len, -1, head_dim)
# Attention computation (parallelized across heads)
scores = jnp.einsum("bqhd,bkhd->bhqk", q, k) / jnp.sqrt(head_dim)
attention_weights = nnx.softmax(scores, axis=-1)
context = jnp.einsum("bhqk,bkhd->bqhd", attention_weights, v)
# Reshape back
context = context.reshape(batch_size, seq_len, -1)
# Output projection (row-parallel)
output = self.output_proj(context) # All-reduce happens here
return output
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass through transformer layer.
Args:
x: (batch, seq_len, hidden_size)
Returns:
(batch, seq_len, hidden_size)
"""
# Attention block with residual
residual = x
x = self.ln1(x)
x = self.attention(x)
x = residual + x
# Feed-forward block with residual
residual = x
x = self.ln2(x)
x = self.ffn_layer1(x)
x = nnx.gelu(x)
x = self.ffn_layer2(x)
x = residual + x
return x
# Create model with tensor parallelism
devices = jax.devices()
mesh = Mesh(devices.reshape(1, 4), axis_names=("data", "model"))
with mesh:
# Create transformer layer parallelized across 4 devices
layer = ParallelTransformerLayer(
hidden_size=768,
num_heads=12,
ffn_hidden_size=3072,
rngs=nnx.Rngs(0),
mesh=mesh,
)
# Forward pass (automatic parallelization)
x = jnp.ones((2, 128, 768)) # (batch=2, seq_len=128, hidden=768)
output = layer(x)
print(f"Output shape: {output.shape}") # (2, 128, 768)
Sequence Parallelism¤
For long sequences, also shard along sequence dimension:
class SequenceParallelTransformerLayer(nnx.Module):
"""Transformer with both tensor and sequence parallelism."""
def __init__(
self,
hidden_size: int,
num_heads: int,
*,
rngs: nnx.Rngs,
mesh: Mesh,
):
super().__init__()
# Same as before but with sequence parallelism annotations
self.qkv_proj = ColumnParallelLinear(
hidden_size, 3 * hidden_size,
rngs=rngs, mesh=mesh, gather_output=False
)
self.output_proj = RowParallelLinear(
hidden_size, hidden_size,
rngs=rngs, mesh=mesh, input_is_parallel=True
)
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with sequence parallelism.
Args:
x: (batch, seq_len, hidden_size)
Sharded along both seq_len and hidden_size dimensions
Returns:
(batch, seq_len, hidden_size) with same sharding
"""
# Layer norm computed independently on each sequence chunk
x_norm = nnx.LayerNorm(self.hidden_size)(x)
# Attention with sequence parallelism
# All-gather along sequence dimension for attention computation
x_gathered = jax.lax.all_gather(
x_norm,
axis_name="sequence",
axis=1, # Gather along sequence dimension
tiled=True
)
# Compute attention on full sequence
attn_output = self.attention(x_gathered)
# Split back along sequence dimension
attn_output = jax.lax.all_split(
attn_output,
axis_name="sequence",
split_axis=1
)
# Residual connection
output = x + attn_output
return output
# Create mesh with sequence parallelism
devices = jax.devices() # 8 devices
mesh = Mesh(
devices.reshape(2, 2, 2),
axis_names=("data", "model", "sequence")
)
# Shard input along both model and sequence dimensions
sharding = NamedSharding(mesh, P("data", "sequence", "model"))
x = jax.device_put(x, sharding)
Pipeline Parallelism¤
Pipeline parallelism splits model layers across devices and pipelines microbatches.
GPipe-Style Pipeline¤
from typing import Callable
import jax
import jax.numpy as jnp
from flax import nnx
class PipelineParallelModel(nnx.Module):
"""Model with GPipe-style pipeline parallelism."""
def __init__(
self,
layer_configs: list[dict],
num_microbatches: int = 4,
*,
rngs: nnx.Rngs,
):
super().__init__()
self.num_microbatches = num_microbatches
# Create stages (groups of layers)
self.stages = []
for config in layer_configs:
stage = self._create_stage(config, rngs)
self.stages.append(stage)
def _create_stage(self, config: dict, rngs: nnx.Rngs) -> nnx.Module:
"""Create a pipeline stage from config."""
layers = []
for layer_spec in config["layers"]:
if layer_spec["type"] == "linear":
layer = nnx.Linear(
in_features=layer_spec["in_features"],
out_features=layer_spec["out_features"],
rngs=rngs,
)
elif layer_spec["type"] == "conv":
layer = nnx.Conv(
in_features=layer_spec["in_channels"],
out_features=layer_spec["out_channels"],
kernel_size=layer_spec["kernel_size"],
rngs=rngs,
)
layers.append(layer)
# Wrap layers in a sequential module
return nnx.Sequential(*layers)
def forward_stage(self, stage_id: int, x: jax.Array) -> jax.Array:
"""Forward pass through one pipeline stage."""
return self.stages[stage_id](x)
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with pipeline parallelism."""
return self._pipeline_forward(x)
def _pipeline_forward(self, x: jax.Array) -> jax.Array:
"""Execute pipeline forward pass with microbatching."""
batch_size = x.shape[0]
microbatch_size = batch_size // self.num_microbatches
num_stages = len(self.stages)
# Split input into microbatches
microbatches = [
x[i * microbatch_size:(i + 1) * microbatch_size]
for i in range(self.num_microbatches)
]
# Pipeline execution buffer
# buffer[stage_id] holds the activation for that stage
buffer = [None] * num_stages
outputs = []
# Pipeline schedule: fill, steady-state, drain
for time_step in range(num_stages + self.num_microbatches - 1):
# Process each stage at this time step
for stage_id in range(num_stages - 1, -1, -1):
microbatch_id = time_step - stage_id
if 0 <= microbatch_id < self.num_microbatches:
# Get input for this stage
if stage_id == 0:
stage_input = microbatches[microbatch_id]
else:
stage_input = buffer[stage_id - 1]
# Compute this stage
stage_output = self.forward_stage(stage_id, stage_input)
# Store in buffer or output
if stage_id == num_stages - 1:
outputs.append(stage_output)
else:
buffer[stage_id] = stage_output
# Concatenate outputs
return jnp.concatenate(outputs, axis=0)
# Create pipeline model
layer_configs = [
# Stage 0: Input layers
{
"layers": [
{"type": "linear", "in_features": 784, "out_features": 512},
{"type": "linear", "in_features": 512, "out_features": 512},
]
},
# Stage 1: Middle layers
{
"layers": [
{"type": "linear", "in_features": 512, "out_features": 256},
{"type": "linear", "in_features": 256, "out_features": 256},
]
},
# Stage 2: Output layers
{
"layers": [
{"type": "linear", "in_features": 256, "out_features": 128},
{"type": "linear", "in_features": 128, "out_features": 10},
]
},
]
model = PipelineParallelModel(
layer_configs=layer_configs,
num_microbatches=4,
rngs=nnx.Rngs(0),
)
# Forward pass with pipeline parallelism
x = jnp.ones((32, 784)) # Batch of 32
output = model(x)
print(f"Output shape: {output.shape}") # (32, 10)
1F1B (One Forward One Backward) Schedule¤
More memory-efficient pipeline schedule:
class OneFOneBPipeline(nnx.Module):
"""Pipeline with 1F1B (One Forward One Backward) schedule."""
def __init__(
self,
stages: list[nnx.Module],
num_microbatches: int = 4,
):
super().__init__()
self.stages = stages
self.num_microbatches = num_microbatches
self.num_stages = len(stages)
def forward_backward_step(
self,
stage_id: int,
forward_input: jax.Array | None,
backward_grad: jax.Array | None,
) -> tuple:
"""Perform one forward and one backward step for a stage."""
outputs = {}
# Forward pass if input available
if forward_input is not None:
def forward_fn(stage, x):
return self.stages[stage_id](x)
# Compute forward and store for backward
outputs["forward_output"], outputs["forward_vjp"] = jax.vjp(
lambda x: forward_fn(stage_id, x),
forward_input
)
# Backward pass if gradient available
if backward_grad is not None and "forward_vjp" in outputs:
# Compute gradients
outputs["backward_grad"] = outputs["forward_vjp"](backward_grad)[0]
return outputs
def __call__(self, x: jax.Array) -> tuple[jax.Array, dict]:
"""Forward-backward pass with 1F1B schedule."""
batch_size = x.shape[0]
microbatch_size = batch_size // self.num_microbatches
# Split into microbatches
microbatches = [
x[i * microbatch_size:(i + 1) * microbatch_size]
for i in range(self.num_microbatches)
]
# Execution state
forward_cache = [[None] * self.num_microbatches
for _ in range(self.num_stages)]
backward_grads = [[None] * self.num_microbatches
for _ in range(self.num_stages)]
outputs = []
# 1F1B Schedule:
# 1. Warmup: Fill pipeline with forward passes
# 2. Steady state: Alternate forward and backward
# 3. Cooldown: Drain backward passes
warmup_steps = self.num_stages - 1
steady_steps = self.num_microbatches - warmup_steps
# Warmup phase
for step in range(warmup_steps):
for stage_id in range(step + 1):
microbatch_id = step - stage_id
if stage_id == 0:
stage_input = microbatches[microbatch_id]
else:
stage_input = forward_cache[stage_id - 1][microbatch_id]
result = self.forward_backward_step(
stage_id, stage_input, None
)
forward_cache[stage_id][microbatch_id] = result["forward_output"]
# Steady state: 1 forward + 1 backward per step
for step in range(steady_steps):
microbatch_id = warmup_steps + step
# Forward pass for new microbatch
for stage_id in range(self.num_stages):
if stage_id == 0:
stage_input = microbatches[microbatch_id]
else:
stage_input = forward_cache[stage_id - 1][microbatch_id]
result = self.forward_backward_step(
stage_id, stage_input, None
)
forward_cache[stage_id][microbatch_id] = result["forward_output"]
# Backward pass for oldest microbatch in pipeline
backward_microbatch_id = step
for stage_id in range(self.num_stages - 1, -1, -1):
if stage_id == self.num_stages - 1:
# Loss gradient (assume 1.0 for now)
grad = jnp.ones_like(
forward_cache[stage_id][backward_microbatch_id]
)
else:
grad = backward_grads[stage_id + 1][backward_microbatch_id]
result = self.forward_backward_step(
stage_id,
forward_cache[stage_id][backward_microbatch_id],
grad
)
if "backward_grad" in result:
backward_grads[stage_id][backward_microbatch_id] = result["backward_grad"]
# Cooldown: Drain remaining backward passes
for step in range(warmup_steps):
backward_microbatch_id = steady_steps + step
for stage_id in range(self.num_stages - 1, -1, -1):
if stage_id == self.num_stages - 1:
grad = jnp.ones_like(
forward_cache[stage_id][backward_microbatch_id]
)
else:
grad = backward_grads[stage_id + 1][backward_microbatch_id]
result = self.forward_backward_step(
stage_id,
forward_cache[stage_id][backward_microbatch_id],
grad
)
if "backward_grad" in result:
backward_grads[stage_id][backward_microbatch_id] = result["backward_grad"]
# Collect outputs
final_outputs = [
forward_cache[self.num_stages - 1][i]
for i in range(self.num_microbatches)
]
return jnp.concatenate(final_outputs, axis=0), backward_grads[0]
Hybrid Parallelism¤
Combine multiple parallelism strategies for maximum scale.
3D Parallelism¤
Combine data, tensor, and pipeline parallelism:
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from flax import nnx
class HybridParallelTransformer(nnx.Module):
"""Transformer with 3D parallelism (data + tensor + pipeline)."""
def __init__(
self,
num_layers: int,
hidden_size: int,
num_heads: int,
num_pipeline_stages: int,
*,
rngs: nnx.Rngs,
mesh: Mesh,
):
super().__init__()
# Divide layers into pipeline stages
layers_per_stage = num_layers // num_pipeline_stages
self.stages = []
for stage_id in range(num_pipeline_stages):
stage_layers = []
for _ in range(layers_per_stage):
# Each layer uses tensor parallelism
layer = ParallelTransformerLayer(
hidden_size=hidden_size,
num_heads=num_heads,
ffn_hidden_size=4 * hidden_size,
rngs=rngs,
mesh=mesh,
)
stage_layers.append(layer)
self.stages.append(nnx.Sequential(*stage_layers))
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with 3D parallelism.
Args:
x: (batch, seq_len, hidden_size)
Sharded along batch (data parallel),
hidden_size (tensor parallel),
and layers (pipeline parallel)
Returns:
(batch, seq_len, hidden_size) with same sharding
"""
# Pipeline forward through stages
for stage in self.stages:
x = stage(x)
return x
# Create 3D parallel mesh
devices = jax.devices() # e.g., 16 devices
mesh = Mesh(
devices.reshape(2, 4, 2), # (data, model, pipeline)
axis_names=("data", "model", "pipeline")
)
# Create hybrid parallel model
with mesh:
model = HybridParallelTransformer(
num_layers=24,
hidden_size=1024,
num_heads=16,
num_pipeline_stages=2, # 12 layers per stage
rngs=nnx.Rngs(0),
mesh=mesh,
)
# Define input sharding
input_sharding = NamedSharding(mesh, P("data", None, "model"))
# Forward pass
x = jnp.ones((16, 512, 1024)) # (batch, seq_len, hidden)
x = jax.device_put(x, input_sharding)
output = model(x)
Automatic Parallelism Selection¤
Choose parallelism strategy based on model size and available devices:
def select_parallelism_strategy(
model_params: int,
available_devices: int,
device_memory_gb: float,
) -> dict:
"""Select optimal parallelism strategy.
Args:
model_params: Number of model parameters (billions)
available_devices: Number of available devices
device_memory_gb: Memory per device (GB)
Returns:
Dictionary with parallelism configuration
"""
# Estimate memory requirements (rough approximation)
# Parameters: 4 bytes per param (fp32) or 2 bytes (fp16)
# Gradients: Same as parameters
# Optimizer states: 2x parameters (Adam)
# Activations: Depends on batch size, roughly 2x parameters
memory_per_param_bytes = 2 # fp16
total_memory_gb = model_params * memory_per_param_bytes * 5 / 1e9
params_per_device = model_params / available_devices
if total_memory_gb <= device_memory_gb:
# Model fits on one device: use data parallelism
return {
"strategy": "data_parallel",
"data_parallel_size": available_devices,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"mesh_shape": (available_devices,),
"mesh_axis_names": ("data",),
}
elif params_per_device * memory_per_param_bytes * 5 / 1e9 <= device_memory_gb:
# Model fits with data parallelism: use it
return {
"strategy": "data_parallel",
"data_parallel_size": available_devices,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"mesh_shape": (available_devices,),
"mesh_axis_names": ("data",),
}
else:
# Need model parallelism
# Prefer tensor parallelism for fast interconnect
# Fall back to pipeline for slower interconnect
# Heuristic: Use tensor parallelism up to 8 devices
# Then add pipeline parallelism
if available_devices <= 8:
return {
"strategy": "tensor_parallel",
"data_parallel_size": 1,
"tensor_parallel_size": available_devices,
"pipeline_parallel_size": 1,
"mesh_shape": (1, available_devices),
"mesh_axis_names": ("data", "model"),
}
else:
# 3D parallelism
# Allocate: 25% data, 50% tensor, 25% pipeline
# Adjust to factors of available_devices
# Find factors
import math
def find_factors(n):
factors = []
for i in range(1, int(math.sqrt(n)) + 1):
if n % i == 0:
factors.append((i, n // i))
return factors
# Simple heuristic: balance tensor and pipeline
tensor_size = 4 # Typical good value
if available_devices % tensor_size == 0:
remaining = available_devices // tensor_size
# Split remaining between data and pipeline
data_size = max(1, remaining // 2)
pipeline_size = remaining // data_size
return {
"strategy": "hybrid_3d",
"data_parallel_size": data_size,
"tensor_parallel_size": tensor_size,
"pipeline_parallel_size": pipeline_size,
"mesh_shape": (data_size, tensor_size, pipeline_size),
"mesh_axis_names": ("data", "model", "pipeline"),
}
# Fallback: Just use available devices for tensor parallelism
return {
"strategy": "tensor_parallel",
"data_parallel_size": 1,
"tensor_parallel_size": available_devices,
"pipeline_parallel_size": 1,
"mesh_shape": (1, available_devices),
"mesh_axis_names": ("data", "model"),
}
# Example usage
strategy = select_parallelism_strategy(
model_params=175, # 175B parameters (GPT-3 scale)
available_devices=64,
device_memory_gb=40, # A100 40GB
)
print(f"Selected strategy: {strategy['strategy']}")
print(f"Data parallel: {strategy['data_parallel_size']}")
print(f"Tensor parallel: {strategy['tensor_parallel_size']}")
print(f"Pipeline parallel: {strategy['pipeline_parallel_size']}")
print(f"Mesh shape: {strategy['mesh_shape']}")
Memory Optimization¤
Techniques to maximize model size on limited memory.
Activation Checkpointing¤
Trade computation for memory by recomputing activations:
from jax.ad_checkpoint import checkpoint as jax_checkpoint
class CheckpointedTransformerLayer(nnx.Module):
"""Transformer layer with activation checkpointing."""
def __init__(
self,
hidden_size: int,
num_heads: int,
*,
rngs: nnx.Rngs,
use_remat: bool = True,
):
super().__init__()
self.use_remat = use_remat
# Standard transformer layer components
self.attention = MultiHeadAttention(hidden_size, num_heads, rngs=rngs)
self.ffn = FeedForward(hidden_size, 4 * hidden_size, rngs=rngs)
self.ln1 = nnx.LayerNorm(hidden_size, rngs=rngs)
self.ln2 = nnx.LayerNorm(hidden_size, rngs=rngs)
def _forward_attention(self, x: jax.Array) -> jax.Array:
"""Attention block (may be checkpointed)."""
residual = x
x = self.ln1(x)
x = self.attention(x)
x = residual + x
return x
def _forward_ffn(self, x: jax.Array) -> jax.Array:
"""Feed-forward block (may be checkpointed)."""
residual = x
x = self.ln2(x)
x = self.ffn(x)
x = residual + x
return x
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with optional checkpointing."""
if self.use_remat:
# Checkpoint both attention and FFN
# Activations will be recomputed during backward pass
x = jax_checkpoint(self._forward_attention)(x)
x = jax_checkpoint(self._forward_ffn)(x)
else:
# Standard forward pass
x = self._forward_attention(x)
x = self._forward_ffn(x)
return x
# Compare memory usage
def measure_memory(use_checkpointing: bool):
"""Measure memory usage with/without checkpointing."""
layer = CheckpointedTransformerLayer(
hidden_size=1024,
num_heads=16,
rngs=nnx.Rngs(0),
use_remat=use_checkpointing,
)
# Dummy forward-backward pass
x = jnp.ones((32, 512, 1024)) # Large batch and sequence
def loss_fn(layer, x):
output = layer(x)
return jnp.mean(output ** 2)
# Compute gradients (triggers backward pass)
loss, grads = nnx.value_and_grad(loss_fn)(layer, x)
return loss, grads
# Without checkpointing: ~10GB peak memory
# With checkpointing: ~5GB peak memory (50% reduction)
# But ~30% slower due to recomputation
Selective Checkpointing¤
Checkpoint only memory-intensive operations:
class SelectiveCheckpointedLayer(nnx.Module):
"""Layer with selective activation checkpointing."""
def __init__(
self,
hidden_size: int,
*,
rngs: nnx.Rngs,
checkpoint_attention: bool = True,
checkpoint_ffn: bool = False,
):
super().__init__()
self.checkpoint_attention = checkpoint_attention
self.checkpoint_ffn = checkpoint_ffn
self.attention = MultiHeadAttention(hidden_size, 16, rngs=rngs)
self.ffn = FeedForward(hidden_size, 4 * hidden_size, rngs=rngs)
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass with selective checkpointing."""
# Attention: Large activations (seq_len^2), worth checkpointing
if self.checkpoint_attention:
x = jax_checkpoint(self.attention)(x)
else:
x = self.attention(x)
# FFN: Smaller activations, maybe don't checkpoint
if self.checkpoint_ffn:
x = jax_checkpoint(self.ffn)(x)
else:
x = self.ffn(x)
return x
# Rule of thumb:
# - Checkpoint attention (quadratic memory in sequence length)
# - Don't checkpoint FFN (linear memory, fast recompute)
# - Checkpoint every N layers in deep models
Gradient Accumulation¤
Simulate larger batches with gradient accumulation:
@jax.jit
def train_step_with_accumulation(
model_state,
batch,
optimizer_state,
num_accumulation_steps: int = 4
):
"""Training step with gradient accumulation for memory efficiency."""
model_graphdef, _ = nnx.split(model)
# Split batch into sub-batches
sub_batch_size = batch["data"].shape[0] // num_accumulation_steps
# Initialize accumulated gradients
accumulated_grads = None
total_loss = 0.0
# Accumulate gradients over sub-batches
for i in range(num_accumulation_steps):
start_idx = i * sub_batch_size
end_idx = (i + 1) * sub_batch_size
sub_batch = {
"data": batch["data"][start_idx:end_idx]
}
# Compute gradients for sub-batch
def loss_fn(state):
model = nnx.merge(model_graphdef, state)
output = model(sub_batch["data"])
return output["loss"]
loss, grads = nnx.value_and_grad(loss_fn)(model_state)
# Accumulate gradients
if accumulated_grads is None:
accumulated_grads = grads
else:
accumulated_grads = jax.tree.map(
lambda acc, g: acc + g,
accumulated_grads,
grads
)
total_loss += loss
# Average accumulated gradients
accumulated_grads = jax.tree.map(
lambda g: g / num_accumulation_steps,
accumulated_grads
)
# Single optimizer update
updates, optimizer_state = optimizer.update(
accumulated_grads,
optimizer_state
)
model_state = optax.apply_updates(model_state, updates)
return model_state, optimizer_state, total_loss / num_accumulation_steps
# Memory usage: 1/num_accumulation_steps of full batch
# Gradient noise: Same as full batch
# Training speed: ~num_accumulation_steps slower
Troubleshooting¤
Common issues and solutions in model parallelism.
Communication Overhead¤
Problem: Training slower than expected due to communication.
Solutions:
- Profile Communication:
import jax.profiler
# Profile training step
with jax.profiler.trace("/tmp/trace"):
train_step(model_state, batch)
# Look for:
# - All-reduce time (should be <20% of step time)
# - All-gather time
# - Point-to-point communication
- Reduce Tensor Parallelism:
# If communication > computation, reduce parallelism
# Bad: 8-way tensor parallel on slow interconnect
mesh = Mesh(devices.reshape(1, 8), ("data", "model"))
# Good: 2-way tensor parallel, 4-way data parallel
mesh = Mesh(devices.reshape(4, 2), ("data", "model"))
- Use Larger Microbatches:
# Larger microbatches amortize communication
# Bad: Too many small microbatches
model(x, num_microbatches=16) # High overhead
# Good: Fewer larger microbatches
model(x, num_microbatches=4) # Lower overhead
Load Imbalance¤
Problem: Some devices idle while others compute.
Solutions:
- Balance Pipeline Stages:
# Profile each stage
for stage_id, stage in enumerate(model.stages):
start = time.time()
stage(x)
duration = time.time() - start
print(f"Stage {stage_id}: {duration:.3f}s")
# Rebalance: Move layers from slow stages to fast stages
- Adjust Parallelism Dimensions:
# If tensor parallel devices imbalanced, adjust mesh
# Bad: Imbalanced load
mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
# Good: More balanced
mesh = Mesh(devices.reshape(4, 2), ("data", "model"))
Memory Fragmentation¤
Problem: Out of memory despite sufficient total memory.
Solutions:
- Use Gradient Checkpointing:
# Reduce peak memory with checkpointing
layer = CheckpointedTransformerLayer(
hidden_size=1024,
num_heads=16,
rngs=rngs,
use_remat=True, # Enable checkpointing
)
- Increase Pipeline Stages:
# More pipeline stages = less memory per device
# But more communication and pipeline bubbles
model = PipelineParallelModel(
layer_configs=layer_configs,
num_microbatches=8, # Increase microbatches too
rngs=rngs,
)
Best Practices¤
DO¤
- ✅ Start with data parallelism - simplest and most efficient
- ✅ Profile before optimizing - measure actual bottlenecks
- ✅ Use tensor parallelism for large layers - effective for transformers
- ✅ Use pipeline parallelism for many layers - good for deep models
- ✅ Combine strategies for extreme scale - 3D parallelism
- ✅ Use activation checkpointing - when memory-constrained
- ✅ Balance pipeline stages - equal computation per stage
- ✅ Match parallelism to interconnect - tensor needs fast links
- ✅ Test on small scale first - validate before scaling
- ✅ Monitor communication overhead - should be <20%
DON'T¤
- ❌ Don't over-parallelize - diminishing returns beyond 8-16 devices
- ❌ Don't mix strategies randomly - profile and measure
- ❌ Don't ignore load imbalance - causes bubbles and idle time
- ❌ Don't checkpoint everything - balance memory vs. compute
- ❌ Don't use pipeline for small models - overhead not worth it
- ❌ Don't use tensor parallel on slow interconnect - communication dominates
- ❌ Don't forget gradient averaging - affects convergence
- ❌ Don't assume linear scaling - measure actual speedup
- ❌ Don't ignore pipeline bubbles - can waste 20-30% of time
- ❌ Don't skip testing - parallelism bugs are subtle
Summary¤
Model parallelism enables training large models:
- Tensor Parallelism: Split weight matrices across devices
- Pipeline Parallelism: Distribute layers across devices
- Hybrid Parallelism: Combine strategies for extreme scale
- Memory Optimization: Checkpointing and gradient accumulation
Key trade-offs:
- Tensor parallel: High communication, good for large layers
- Pipeline parallel: Low communication, pipeline bubbles
- Hybrid: Scales to extreme sizes, complex to implement
Choose based on:
- Model size and architecture
- Available devices and interconnect
- Memory constraints
- Target throughput
Next Steps¤
-
Checkpointing
Learn about gradient checkpointing and model checkpointing
-
Custom Architectures
Build custom model architectures with parallelism
-
Distributed Training
Return to distributed training overview
-
Training Guide
Complete training documentation and best practices