Skip to content

Scaling & Distributed Training¤

Comprehensive tools for scaling generative model training across multiple devices and accelerators.

Overview¤

Artifex provides robust infrastructure for scaling model training from single-GPU experiments to multi-node distributed setups. The scaling module offers:

  • Device Mesh Management


    Create and optimize device meshes for different workloads

    Mesh Utilities

  • Sharding Strategies


    Data, tensor, FSDP, and pipeline parallelism

    Sharding Strategies

  • Multi-Dimensional Parallelism


    Combine strategies for optimal performance

    Multi-Dimensional

  • Configuration


    Flexible configuration for complex setups

    Configuration


Quick Start¤

Basic Data Parallel Training¤

import jax
from artifex.generative_models.scaling import mesh_utils, sharding

# Get available devices
devices = jax.devices()
print(f"Available devices: {len(devices)}")

# Create sharding config for data parallelism
config = sharding.ShardingConfig.from_device_count(len(devices))

# Create parallelism config with mesh topology
parallel_config = sharding.ParallelismConfig.from_sharding_config(config)

# Create device mesh
mesh_manager = mesh_utils.DeviceMeshManager(
    mesh_shape=parallel_config.mesh_shape,
    axis_names=parallel_config.mesh_axis_names,
)
mesh = mesh_manager.create_mesh_from_config(parallel_config)

print(f"Mesh shape: {parallel_config.mesh_shape}")
print(f"Axis names: {parallel_config.mesh_axis_names}")

Tensor Parallel Setup¤

from artifex.generative_models.scaling.sharding import (
    ShardingConfig,
    ParallelismConfig,
    TensorParallelStrategy,
)

# Configure tensor parallelism for large models
config = ShardingConfig(
    data_parallel_size=2,
    tensor_parallel_size=4,  # 8 GPUs total
)

# Create tensor parallel strategy
tensor_strategy = TensorParallelStrategy(
    axis_name="model",
    mesh_axis=1,
    shard_dimension="out_features",
)

# Get partition specs for attention layers
qkv_spec = tensor_strategy.get_attention_qkv_spec()
output_spec = tensor_strategy.get_attention_output_spec()

Device Mesh Management¤

The DeviceMeshManager provides utilities for creating and optimizing device meshes.

Creating a Device Mesh¤

from artifex.generative_models.scaling.mesh_utils import (
    DeviceMeshManager,
    create_device_mesh,
)

# Simple mesh creation
mesh = create_device_mesh(
    mesh_shape=(4, 2),        # 4 data parallel x 2 tensor parallel
    axis_names=("data", "model"),
)

# Using DeviceMeshManager for more control
manager = DeviceMeshManager(
    mesh_shape=(4, 2),
    axis_names=("data", "model"),
)

# Get optimal mesh shape for device count
optimal_shape = manager.get_optimal_mesh_shape(
    device_count=8,
    dimensions=2,
)
print(f"Optimal shape for 8 devices: {optimal_shape}")

Optimizing for Transformers¤

# Optimize mesh for transformer workloads
optimal_shape = manager.optimize_for_transformer(
    device_count=8,
    model_size="7B",
    sequence_length=2048,
)
print(f"Optimal shape for 7B model: {optimal_shape}")

# For larger models, more tensor parallelism
large_shape = manager.optimize_for_transformer(
    device_count=32,
    model_size="70B",
    sequence_length=4096,
)
print(f"Optimal shape for 70B model: {large_shape}")

Validation¤

# Validate mesh configuration before use
is_valid = manager.validate_mesh_config(
    mesh_shape=(4, 2),
    device_count=8,
)
print(f"Configuration valid: {is_valid}")

Sharding Strategies¤

Artifex provides multiple sharding strategies for different parallelism approaches.

Data Parallel Strategy¤

Shards data across devices while replicating model parameters.

from artifex.generative_models.scaling.sharding import DataParallelStrategy

