etha.comm.get_m2m_map#
P2P map for Etha.
Attributes#
Functions#
|
Get P2P communication map for tensor redistribution. |
|
Calculate shard shape from device mesh and placements. |
Module Contents#
- etha.comm.get_m2m_map.get_m2m_map(source_mesh: torch.distributed._tensor.DeviceMesh, source_placements: tuple[torch.distributed.tensor.placement_types.Placement, Ellipsis], target_mesh: torch.distributed._tensor.DeviceMesh, target_placements: tuple[torch.distributed.tensor.placement_types.Placement, Ellipsis], group: torch.distributed.ProcessGroup, device: str = 'cpu') tuple[dict[int, dict[tuple, list[tuple[int, tuple]]]], list[int], list[int], list[tuple[int, str]]]#
Get P2P communication map for tensor redistribution.
Source Partial is supported by substituting Partial→Replicate for the trace, then inserting SHADOW entries for the dropped peers via
_expand_partial_shadows. Target Partial is rejected — the decomposition of a logical tensor into Partial contributions is not uniquely defined across an independent process-group boundary.- Returns:
(m2m_map, source_num_slicers, target_num_slicers, source_partial_reductions). The last is a list of(mesh_dim_idx, reduce_op_str)per Partial dim, empty when source has no Partial.
- etha.comm.get_m2m_map.get_shard_shape(device_mesh: tuple[int, Ellipsis], placements: tuple[torch.distributed.tensor.placement_types.Placement, Ellipsis], tensor_ndim: int) list[int]#
Calculate shard shape from device mesh and placements.
- etha.comm.get_m2m_map.logger#