Skip to content

Device Placement¤

Module: artifex.generative_models.training.distributed.device_placement

Source: src/artifex/generative_models/training/distributed/device_placement.py

Overview¤

The DevicePlacement module provides utilities for explicit device placement of JAX arrays and PyTrees, enabling efficient data distribution across accelerators. It includes hardware-aware batch size recommendations based on JAX performance guidelines.

Enums¤

HardwareType¤

Enumeration of supported hardware types for batch size recommendations.

class HardwareType(Enum):
    """Enumeration of supported hardware types."""
    TPU_V5E = "tpu_v5e"
    TPU_V5P = "tpu_v5p"
    TPU_V4 = "tpu_v4"
    H100 = "h100"
    A100 = "a100"
    V100 = "v100"
    CPU = "cpu"
    UNKNOWN = "unknown"

Classes¤

BatchSizeRecommendation¤

Hardware-specific batch size recommendations dataclass.

@dataclass(frozen=True)
class BatchSizeRecommendation:
    """Hardware-specific batch size recommendations.

    Attributes:
        min_batch_size: Minimum batch size for reasonable efficiency.
        optimal_batch_size: Optimal batch size for peak throughput.
        critical_batch_size: Critical batch size for reaching roofline (per JAX guide).
        max_memory_batch_size: Maximum batch size before OOM (estimate).
        notes: Additional notes about the recommendation.
    """
    min_batch_size: int
    optimal_batch_size: int
    critical_batch_size: int
    max_memory_batch_size: int | None = None
    notes: str = ""

Hardware-Specific Values¤

Hardware Min Batch Optimal Critical Notes
TPU v5e 64 256 240 Critical batch size for reaching roofline
TPU v5p 128 512 480 Higher throughput, needs larger batches
TPU v4 64 256 192 Similar to v5e, slightly lower critical
H100 64 320 298 Critical batch size for roofline
A100 32 256 240 For 80GB variant
V100 16 128 96 Memory-limited on 16GB variant
CPU 1 32 16 Memory-bandwidth bound

DevicePlacement¤

Utility class for explicit device placement of JAX arrays.

class DevicePlacement:
    """Utility class for explicit device placement of JAX arrays.

    This class provides methods for placing arrays on specific devices,
    distributing batches across devices using sharding, and providing
    hardware-aware batch size recommendations.
    """

Constructor¤

def __init__(self, default_device: Any | None = None) -> None:
    """Initialize DevicePlacement.

    Args:
        default_device: Default device to use when none is specified.
            If None, uses jax.devices()[0].
    """

Methods¤

place_on_device¤

Place data on a specific device.

def place_on_device(
    self,
    data: Any,
    device: Any | None = None,
) -> Any:
    """Place data on a specific device.

    Args:
        data: PyTree of JAX arrays to place on device.
        device: Target device. If None, uses the default device.

    Returns:
        PyTree with arrays placed on the specified device.
    """
distribute_batch¤

Distribute data across devices using sharding.

def distribute_batch(
    self,
    data: Any,
    sharding: Sharding,
) -> Any:
    """Distribute data across devices using the specified sharding.

    Args:
        data: PyTree of JAX arrays to distribute.
        sharding: JAX Sharding specification.

    Returns:
        PyTree with arrays distributed according to the sharding.
    """
replicate_across_devices¤

Replicate data across all devices.

def replicate_across_devices(
    self,
    data: Any,
    devices: list[Any] | None = None,
) -> Any:
    """Replicate data across all specified devices.

    Args:
        data: PyTree of JAX arrays to replicate.
        devices: List of devices to replicate to. If None, uses all devices.

    Returns:
        PyTree with arrays replicated across devices.
    """
shard_batch_dim¤

Shard data along the batch dimension.

def shard_batch_dim(
    self,
    data: Any,
    mesh: Mesh,
    batch_axis: int = 0,
    mesh_axis: str = "data",
) -> Any:
    """Shard data along the batch dimension.

    This is the most common sharding pattern for data-parallel training,
    where each device processes a slice of the batch.

    Args:
        data: PyTree of JAX arrays to shard.
        mesh: Device mesh to shard across.
        batch_axis: The axis index representing the batch dimension.
        mesh_axis: The mesh axis name to shard along.

    Returns:
        PyTree with arrays sharded along the batch dimension.
    """
prefetch_to_device¤

Create a prefetching wrapper for async data placement.

def prefetch_to_device(
    self,
    data_iterator: Iterator[Any],
    device: Any | None = None,
    buffer_size: int = 2,
) -> Iterator[Any]:
    """Create a prefetching wrapper that places data on device asynchronously.

    This enables overlapping data transfer with computation for improved
    throughput in training loops.

    Args:
        data_iterator: Iterator yielding PyTrees of data.
        device: Target device for prefetching.
        buffer_size: Number of batches to prefetch.

    Returns:
        Iterator that yields device-placed data.
    """
get_batch_size_recommendation¤

Get hardware-specific batch size recommendations.

def get_batch_size_recommendation(
    self,
    hardware_type: HardwareType | None = None,
) -> BatchSizeRecommendation:
    """Get batch size recommendation for the current hardware.

    Args:
        hardware_type: Override hardware type. If None, uses detected type.

    Returns:
        BatchSizeRecommendation with hardware-specific values.
    """
validate_batch_size¤

Validate batch size against hardware recommendations.

def validate_batch_size(
    self,
    batch_size: int,
    warn_suboptimal: bool = True,
) -> tuple[bool, str]:
    """Validate batch size against hardware recommendations.

    Args:
        batch_size: The batch size to validate.
        warn_suboptimal: Whether to warn for suboptimal (but valid) sizes.

    Returns:
        Tuple of (is_valid, message).
    """