strategy = DataParallelStrategy(axis_name="data", mesh_axis=0)

# Get partition spec for a batch of data
# Shape: (batch, sequence, hidden)
spec = strategy.get_partition_spec(("batch", "sequence", "hidden"))
# Result: PartitionSpec("data", None, None)

# Apply sharding to data
sharded_data = strategy.apply_sharding(batch_data, mesh)

FSDP Strategy¤

Fully Sharded Data Parallel for memory-efficient training.

from artifex.generative_models.scaling.sharding import FSDPStrategy

strategy = FSDPStrategy(
    axis_name="fsdp",
    mesh_axis=0,
    min_weight_size=1024,  # Only shard weights >= 1024 in first dim
)

# Check if a weight should be sharded
should_shard = strategy.should_shard_weight(large_weight_matrix)

# Apply FSDP sharding
sharded_weights = strategy.apply_sharding(weights, mesh)

Tensor Parallel Strategy¤

Shards model computation across devices.

from artifex.generative_models.scaling.sharding import TensorParallelStrategy

strategy = TensorParallelStrategy(
    axis_name="model",
    mesh_axis=1,
    shard_dimension="out_features",
)

# Get specs for attention layers
qkv_spec = strategy.get_attention_qkv_spec()      # Shard output
output_spec = strategy.get_attention_output_spec() # Shard input

# Get specs for linear layers
linear_spec = strategy.get_linear_weight_spec()

Pipeline Parallel Strategy¤

Distributes model layers across devices.

from artifex.generative_models.scaling.sharding import PipelineParallelStrategy

strategy = PipelineParallelStrategy(
    axis_name="pipeline",
    mesh_axis=2,
    num_stages=4,
)

# Assign 24 transformer layers to 4 pipeline stages
layer_assignments = strategy.assign_layers_to_stages(num_layers=24)
# Result: [6, 6, 6, 6] - 6 layers per stage

# Get communication patterns
forward_pattern = strategy.get_forward_communication_pattern()
backward_pattern = strategy.get_backward_communication_pattern()

Multi-Dimensional Parallelism¤

Combine multiple strategies for optimal large-scale training.

from artifex.generative_models.scaling.sharding import (
    MultiDimensionalStrategy,
    DataParallelStrategy,
    TensorParallelStrategy,
    FSDPStrategy,
    ParallelismConfig,
    ShardingConfig,
)

# Create individual strategies
data_strategy = DataParallelStrategy(axis_name="data", mesh_axis=0)
tensor_strategy = TensorParallelStrategy(
    axis_name="model",
    mesh_axis=1,
    shard_dimension="out_features",
)
fsdp_strategy = FSDPStrategy(axis_name="data", mesh_axis=0)

# Combine into multi-dimensional strategy
config = ShardingConfig(
    data_parallel_size=4,
    tensor_parallel_size=2,
    fsdp_enabled=True,
)
parallel_config = ParallelismConfig.from_sharding_config(config)

multi_strategy = MultiDimensionalStrategy(
    strategies={
        "data": data_strategy,
        "tensor": tensor_strategy,
        "fsdp": fsdp_strategy,
    },
    config=parallel_config,
)

# Get combined partition spec for a tensor
combined_spec = multi_strategy.get_combined_partition_spec(
    tensor_name="attention.query",
    tensor_shape=("batch", "sequence", "hidden"),
)

Configuration¤

ShardingConfig¤

Defines parallelism dimensions.

from artifex.generative_models.scaling.sharding import ShardingConfig

# Manual configuration
config = ShardingConfig(
    data_parallel_size=4,
    tensor_parallel_size=2,
    pipeline_parallel_size=1,
    fsdp_enabled=True,
    fsdp_min_weight_size=1024,
)

# Auto-configure from device count
auto_config = ShardingConfig.from_device_count(device_count=8)

# Get total device requirement
total_devices = config.get_total_device_count()  # 4 * 2 * 1 = 8

