Skip to content

Mesh Utils¤

Module: generative_models.scaling.mesh_utils

Source: generative_models/scaling/mesh_utils.py

Overview¤

Device mesh management utilities for scalable distributed training.

This module provides comprehensive device mesh management including:

  • Device mesh creation and optimization
  • Topology optimization for different workloads
  • Validation and configuration utilities
  • Hardware-aware mesh shape calculation

All implementations prioritize performance and follow JAX/Flax NNX patterns.

Classes¤

DeviceMeshManager¤

class DeviceMeshManager

Functions¤

init¤

def __init__()

create_device_mesh¤

def create_device_mesh()

create_mesh¤

def create_mesh()

create_mesh_from_config¤

def create_mesh_from_config()

get_optimal_mesh_shape¤

def get_optimal_mesh_shape()

get_optimal_mesh_shape¤

def get_optimal_mesh_shape()

optimize_for_transformer¤

def optimize_for_transformer()

validate_mesh_config¤

def validate_mesh_config()

Module Statistics¤

  • Classes: 1
  • Functions: 8
  • Imports: 4