etha.tensor_bus.batch_state#

Batch state management for TensorBus.

A batch represents one call to register_tensors(). Multiple batches can exist for the same pair, each with independent tensors and handlers.

Classes#

BatchState

State for a single batch of registered tensors.

Module Contents#

class etha.tensor_bus.batch_state.BatchState#

Bases: msgspec.Struct

State for a single batch of registered tensors.

A batch is created for each register_tensors() call and contains all tensor data and execution plans for that specific registration.

Key design: chunks and buckets are FLATTENED across all pairs in the batch, allowing single-pass execution via execute_bucket_pipeline().

batch_group: torch.distributed.ProcessGroup | None = None#
batch_id: str#
bucket_size: int | None = None#
local_group: torch.distributed.ProcessGroup | None = None#
local_leader: int | None = None#
pair_names: list[str]#
pair_target_dtypes: dict[str, list[torch.dtype]]#
pair_tensors: dict[str, list[torch.Tensor]]#
recv_buckets: list[etha.comm.ir.Bucket] | None = None#
recv_chunks: list[etha.comm.ir.Chunk]#
send_buckets: list[etha.comm.ir.Bucket] | None = None#
send_chunks: list[etha.comm.ir.Chunk]#