Device Mesh Management¤
Module: artifex.generative_models.training.distributed.mesh
Source: src/artifex/generative_models/training/distributed/mesh.py
Overview¤
The DeviceMeshManager class provides utilities for creating and managing JAX device meshes for distributed training. It supports various parallelism strategies including data parallelism, model parallelism, and hybrid parallelism.
Classes¤
DeviceMeshManager¤
Manager for creating and configuring JAX device meshes.
class DeviceMeshManager:
"""Manager for creating and configuring JAX device meshes.
This class provides methods for creating device meshes with various
configurations for data parallelism, model parallelism, and hybrid
parallelism strategies.
"""
Constructor¤
def __init__(self, devices: Sequence[Any] | None = None) -> None:
"""Initialize DeviceMeshManager.
Args:
devices: Optional list of devices to use. If None, uses all
available devices from jax.devices().
"""
Methods¤
create_device_mesh¤
Create a device mesh with the specified shape.
def create_device_mesh(
self,
mesh_shape: dict[str, int] | list[tuple[str, int]],
devices: Sequence[Any] | None = None,
) -> Mesh:
"""Create a device mesh with the specified shape.
Args:
mesh_shape: Shape specification as either:
- dict mapping axis names to sizes, e.g., {"data": 2, "model": 1}
- list of (axis_name, size) tuples, e.g., [("data", 2), ("model", 1)]
devices: Optional list of devices to use.
Returns:
A JAX Mesh with the specified configuration.
Raises:
ValueError: If mesh requires more devices than available.
"""
create_data_parallel_mesh¤
Create a mesh for data parallelism.
def create_data_parallel_mesh(
self,
num_devices: int | None = None,
axis_name: str = "data",
) -> Mesh:
"""Create a mesh for data parallelism.
Args:
num_devices: Number of devices to use. If None, uses all available.
axis_name: Name of the data parallel axis.
Returns:
A JAX Mesh configured for data parallelism.
"""
create_model_parallel_mesh¤
Create a mesh for model parallelism.
def create_model_parallel_mesh(
self,
num_devices: int | None = None,
axis_name: str = "model",
) -> Mesh:
"""Create a mesh for model parallelism.
Args:
num_devices: Number of devices to use. If None, uses all available.
axis_name: Name of the model parallel axis.
Returns:
A JAX Mesh configured for model parallelism.
"""
create_hybrid_mesh¤
Create a mesh for hybrid data and model parallelism.
def create_hybrid_mesh(
self,
data_parallel_size: int = 1,
model_parallel_size: int = 1,
data_axis: str = "data",
model_axis: str = "model",
) -> Mesh:
"""Create a mesh for hybrid data and model parallelism.
Args:
data_parallel_size: Number of devices for data parallelism.
model_parallel_size: Number of devices for model parallelism.
data_axis: Name of the data parallel axis.
model_axis: Name of the model parallel axis.
Returns:
A JAX Mesh configured for hybrid parallelism.
"""
get_mesh_info¤
Get information about a device mesh.
def get_mesh_info(self, mesh: Mesh) -> dict[str, Any]:
"""Get information about a device mesh.
Args:
mesh: The mesh to get information about.
Returns:
Dictionary containing mesh information with keys:
- total_devices: Total number of devices in the mesh
- axes: Dict mapping axis names to their sizes
"""
Properties¤
num_devices: int- Number of available devicesdevices: list[Any]- List of available devices
Usage Examples¤
Basic Usage¤
from artifex.generative_models.training.distributed import DeviceMeshManager
# Create manager (uses all available devices)
manager = DeviceMeshManager()
print(f"Available devices: {manager.num_devices}")
# Create a data-parallel mesh using all devices
mesh = manager.create_data_parallel_mesh()
Data Parallelism¤
# Create mesh for data parallelism with 4 devices
manager = DeviceMeshManager()
mesh = manager.create_data_parallel_mesh(num_devices=4, axis_name="batch")
# Use the mesh with JAX sharding
from jax.sharding import NamedSharding, PartitionSpec
# Shard data along batch dimension
data_sharding = NamedSharding(mesh, PartitionSpec("batch"))
Model Parallelism¤
# Create mesh for model parallelism
manager = DeviceMeshManager()
mesh = manager.create_model_parallel_mesh(num_devices=2, axis_name="model")
# Shard model parameters
param_sharding = NamedSharding(mesh, PartitionSpec(None, "model"))
Hybrid Parallelism¤
# Create 2D mesh: 2 devices for data, 2 for model (total 4 devices)
manager = DeviceMeshManager()
mesh = manager.create_hybrid_mesh(
data_parallel_size=2,
model_parallel_size=2,
data_axis="data",
model_axis="model",
)
# Get mesh info
info = manager.get_mesh_info(mesh)
print(f"Total devices: {info['total_devices']}")
print(f"Axes: {info['axes']}") # {'data': 2, 'model': 2}
Using Dict or List Specification¤
manager = DeviceMeshManager()
# Using dict specification
mesh_dict = manager.create_device_mesh({"data": 2, "model": 1})
# Using list of tuples specification
mesh_list = manager.create_device_mesh([("data", 2), ("model", 1)])
Module Statistics¤
- Classes: 1 (
DeviceMeshManager) - Methods: 5 public methods
- Properties: 2 (
num_devices,devices)