ParallelismConfig¤

Complete parallelism configuration with mesh topology.

from artifex.generative_models.scaling.sharding import ParallelismConfig

# Create from sharding config
parallel_config = ParallelismConfig.from_sharding_config(config)

# Access mesh configuration
print(f"Mesh shape: {parallel_config.mesh_shape}")
print(f"Axis names: {parallel_config.mesh_axis_names}")

# Validate configuration
is_valid = parallel_config.is_valid()

Best Practices¤

Choosing Parallelism Strategy¤

Model Size Devices Recommended Strategy
< 1B params 1-8 Data Parallel
1B - 10B 8-32 Data + Tensor Parallel
10B - 100B 32-128 Data + Tensor + FSDP
> 100B 128+ All strategies + Pipeline

Memory Optimization¤

# For memory-constrained setups, enable FSDP
config = ShardingConfig(
    data_parallel_size=4,
    fsdp_enabled=True,
    fsdp_min_weight_size=512,  # Shard smaller weights
)

# For very large models, combine with tensor parallel
config = ShardingConfig(
    data_parallel_size=2,
    tensor_parallel_size=4,
    fsdp_enabled=True,
)

Performance Tips¤

  1. Balance dimensions: Avoid extreme ratios in mesh shape
  2. Match workload: Use transformer-optimized shapes for transformers
  3. Validate configs: Always validate before creating meshes
  4. Monitor memory: Enable FSDP for memory-constrained scenarios

API Reference¤

Mesh Utilities¤

artifex.generative_models.scaling.mesh_utils ¤

Device mesh management utilities for scalable distributed training.

This module provides comprehensive device mesh management including: - Device mesh creation and optimization - Topology optimization for different workloads - Validation and configuration utilities - Hardware-aware mesh shape calculation

All implementations prioritize performance and follow JAX/Flax NNX patterns.

DeviceMeshManager ¤

DeviceMeshManager(
    mesh_shape: tuple[int, ...], axis_names: tuple[str, ...]
)

Device mesh management for distributed training optimization.

Provides utilities for creating and optimizing device meshes based on available hardware and workload characteristics.

mesh_shape instance-attribute ¤

mesh_shape = mesh_shape

axis_names instance-attribute ¤

axis_names = axis_names

create_mesh ¤

create_mesh(
    mesh_shape: tuple[int, ...], axis_names: tuple[str, ...]
) -> Mesh

Create device mesh with specified shape and axis names.

Parameters:

Name Type Description Default
mesh_shape tuple[int, ...]

Shape of the device mesh

required
axis_names tuple[str, ...]

Names for each mesh axis

required

Returns:

Type Description
Mesh

JAX device mesh

create_mesh_from_config ¤

create_mesh_from_config(config: ParallelismConfig) -> Mesh

Create mesh from parallelism configuration.

Parameters:

Name Type Description Default
config ParallelismConfig

Parallelism configuration

required

Returns:

Type Description
Mesh

JAX device mesh

get_optimal_mesh_shape ¤

get_optimal_mesh_shape(
    device_count: int, dimensions: int = 2
) -> tuple[int, ...]

Calculate optimal mesh shape for given device count.

Parameters:

Name Type Description Default
device_count int

Number of available devices

required
dimensions int

Number of mesh dimensions

2

Returns:

Type Description
tuple[int, ...]

Optimal mesh shape tuple

optimize_for_transformer ¤

optimize_for_transformer(
    device_count: int, model_size: str, sequence_length: int
) -> tuple[int, ...]

Optimize mesh shape for transformer workloads.

Parameters:

Name Type Description Default
device_count int

Number of available devices

required
model_size str

Model size (e.g., '7B', '13B', '70B')

required
sequence_length int

Input sequence length

required

Returns:

Type Description
tuple[int, ...]

Optimized mesh shape for transformer workloads

