etha.comm.utils#

Helper functions for communication operations.

Pure utility functions with no side effects.

Functions#

enumerate_partial_subgroup_ranks(→ list[list[int]])

Sub-group rank lists, one per slice through mesh_dim_idx.

get_slice_from_multi_index(→ tuple[slice, Ellipsis])

Convert multi-dimensional index to linear index and return corresponding slice tuple.

get_slicer_tuples(→ list[tuple[slice, Ellipsis]])

Pre-compute all slice tuples for a tensor partitioned by num_slicers.

Module Contents#

etha.comm.utils.enumerate_partial_subgroup_ranks(mesh_tensor: torch.Tensor, mesh_dim_idx: int) list[list[int]]#

Sub-group rank lists, one per slice through mesh_dim_idx.

Each sub-group is the set of ranks that share all coordinates except along mesh_dim_idx — i.e., the ranks that contribute to one all-reduce when collapsing a Partial dim to Replicate.

etha.comm.utils.get_slice_from_multi_index(source_idx: tuple, source_num_slicers: list[int], slicer_tuples: list[tuple[slice, Ellipsis]]) tuple[slice, Ellipsis]#

Convert multi-dimensional index to linear index and return corresponding slice tuple.

Parameters:
  • source_idx – Multi-dimensional index (e.g., (0, 1))

  • source_num_slicers – Number of slices per dimension

  • slicer_tuples – Pre-computed slice tuples from get_slicer_tuples()

Returns:

Slice tuple for the given index

etha.comm.utils.get_slicer_tuples(tensor_shape: torch.Size, source_num_slicers: list[int]) list[tuple[slice, Ellipsis]]#

Pre-compute all slice tuples for a tensor partitioned by num_slicers.

Parameters:
  • tensor_shape – Shape of the tensor to slice

  • source_num_slicers – Number of slices per dimension

Returns:

List of slice tuples, one for each chunk