etha.comm.utils
===============

.. py:module:: etha.comm.utils

.. autoapi-nested-parse::

   Helper functions for communication operations.

   Pure utility functions with no side effects.



Functions
---------

.. autoapisummary::

   etha.comm.utils.enumerate_partial_subgroup_ranks
   etha.comm.utils.get_slice_from_multi_index
   etha.comm.utils.get_slicer_tuples


Module Contents
---------------

.. py:function:: enumerate_partial_subgroup_ranks(mesh_tensor: torch.Tensor, mesh_dim_idx: int) -> list[list[int]]

   Sub-group rank lists, one per slice through ``mesh_dim_idx``.

   Each sub-group is the set of ranks that share all coordinates except along
   ``mesh_dim_idx`` — i.e., the ranks that contribute to one all-reduce when
   collapsing a ``Partial`` dim to ``Replicate``.


.. py:function:: get_slice_from_multi_index(source_idx: tuple, source_num_slicers: list[int], slicer_tuples: list[tuple[slice, Ellipsis]]) -> tuple[slice, Ellipsis]

   Convert multi-dimensional index to linear index and return corresponding slice tuple.

   :param source_idx: Multi-dimensional index (e.g., (0, 1))
   :param source_num_slicers: Number of slices per dimension
   :param slicer_tuples: Pre-computed slice tuples from get_slicer_tuples()

   :returns: Slice tuple for the given index


.. py:function:: get_slicer_tuples(tensor_shape: torch.Size, source_num_slicers: list[int]) -> list[tuple[slice, Ellipsis]]

   Pre-compute all slice tuples for a tensor partitioned by num_slicers.

   :param tensor_shape: Shape of the tensor to slice
   :param source_num_slicers: Number of slices per dimension

   :returns: List of slice tuples, one for each chunk