validate_mesh_config ¤

validate_mesh_config(
    mesh_shape: tuple[int, ...], device_count: int
) -> bool

Validate mesh configuration.

Parameters:

Name Type Description Default
mesh_shape tuple[int, ...]

Proposed mesh shape

required
device_count int

Available device count

required

Returns:

Type Description
bool

True if configuration is valid

create_device_mesh ¤

create_device_mesh(
    mesh_shape: tuple[int, ...], axis_names: tuple[str, ...]
) -> Mesh

Create device mesh with specified configuration.

Parameters:

Name Type Description Default
mesh_shape tuple[int, ...]

Shape of the device mesh

required
axis_names tuple[str, ...]

Names for each mesh axis

required

Returns:

Type Description
Mesh

JAX device mesh

get_optimal_mesh_shape ¤

get_optimal_mesh_shape(
    device_count: int, parallelism_config: ParallelismConfig
) -> tuple[int, ...]

Get optimal mesh shape for given configuration.

Parameters:

Name Type Description Default
device_count int

Number of available devices

required
parallelism_config ParallelismConfig

Parallelism configuration

required

Returns:

Type Description
tuple[int, ...]

Optimal mesh shape

Sharding¤

artifex.generative_models.scaling.sharding ¤

Sharding strategies and parallelism configuration for scalable training.

This module provides comprehensive sharding infrastructure including: - Abstract base class for sharding strategies - Concrete implementations for different parallelism types - Multi-dimensional parallelism support - Configuration management for complex sharding setups

All implementations prioritize performance and follow JAX/Flax NNX patterns.

ShardingConfig dataclass ¤

ShardingConfig(
    data_parallel_size: int = 1,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    fsdp_enabled: bool = False,
    fsdp_min_weight_size: int = 1024,
)

Configuration for multi-dimensional parallelism setup.

Defines the parallelism dimensions and FSDP settings for a model.

data_parallel_size class-attribute instance-attribute ¤

data_parallel_size: int = 1

tensor_parallel_size class-attribute instance-attribute ¤

tensor_parallel_size: int = 1

pipeline_parallel_size class-attribute instance-attribute ¤

pipeline_parallel_size: int = 1

fsdp_enabled class-attribute instance-attribute ¤

fsdp_enabled: bool = False

fsdp_min_weight_size class-attribute instance-attribute ¤

fsdp_min_weight_size: int = 1024

get_total_device_count ¤

get_total_device_count() -> int

Calculate total devices needed for this configuration.

from_device_count classmethod ¤

from_device_count(device_count: int) -> ShardingConfig

Create optimal sharding config for given device count.

Uses heuristics to balance different parallelism dimensions.

ParallelismConfig dataclass ¤

ParallelismConfig(
    mesh_shape: tuple[int, ...],
    mesh_axis_names: tuple[str, ...],
    sharding_config: ShardingConfig,
)

Complete parallelism configuration including mesh topology.

Combines sharding configuration with device mesh setup.

mesh_shape instance-attribute ¤

mesh_shape: tuple[int, ...]

mesh_axis_names instance-attribute ¤

mesh_axis_names: tuple[str, ...]

sharding_config instance-attribute ¤

sharding_config: ShardingConfig

is_valid ¤

is_valid() -> bool

Validate that mesh shape matches sharding configuration.

from_sharding_config classmethod ¤

from_sharding_config(
    config: ShardingConfig,
) -> ParallelismConfig

Create parallelism config from sharding configuration.

ShardingStrategy ¤

ShardingStrategy(axis_name: str, mesh_axis: int)

Bases: ABC

Abstract base class for sharding strategies.

Defines the interface that all sharding strategies must implement for consistent handling of different parallelism types.

Parameters:

Name Type Description Default
axis_name str

Name of the mesh axis for this strategy

required
mesh_axis int

Index of the mesh axis

required

axis_name instance-attribute ¤

axis_name = axis_name

