etha.comm.get_chunks#
Get chunks from m2m map.
Functions#
|
|
|
Module Contents#
- etha.comm.get_chunks.calculate_chunk_shape(num_slicers: list[int], tensor_shape: tuple[int, Ellipsis] | None) tuple[int, Ellipsis]#
- etha.comm.get_chunks.map_to_chunk_ops(m2m_map: dict[int, dict[tuple, list[tuple[int, tuple]]]], rank: int, source_num_slicers: list[int], target_num_slicers: list[int], source_tensor: torch.Tensor | None = None, target_tensor: torch.Tensor | None = None, transfer_dtype: torch.dtype | None = None, source_partial_groups: list[tuple[torch.distributed.ProcessGroup, str]] | None = None) list[etha.comm.ir.Chunk]#