get_device_info¤

Get information about available devices.

def get_device_info(self) -> dict[str, Any]:
    """Get information about available devices.

    Returns:
        Dictionary containing device information including:
        - num_devices: Number of available devices
        - hardware_type: Detected hardware type
        - platforms: List of unique platforms
        - device_kinds: List of device kinds
        - devices: Detailed list of device info
    """

Properties¤

  • hardware_type: HardwareType - The detected hardware type
  • num_devices: int - Number of available devices

Convenience Functions¤

place_on_device¤

def place_on_device(data: Any, device: Any | None = None) -> Any:
    """Convenience function for placing data on a device.

    Args:
        data: PyTree of JAX arrays.
        device: Target device. If None, uses first available device.

    Returns:
        PyTree with arrays on the specified device.
    """

distribute_batch¤

def distribute_batch(data: Any, sharding: Sharding) -> Any:
    """Convenience function for distributing data across devices.

    Args:
        data: PyTree of JAX arrays.
        sharding: JAX Sharding specification.

    Returns:
        PyTree with arrays distributed according to sharding.
    """

get_batch_size_recommendation¤

def get_batch_size_recommendation(
    hardware_type: HardwareType | None = None,
) -> BatchSizeRecommendation:
    """Get batch size recommendation for current or specified hardware.

    Args:
        hardware_type: Hardware type to get recommendation for.

    Returns:
        BatchSizeRecommendation with hardware-specific values.
    """

Usage Examples¤

Basic Device Placement¤

from artifex.generative_models.training.distributed import (
    DevicePlacement,
    place_on_device,
)
import jax.numpy as jnp

# Create placement utility
placement = DevicePlacement()
print(f"Detected hardware: {placement.hardware_type}")
print(f"Available devices: {placement.num_devices}")

# Place data on default device
data = jnp.ones((32, 784))
placed_data = placement.place_on_device(data)

# Or use convenience function
placed_data = place_on_device(data)

Batch Size Validation¤

from artifex.generative_models.training.distributed import (
    DevicePlacement,
    HardwareType,
    get_batch_size_recommendation,
)

# Get recommendation for current hardware
placement = DevicePlacement()
rec = placement.get_batch_size_recommendation()
print(f"Optimal batch size: {rec.optimal_batch_size}")
print(f"Critical batch size: {rec.critical_batch_size}")

# Get recommendation for specific hardware
h100_rec = get_batch_size_recommendation(HardwareType.H100)
print(f"H100 critical batch: {h100_rec.critical_batch_size}")  # 298

# Validate batch size
is_valid, message = placement.validate_batch_size(256)
print(f"Valid: {is_valid}, Message: {message}")

Distributing Batches with Sharding¤

from artifex.generative_models.training.distributed import (
    DevicePlacement,
    distribute_batch,
)
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import jax
import numpy as np

# Create device mesh
devices = jax.devices()
mesh = Mesh(np.array(devices), axis_names=("data",))

# Create sharding for batch dimension
data_sharding = NamedSharding(mesh, PartitionSpec("data", None))

# Distribute data
placement = DevicePlacement()
batch = {"images": jnp.ones((8, 28, 28, 3)), "labels": jnp.zeros((8,))}
distributed = placement.distribute_batch(batch, data_sharding)

# Or use convenience function
distributed = distribute_batch(batch, data_sharding)

Sharding Along Batch Dimension¤

from artifex.generative_models.training.distributed import DevicePlacement
from jax.sharding import Mesh
import jax
import numpy as np

placement = DevicePlacement()

# Create mesh
devices = jax.devices()
mesh = Mesh(np.array(devices), axis_names=("data",))

# Shard batch along first dimension
batch = {
    "images": jnp.ones((16, 224, 224, 3)),
    "labels": jnp.zeros((16,), dtype=jnp.int32),
}
sharded_batch = placement.shard_batch_dim(batch, mesh)

Prefetching Data to Device¤

from artifex.generative_models.training.distributed import DevicePlacement

placement = DevicePlacement()

# Create a data iterator
def data_generator():
    for i in range(100):
        yield {"batch": jnp.ones((32, 784)) * i}

# Prefetch data to GPU with buffer of 2 batches
prefetched = placement.prefetch_to_device(
    data_generator(),
    buffer_size=2,
)

# Training loop with prefetched data
for batch in prefetched:
    # Data is already on GPU when we receive it
    process_batch(batch)

Replicating Model Weights¤

from artifex.generative_models.training.distributed import DevicePlacement

placement = DevicePlacement()

# Model weights to replicate
weights = {
    "layer1": jnp.ones((784, 256)),
    "layer2": jnp.ones((256, 10)),
}

# Replicate across all devices
replicated_weights = placement.replicate_across_devices(weights)

# Or replicate to specific devices
gpu_devices = jax.devices("gpu")[:2]
replicated_weights = placement.replicate_across_devices(weights, devices=gpu_devices)

Getting Device Information¤

from artifex.generative_models.training.distributed import DevicePlacement

placement = DevicePlacement()
info = placement.get_device_info()

print(f"Number of devices: {info['num_devices']}")
print(f"Hardware type: {info['hardware_type']}")
print(f"Platforms: {info['platforms']}")

for device in info['devices']:
    print(f"  Device {device['id']}: {device['platform']} ({device['device_kind']})")

Module Statistics¤

  • Classes: 2 (DevicePlacement, BatchSizeRecommendation)
  • Enums: 1 (HardwareType)
  • Convenience Functions: 3
  • Instance Methods: 7