mesh_axis instance-attribute ¤

mesh_axis = mesh_axis

get_partition_spec abstractmethod ¤

get_partition_spec(
    tensor_shape: tuple[str, ...],
) -> PartitionSpec

Get partition specification for a tensor with given shape names.

Parameters:

Name Type Description Default
tensor_shape tuple[str, ...]

Tuple of dimension names for the tensor

required

Returns:

Type Description
PartitionSpec

PartitionSpec defining how to shard the tensor

apply_sharding abstractmethod ¤

apply_sharding(array: Array, mesh: Mesh) -> Array

Apply sharding to an array using the given mesh.

Parameters:

Name Type Description Default
array Array

JAX array to shard

required
mesh Mesh

Device mesh for sharding

required

Returns:

Type Description
Array

Sharded array

get_sharding_constraints ¤

get_sharding_constraints() -> dict[str, Any]

Get sharding constraints for this strategy.

Returns:

Type Description
dict[str, Any]

Dictionary of sharding constraints

DataParallelStrategy ¤

DataParallelStrategy(axis_name: str, mesh_axis: int)

Bases: ShardingStrategy

Data parallel sharding strategy.

Shards the batch dimension across devices while replicating model parameters and computation.

get_partition_spec ¤

get_partition_spec(
    tensor_shape: tuple[str, ...],
) -> PartitionSpec

Get partition spec for data parallel sharding.

Only shards the batch dimension, leaves others replicated.

apply_sharding ¤

apply_sharding(array: Array, mesh: Mesh) -> Array

Apply data parallel sharding to array.

FSDPStrategy ¤

FSDPStrategy(
    axis_name: str,
    mesh_axis: int,
    min_weight_size: int = 1024,
)

Bases: ShardingStrategy

Fully Sharded Data Parallel strategy.

Shards model parameters across devices to reduce memory usage while maintaining training efficiency.

Parameters:

Name Type Description Default
axis_name str

Name of the mesh axis

required
mesh_axis int

Index of the mesh axis

required
min_weight_size int

Minimum first dimension size to enable sharding

1024

min_weight_size instance-attribute ¤

min_weight_size = min_weight_size

should_shard_weight ¤

should_shard_weight(weight: Array) -> bool

Determine if a weight should be sharded based on its size.

Parameters:

Name Type Description Default
weight Array

Weight array to check

required

Returns:

Type Description
bool

True if weight should be sharded, False otherwise

get_partition_spec ¤

get_partition_spec(
    tensor_shape: tuple[str, ...],
) -> PartitionSpec

Get partition spec for FSDP sharding.

Shards along the first dimension of weight tensors.

get_gradient_partition_spec ¤

get_gradient_partition_spec(
    tensor_shape: tuple[str, ...],
) -> PartitionSpec

Get partition spec for gradient sharding (same as weights).

apply_sharding ¤

apply_sharding(array: Array, mesh: Mesh) -> Array

Apply FSDP sharding to array.

TensorParallelStrategy ¤

TensorParallelStrategy(
    axis_name: str,
    mesh_axis: int,
    shard_dimension: str | None = None,
)

Bases: ShardingStrategy

Tensor parallel sharding strategy.

Shards model computation across devices by splitting tensors along specific dimensions (typically features).

Parameters:

Name Type Description Default
axis_name str

Name of the mesh axis

required
mesh_axis int

Index of the mesh axis

required
shard_dimension str | None

Preferred dimension to shard ('in_features' or 'out_features')

None

shard_dimension instance-attribute ¤

shard_dimension = shard_dimension

get_partition_spec ¤

get_partition_spec(
    tensor_shape: tuple[str, ...],
) -> PartitionSpec

Get partition spec for tensor parallel sharding.

get_linear_weight_spec ¤

get_linear_weight_spec() -> PartitionSpec

Get partition spec for linear layer weights.

get_attention_qkv_spec ¤

