etha.comm.get_m2m_map#

P2P map for Etha.

Attributes#

Functions#

get_m2m_map(→ tuple[dict[int, dict[tuple, ...)

Get P2P communication map for tensor redistribution.

get_shard_shape(→ list[int])

Calculate shard shape from device mesh and placements.

Module Contents#

etha.comm.get_m2m_map.get_m2m_map(source_mesh: torch.distributed._tensor.DeviceMesh, source_placements: tuple[torch.distributed.tensor.placement_types.Placement, Ellipsis], target_mesh: torch.distributed._tensor.DeviceMesh, target_placements: tuple[torch.distributed.tensor.placement_types.Placement, Ellipsis], group: torch.distributed.ProcessGroup, device: str = 'cpu') tuple[dict[int, dict[tuple, list[tuple[int, tuple]]]], list[int], list[int], list[tuple[int, str]]]#

Get P2P communication map for tensor redistribution.

Source Partial is supported by substituting Partial→Replicate for the trace, then inserting SHADOW entries for the dropped peers via _expand_partial_shadows. Target Partial is rejected — the decomposition of a logical tensor into Partial contributions is not uniquely defined across an independent process-group boundary.

Returns:

(m2m_map, source_num_slicers, target_num_slicers, source_partial_reductions). The last is a list of (mesh_dim_idx, reduce_op_str) per Partial dim, empty when source has no Partial.

etha.comm.get_m2m_map.get_shard_shape(device_mesh: tuple[int, Ellipsis], placements: tuple[torch.distributed.tensor.placement_types.Placement, Ellipsis], tensor_ndim: int) list[int]#

Calculate shard shape from device mesh and placements.

etha.comm.get_m2m_map.logger#