Scaling & Distributed Training¤
Comprehensive tools for scaling generative model training across multiple devices and accelerators.
Overview¤
Artifex provides robust infrastructure for scaling model training from single-GPU experiments to multi-node distributed setups. The scaling module offers:
-
Device Mesh Management
Create and optimize device meshes for different workloads
-
Sharding Strategies
Data, tensor, FSDP, and pipeline parallelism
-
Multi-Dimensional Parallelism
Combine strategies for optimal performance
-
Configuration
Flexible configuration for complex setups
Quick Start¤
Basic Data Parallel Training¤
import jax
from artifex.generative_models.scaling import mesh_utils, sharding
# Get available devices
devices = jax.devices()
print(f"Available devices: {len(devices)}")
# Create sharding config for data parallelism
config = sharding.ShardingConfig.from_device_count(len(devices))
# Create parallelism config with mesh topology
parallel_config = sharding.ParallelismConfig.from_sharding_config(config)
# Create device mesh
mesh_manager = mesh_utils.DeviceMeshManager(
mesh_shape=parallel_config.mesh_shape,
axis_names=parallel_config.mesh_axis_names,
)
mesh = mesh_manager.create_mesh_from_config(parallel_config)
print(f"Mesh shape: {parallel_config.mesh_shape}")
print(f"Axis names: {parallel_config.mesh_axis_names}")
Tensor Parallel Setup¤
from artifex.generative_models.scaling.sharding import (
ShardingConfig,
ParallelismConfig,
TensorParallelStrategy,
)
# Configure tensor parallelism for large models
config = ShardingConfig(
data_parallel_size=2,
tensor_parallel_size=4, # 8 GPUs total
)
# Create tensor parallel strategy
tensor_strategy = TensorParallelStrategy(
axis_name="model",
mesh_axis=1,
shard_dimension="out_features",
)
# Get partition specs for attention layers
qkv_spec = tensor_strategy.get_attention_qkv_spec()
output_spec = tensor_strategy.get_attention_output_spec()
Device Mesh Management¤
The DeviceMeshManager provides utilities for creating and optimizing device meshes.
Creating a Device Mesh¤
from artifex.generative_models.scaling.mesh_utils import (
DeviceMeshManager,
create_device_mesh,
)
# Simple mesh creation
mesh = create_device_mesh(
mesh_shape=(4, 2), # 4 data parallel x 2 tensor parallel
axis_names=("data", "model"),
)
# Using DeviceMeshManager for more control
manager = DeviceMeshManager(
mesh_shape=(4, 2),
axis_names=("data", "model"),
)
# Get optimal mesh shape for device count
optimal_shape = manager.get_optimal_mesh_shape(
device_count=8,
dimensions=2,
)
print(f"Optimal shape for 8 devices: {optimal_shape}")
Optimizing for Transformers¤
# Optimize mesh for transformer workloads
optimal_shape = manager.optimize_for_transformer(
device_count=8,
model_size="7B",
sequence_length=2048,
)
print(f"Optimal shape for 7B model: {optimal_shape}")
# For larger models, more tensor parallelism
large_shape = manager.optimize_for_transformer(
device_count=32,
model_size="70B",
sequence_length=4096,
)
print(f"Optimal shape for 70B model: {large_shape}")
Validation¤
# Validate mesh configuration before use
is_valid = manager.validate_mesh_config(
mesh_shape=(4, 2),
device_count=8,
)
print(f"Configuration valid: {is_valid}")
Sharding Strategies¤
Artifex provides multiple sharding strategies for different parallelism approaches.
Data Parallel Strategy¤
Shards data across devices while replicating model parameters.
from artifex.generative_models.scaling.sharding import DataParallelStrategy
strategy = DataParallelStrategy(axis_name="data", mesh_axis=0)
# Get partition spec for a batch of data
# Shape: (batch, sequence, hidden)
spec = strategy.get_partition_spec(("batch", "sequence", "hidden"))
# Result: PartitionSpec("data", None, None)
# Apply sharding to data
sharded_data = strategy.apply_sharding(batch_data, mesh)
FSDP Strategy¤
Fully Sharded Data Parallel for memory-efficient training.
from artifex.generative_models.scaling.sharding import FSDPStrategy
strategy = FSDPStrategy(
axis_name="fsdp",
mesh_axis=0,
min_weight_size=1024, # Only shard weights >= 1024 in first dim
)
# Check if a weight should be sharded
should_shard = strategy.should_shard_weight(large_weight_matrix)
# Apply FSDP sharding
sharded_weights = strategy.apply_sharding(weights, mesh)
Tensor Parallel Strategy¤
Shards model computation across devices.
from artifex.generative_models.scaling.sharding import TensorParallelStrategy
strategy = TensorParallelStrategy(
axis_name="model",
mesh_axis=1,
shard_dimension="out_features",
)
# Get specs for attention layers
qkv_spec = strategy.get_attention_qkv_spec() # Shard output
output_spec = strategy.get_attention_output_spec() # Shard input
# Get specs for linear layers
linear_spec = strategy.get_linear_weight_spec()
Pipeline Parallel Strategy¤
Distributes model layers across devices.
from artifex.generative_models.scaling.sharding import PipelineParallelStrategy
strategy = PipelineParallelStrategy(
axis_name="pipeline",
mesh_axis=2,
num_stages=4,
)
# Assign 24 transformer layers to 4 pipeline stages
layer_assignments = strategy.assign_layers_to_stages(num_layers=24)
# Result: [6, 6, 6, 6] - 6 layers per stage
# Get communication patterns
forward_pattern = strategy.get_forward_communication_pattern()
backward_pattern = strategy.get_backward_communication_pattern()
Multi-Dimensional Parallelism¤
Combine multiple strategies for optimal large-scale training.
from artifex.generative_models.scaling.sharding import (
MultiDimensionalStrategy,
DataParallelStrategy,
TensorParallelStrategy,
FSDPStrategy,
ParallelismConfig,
ShardingConfig,
)
# Create individual strategies
data_strategy = DataParallelStrategy(axis_name="data", mesh_axis=0)
tensor_strategy = TensorParallelStrategy(
axis_name="model",
mesh_axis=1,
shard_dimension="out_features",
)
fsdp_strategy = FSDPStrategy(axis_name="data", mesh_axis=0)
# Combine into multi-dimensional strategy
config = ShardingConfig(
data_parallel_size=4,
tensor_parallel_size=2,
fsdp_enabled=True,
)
parallel_config = ParallelismConfig.from_sharding_config(config)
multi_strategy = MultiDimensionalStrategy(
strategies={
"data": data_strategy,
"tensor": tensor_strategy,
"fsdp": fsdp_strategy,
},
config=parallel_config,
)
# Get combined partition spec for a tensor
combined_spec = multi_strategy.get_combined_partition_spec(
tensor_name="attention.query",
tensor_shape=("batch", "sequence", "hidden"),
)
Configuration¤
ShardingConfig¤
Defines parallelism dimensions.
from artifex.generative_models.scaling.sharding import ShardingConfig
# Manual configuration
config = ShardingConfig(
data_parallel_size=4,
tensor_parallel_size=2,
pipeline_parallel_size=1,
fsdp_enabled=True,
fsdp_min_weight_size=1024,
)
# Auto-configure from device count
auto_config = ShardingConfig.from_device_count(device_count=8)
# Get total device requirement
total_devices = config.get_total_device_count() # 4 * 2 * 1 = 8
ParallelismConfig¤
Complete parallelism configuration with mesh topology.
from artifex.generative_models.scaling.sharding import ParallelismConfig
# Create from sharding config
parallel_config = ParallelismConfig.from_sharding_config(config)
# Access mesh configuration
print(f"Mesh shape: {parallel_config.mesh_shape}")
print(f"Axis names: {parallel_config.mesh_axis_names}")
# Validate configuration
is_valid = parallel_config.is_valid()
Best Practices¤
Choosing Parallelism Strategy¤
| Model Size | Devices | Recommended Strategy |
|---|---|---|
| < 1B params | 1-8 | Data Parallel |
| 1B - 10B | 8-32 | Data + Tensor Parallel |
| 10B - 100B | 32-128 | Data + Tensor + FSDP |
| > 100B | 128+ | All strategies + Pipeline |
Memory Optimization¤
# For memory-constrained setups, enable FSDP
config = ShardingConfig(
data_parallel_size=4,
fsdp_enabled=True,
fsdp_min_weight_size=512, # Shard smaller weights
)
# For very large models, combine with tensor parallel
config = ShardingConfig(
data_parallel_size=2,
tensor_parallel_size=4,
fsdp_enabled=True,
)
Performance Tips¤
- Balance dimensions: Avoid extreme ratios in mesh shape
- Match workload: Use transformer-optimized shapes for transformers
- Validate configs: Always validate before creating meshes
- Monitor memory: Enable FSDP for memory-constrained scenarios
API Reference¤
Mesh Utilities¤
artifex.generative_models.scaling.mesh_utils
¤
Device mesh management utilities for scalable distributed training.
This module provides comprehensive device mesh management including: - Device mesh creation and optimization - Topology optimization for different workloads - Validation and configuration utilities - Hardware-aware mesh shape calculation
All implementations prioritize performance and follow JAX/Flax NNX patterns.
DeviceMeshManager
¤
Device mesh management for distributed training optimization.
Provides utilities for creating and optimizing device meshes based on available hardware and workload characteristics.
create_mesh
¤
create_mesh_from_config
¤
create_mesh_from_config(config: ParallelismConfig) -> Mesh
Create mesh from parallelism configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
ParallelismConfig
|
Parallelism configuration |
required |
Returns:
| Type | Description |
|---|---|
Mesh
|
JAX device mesh |
get_optimal_mesh_shape
¤
optimize_for_transformer
¤
optimize_for_transformer(
device_count: int, model_size: str, sequence_length: int
) -> tuple[int, ...]
Optimize mesh shape for transformer workloads.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device_count
|
int
|
Number of available devices |
required |
model_size
|
str
|
Model size (e.g., '7B', '13B', '70B') |
required |
sequence_length
|
int
|
Input sequence length |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, ...]
|
Optimized mesh shape for transformer workloads |
validate_mesh_config
¤
create_device_mesh
¤
get_optimal_mesh_shape
¤
get_optimal_mesh_shape(
device_count: int, parallelism_config: ParallelismConfig
) -> tuple[int, ...]
Get optimal mesh shape for given configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device_count
|
int
|
Number of available devices |
required |
parallelism_config
|
ParallelismConfig
|
Parallelism configuration |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, ...]
|
Optimal mesh shape |
Sharding¤
artifex.generative_models.scaling.sharding
¤
Sharding strategies and parallelism configuration for scalable training.
This module provides comprehensive sharding infrastructure including: - Abstract base class for sharding strategies - Concrete implementations for different parallelism types - Multi-dimensional parallelism support - Configuration management for complex sharding setups
All implementations prioritize performance and follow JAX/Flax NNX patterns.
ShardingConfig
dataclass
¤
ShardingConfig(
data_parallel_size: int = 1,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
fsdp_enabled: bool = False,
fsdp_min_weight_size: int = 1024,
)
Configuration for multi-dimensional parallelism setup.
Defines the parallelism dimensions and FSDP settings for a model.
get_total_device_count
¤
get_total_device_count() -> int
Calculate total devices needed for this configuration.
from_device_count
classmethod
¤
from_device_count(device_count: int) -> ShardingConfig
Create optimal sharding config for given device count.
Uses heuristics to balance different parallelism dimensions.
ParallelismConfig
dataclass
¤
ParallelismConfig(
mesh_shape: tuple[int, ...],
mesh_axis_names: tuple[str, ...],
sharding_config: ShardingConfig,
)
Complete parallelism configuration including mesh topology.
Combines sharding configuration with device mesh setup.
from_sharding_config
classmethod
¤
from_sharding_config(
config: ShardingConfig,
) -> ParallelismConfig
Create parallelism config from sharding configuration.
ShardingStrategy
¤
Bases: ABC
Abstract base class for sharding strategies.
Defines the interface that all sharding strategies must implement for consistent handling of different parallelism types.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis_name
|
str
|
Name of the mesh axis for this strategy |
required |
mesh_axis
|
int
|
Index of the mesh axis |
required |
get_partition_spec
abstractmethod
¤
get_partition_spec(
tensor_shape: tuple[str, ...],
) -> PartitionSpec
Get partition specification for a tensor with given shape names.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor_shape
|
tuple[str, ...]
|
Tuple of dimension names for the tensor |
required |
Returns:
| Type | Description |
|---|---|
PartitionSpec
|
PartitionSpec defining how to shard the tensor |
apply_sharding
abstractmethod
¤
DataParallelStrategy
¤
Bases: ShardingStrategy
Data parallel sharding strategy.
Shards the batch dimension across devices while replicating model parameters and computation.
get_partition_spec
¤
get_partition_spec(
tensor_shape: tuple[str, ...],
) -> PartitionSpec
Get partition spec for data parallel sharding.
Only shards the batch dimension, leaves others replicated.
FSDPStrategy
¤
Bases: ShardingStrategy
Fully Sharded Data Parallel strategy.
Shards model parameters across devices to reduce memory usage while maintaining training efficiency.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis_name
|
str
|
Name of the mesh axis |
required |
mesh_axis
|
int
|
Index of the mesh axis |
required |
min_weight_size
|
int
|
Minimum first dimension size to enable sharding |
1024
|
should_shard_weight
¤
get_partition_spec
¤
get_partition_spec(
tensor_shape: tuple[str, ...],
) -> PartitionSpec
Get partition spec for FSDP sharding.
Shards along the first dimension of weight tensors.
get_gradient_partition_spec
¤
get_gradient_partition_spec(
tensor_shape: tuple[str, ...],
) -> PartitionSpec
Get partition spec for gradient sharding (same as weights).
TensorParallelStrategy
¤
Bases: ShardingStrategy
Tensor parallel sharding strategy.
Shards model computation across devices by splitting tensors along specific dimensions (typically features).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis_name
|
str
|
Name of the mesh axis |
required |
mesh_axis
|
int
|
Index of the mesh axis |
required |
shard_dimension
|
str | None
|
Preferred dimension to shard ('in_features' or 'out_features') |
None
|
get_partition_spec
¤
get_partition_spec(
tensor_shape: tuple[str, ...],
) -> PartitionSpec
Get partition spec for tensor parallel sharding.
get_linear_weight_spec
¤
get_linear_weight_spec() -> PartitionSpec
Get partition spec for linear layer weights.
get_attention_qkv_spec
¤
get_attention_qkv_spec() -> PartitionSpec
Get partition spec for attention QKV projections.
get_attention_output_spec
¤
get_attention_output_spec() -> PartitionSpec
Get partition spec for attention output projection.
PipelineParallelStrategy
¤
Bases: ShardingStrategy
Pipeline parallel sharding strategy.
Distributes model layers across devices to enable pipeline parallelism for very large models that don't fit on single devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis_name
|
str
|
Name of the mesh axis |
required |
mesh_axis
|
int
|
Index of the mesh axis |
required |
num_stages
|
int
|
Number of pipeline stages |
required |
assign_layers_to_stages
¤
get_partition_spec
¤
get_partition_spec(
tensor_shape: tuple[str, ...],
) -> PartitionSpec
Get partition spec for pipeline parallel sharding.
Pipeline parallelism doesn't shard individual tensors, but rather assigns entire layers to different devices.
get_forward_communication_pattern
¤
get_backward_communication_pattern
¤
MultiDimensionalStrategy
¤
MultiDimensionalStrategy(
strategies: dict[str, ShardingStrategy],
config: ParallelismConfig,
)
Multi-dimensional parallelism strategy combining multiple approaches.
Combines different sharding strategies (data, tensor, FSDP, pipeline) to achieve optimal performance for large-scale training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
strategies
|
dict[str, ShardingStrategy]
|
Dictionary mapping strategy names to strategy instances |
required |
config
|
ParallelismConfig
|
Sharding configuration for the multi-dimensional strategy |
required |
get_combined_partition_spec
¤
get_combined_partition_spec(
tensor_name: str, tensor_shape: tuple[str, ...]
) -> PartitionSpec
Get combined partition spec from all strategies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor_name
|
str
|
Name/type of the tensor |
required |
tensor_shape
|
tuple[str, ...]
|
Shape dimension names of the tensor |
required |
Returns:
| Type | Description |
|---|---|
PartitionSpec
|
Combined PartitionSpec |
resolve_sharding_conflicts
¤
resolve_sharding_conflicts(
tensor_name: str,
proposed_specs: dict[str, PartitionSpec],
) -> PartitionSpec
Resolve conflicts between multiple proposed partition specs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor_name
|
str
|
Name of the tensor |
required |
proposed_specs
|
dict[str, PartitionSpec]
|
Dictionary of strategy names to proposed specs |
required |
Returns:
| Type | Description |
|---|---|
PartitionSpec
|
Resolved PartitionSpec |
create_partition_spec
¤
create_partition_spec(
param_shape: tuple[int, ...], param_name: str
) -> PartitionSpec
Create PartitionSpec for pipeline parallel sharding.
Related Documentation¤
- Distributed Training Guide - User guide for distributed training
- Model Parallelism - Model parallelism techniques
- Training Guide - Core training concepts
- Device Management - Device manager API