Sharding¤
Module: generative_models.scaling.sharding
Source: generative_models/scaling/sharding.py
Overview¤
Sharding strategies and parallelism configuration for scalable training.
This module provides comprehensive sharding infrastructure including:
- Abstract base class for sharding strategies
- Concrete implementations for different parallelism types
- Multi-dimensional parallelism support
- Configuration management for complex sharding setups
All implementations prioritize performance and follow JAX/Flax NNX patterns.
Classes¤
DataParallelStrategy¤
FSDPStrategy¤
MultiDimensionalStrategy¤
ParallelismConfig¤
PipelineParallelStrategy¤
ShardingConfig¤
ShardingStrategy¤
TensorParallelStrategy¤
Functions¤
init¤
init¤
init¤
init¤
init¤
apply_sharding¤
apply_sharding¤
apply_sharding¤
apply_sharding¤
apply_sharding¤
assign_layers_to_stages¤
create_partition_spec¤
from_device_count¤
from_sharding_config¤
get_attention_output_spec¤
get_attention_qkv_spec¤
get_backward_communication_pattern¤
get_combined_partition_spec¤
get_forward_communication_pattern¤
get_gradient_partition_spec¤
get_linear_weight_spec¤
get_partition_spec¤
get_partition_spec¤
get_partition_spec¤
get_partition_spec¤
get_partition_spec¤
get_sharding_constraints¤
get_total_device_count¤
is_valid¤
resolve_sharding_conflicts¤
should_shard_weight¤
Module Statistics¤
- Classes: 8
- Functions: 31
- Imports: 8