Skip to content

Sharding¤

Module: generative_models.scaling.sharding

Source: generative_models/scaling/sharding.py

Overview¤

Sharding strategies and parallelism configuration for scalable training.

This module provides comprehensive sharding infrastructure including:

  • Abstract base class for sharding strategies
  • Concrete implementations for different parallelism types
  • Multi-dimensional parallelism support
  • Configuration management for complex sharding setups

All implementations prioritize performance and follow JAX/Flax NNX patterns.

Classes¤

DataParallelStrategy¤

class DataParallelStrategy

FSDPStrategy¤

class FSDPStrategy

MultiDimensionalStrategy¤

class MultiDimensionalStrategy

ParallelismConfig¤

class ParallelismConfig

PipelineParallelStrategy¤

class PipelineParallelStrategy

ShardingConfig¤

class ShardingConfig

ShardingStrategy¤

class ShardingStrategy

TensorParallelStrategy¤

class TensorParallelStrategy

Functions¤

init¤

def __init__()

init¤

def __init__()

init¤

def __init__()

init¤

def __init__()

init¤

def __init__()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

apply_sharding¤

def apply_sharding()

assign_layers_to_stages¤

def assign_layers_to_stages()

create_partition_spec¤

def create_partition_spec()

from_device_count¤

def from_device_count()

from_sharding_config¤

def from_sharding_config()

get_attention_output_spec¤

def get_attention_output_spec()

get_attention_qkv_spec¤

def get_attention_qkv_spec()

get_backward_communication_pattern¤

def get_backward_communication_pattern()

get_combined_partition_spec¤

def get_combined_partition_spec()

get_forward_communication_pattern¤

def get_forward_communication_pattern()

get_gradient_partition_spec¤

def get_gradient_partition_spec()

get_linear_weight_spec¤

def get_linear_weight_spec()

get_partition_spec¤

def get_partition_spec()

get_partition_spec¤

def get_partition_spec()

get_partition_spec¤

def get_partition_spec()

get_partition_spec¤

def get_partition_spec()

get_partition_spec¤

def get_partition_spec()

get_sharding_constraints¤

def get_sharding_constraints()

get_total_device_count¤

def get_total_device_count()

is_valid¤

def is_valid()

resolve_sharding_conflicts¤

def resolve_sharding_conflicts()

should_shard_weight¤

def should_shard_weight()

Module Statistics¤

  • Classes: 8
  • Functions: 31
  • Imports: 8