get_attention_qkv_spec() -> PartitionSpec

Get partition spec for attention QKV projections.

get_attention_output_spec ¤

get_attention_output_spec() -> PartitionSpec

Get partition spec for attention output projection.

apply_sharding ¤

apply_sharding(array: Array, mesh: Mesh) -> Array

Apply tensor parallel sharding to array.

PipelineParallelStrategy ¤

PipelineParallelStrategy(
    axis_name: str, mesh_axis: int, num_stages: int
)

Bases: ShardingStrategy

Pipeline parallel sharding strategy.

Distributes model layers across devices to enable pipeline parallelism for very large models that don't fit on single devices.

Parameters:

Name Type Description Default
axis_name str

Name of the mesh axis

required
mesh_axis int

Index of the mesh axis

required
num_stages int

Number of pipeline stages

required

num_stages instance-attribute ¤

num_stages = num_stages

assign_layers_to_stages ¤

assign_layers_to_stages(num_layers: int) -> list[int]

Assign layers to pipeline stages.

Parameters:

Name Type Description Default
num_layers int

Total number of layers in the model

required

Returns:

Type Description
list[int]

list of layer counts per stage

get_partition_spec ¤

get_partition_spec(
    tensor_shape: tuple[str, ...],
) -> PartitionSpec

Get partition spec for pipeline parallel sharding.

Pipeline parallelism doesn't shard individual tensors, but rather assigns entire layers to different devices.

get_forward_communication_pattern ¤

get_forward_communication_pattern() -> list[
    tuple[int, int]
]

Get communication pattern for forward pass.

Returns:

Type Description
list[tuple[int, int]]

list of (source_stage, dest_stage) pairs

get_backward_communication_pattern ¤

get_backward_communication_pattern() -> list[
    tuple[int, int]
]

Get communication pattern for backward pass.

Returns:

Type Description
list[tuple[int, int]]

list of (source_stage, dest_stage) pairs

apply_sharding ¤

apply_sharding(array: Array, mesh: Mesh) -> Array

Apply pipeline parallel sharding to array.

Pipeline parallelism handles layer assignment rather than tensor sharding.

MultiDimensionalStrategy ¤

MultiDimensionalStrategy(
    strategies: dict[str, ShardingStrategy],
    config: ParallelismConfig,
)

Multi-dimensional parallelism strategy combining multiple approaches.

Combines different sharding strategies (data, tensor, FSDP, pipeline) to achieve optimal performance for large-scale training.

Parameters:

Name Type Description Default
strategies dict[str, ShardingStrategy]

Dictionary mapping strategy names to strategy instances

required
config ParallelismConfig

Sharding configuration for the multi-dimensional strategy

required

strategies instance-attribute ¤

strategies = strategies

config instance-attribute ¤

config = config

get_combined_partition_spec ¤

get_combined_partition_spec(
    tensor_name: str, tensor_shape: tuple[str, ...]
) -> PartitionSpec

Get combined partition spec from all strategies.

Parameters:

Name Type Description Default
tensor_name str

Name/type of the tensor

required
tensor_shape tuple[str, ...]

Shape dimension names of the tensor

required

Returns:

Type Description
PartitionSpec

Combined PartitionSpec

resolve_sharding_conflicts ¤

resolve_sharding_conflicts(
    tensor_name: str,
    proposed_specs: dict[str, PartitionSpec],
) -> PartitionSpec

Resolve conflicts between multiple proposed partition specs.

Parameters:

Name Type Description Default
tensor_name str

Name of the tensor

required
proposed_specs dict[str, PartitionSpec]

Dictionary of strategy names to proposed specs

required

Returns:

Type Description
PartitionSpec

Resolved PartitionSpec

create_partition_spec ¤

create_partition_spec(
    param_shape: tuple[int, ...], param_name: str
) -> PartitionSpec

Create PartitionSpec for pipeline parallel sharding.