etha.comm.utils#
Helper functions for communication operations.
Pure utility functions with no side effects.
Functions#
|
Sub-group rank lists, one per slice through |
|
Convert multi-dimensional index to linear index and return corresponding slice tuple. |
|
Pre-compute all slice tuples for a tensor partitioned by num_slicers. |
Module Contents#
- etha.comm.utils.enumerate_partial_subgroup_ranks(mesh_tensor: torch.Tensor, mesh_dim_idx: int) list[list[int]]#
Sub-group rank lists, one per slice through
mesh_dim_idx.Each sub-group is the set of ranks that share all coordinates except along
mesh_dim_idx— i.e., the ranks that contribute to one all-reduce when collapsing aPartialdim toReplicate.
- etha.comm.utils.get_slice_from_multi_index(source_idx: tuple, source_num_slicers: list[int], slicer_tuples: list[tuple[slice, Ellipsis]]) tuple[slice, Ellipsis]#
Convert multi-dimensional index to linear index and return corresponding slice tuple.
- Parameters:
source_idx – Multi-dimensional index (e.g., (0, 1))
source_num_slicers – Number of slices per dimension
slicer_tuples – Pre-computed slice tuples from get_slicer_tuples()
- Returns:
Slice tuple for the given index
- etha.comm.utils.get_slicer_tuples(tensor_shape: torch.Size, source_num_slicers: list[int]) list[tuple[slice, Ellipsis]]#
Pre-compute all slice tuples for a tensor partitioned by num_slicers.
- Parameters:
tensor_shape – Shape of the tensor to slice
source_num_slicers – Number of slices per dimension
- Returns:
List of slice tuples, one for each chunk