diff --git a/README.md b/README.md index a360040..59aa8a0 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P. -- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`. -- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`. +- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`. +- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified. ### Optimized Weight Broadcast In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern. @@ -36,16 +36,22 @@ It then executes the transfer, where it controls the inference engine through a Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution. +### Optimized P2P Bucket Assignment +In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances. +To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair. +The optimization goal is to make full use of the available network bandwidth for each sender and receiver. +See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25) + ## Benchmark | Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) | | :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- | -| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) | -| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) | -| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) | -| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) | -| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) | -| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) | +| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) | +| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) | +| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) | +| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) | +| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) | +| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) | All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes: @@ -53,6 +59,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism. * Since update duration is related to IPC bucket size, we provide the bucket size in the table. * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster. +* We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds. ## Installation @@ -68,7 +75,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer pip install 'checkpoint-engine[p2p]' ``` -If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank. +If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank. ## Getting Started @@ -141,11 +148,11 @@ Run a simple correctness test for checkpoint_engine torchrun --nproc-per-node 8 tests/test_update.py ``` +Other unit tests can be done with pytest. ## Limitations and Future Work - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang. - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE. -- The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future. ## Acknowledgments diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 1493a69..b2263fa 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: + from typing import TypeVar + from typing_extensions import TypedDict class FileMeta(TypedDict): @@ -34,6 +36,8 @@ class FileMeta(TypedDict): type: type tp_concat_dim: int + T = TypeVar("T") + def _dt_validate(value: Any) -> torch.dtype: if isinstance(value, str): @@ -117,6 +121,7 @@ class MemoryBuffer(BaseModel): class MemoryBufferMetaList(BaseModel): p2p_store_addr: str | None memory_buffer_metas_list: list[MemoryBufferMetas] + rdma_device: str class DataToGather(MemoryBufferMetaList): @@ -552,8 +557,12 @@ def request_inference_to_update( def _gen_h2d_buckets( - global_metas: dict[int, MemoryBufferMetaList], bucket_size: int -) -> list[tuple[int, H2DBucket]]: + global_metas: dict[int, MemoryBufferMetaList], + bucket_size: int, + local_topo: dict[str, set[int]], + remote_topo: dict[str, set[int]], + ranks: list[int] | None = None, +) -> list[tuple[int, int, H2DBucket]]: buckets: list[tuple[int, H2DBucket]] = [] for owner_rank, items in global_metas.items(): @@ -576,7 +585,73 @@ def _gen_h2d_buckets( assert buckets[-1][1].size > 0, ( f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0" ) - return buckets + ranks_set = set(ranks) if ranks else set() + actual_local_topo = ( + {k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo + ) + # if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture + if not ranks: + return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets] + else: + return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo) + + +def _assign_receiver_ranks( + buckets: list[tuple[int, "T"]], + local_topo: dict[str, set[int]], + remote_topo: dict[str, set[int]], +) -> list[tuple[int, int, "T"]]: + """ + (owner_rank, bucket) -> (receiver_rank, owner_rank, bucket) + + Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank. + GPU-rdma_device topology will be considered to make full use of the bandwidth. + """ + if not buckets: + logger.warning("bucket list is empty, no need to assign receiver ranks") + return [] + rank_to_rdma_device = { + rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks + } + + # group buckets by owner RDMA devices + buckets_by_rdma_device = defaultdict(list) + for owner_rank, bucket in buckets: + owner_rdma_device = rank_to_rdma_device[owner_rank] + buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket)) + + buckets_matrix = list(buckets_by_rdma_device.values()) + assert buckets_matrix, "buckets_matrix should not be empty" + + # Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank + num_receivers = min(len(local_topo), len(buckets_by_rdma_device)) + receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]] + + flattened_buckets = [ + buckets_matrix[row][col] + for col in range( + max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0 + ) + for row in range(len(buckets_matrix)) + if col < len(buckets_matrix[row]) + ] + + buckets_with_receiver = [] + assigned_cnt = 0 + while assigned_cnt < len(flattened_buckets): + occupied_devices = set() + for receiver_rank in receiver_list: + if assigned_cnt >= len(flattened_buckets): + break + owner_rank, bucket = flattened_buckets[assigned_cnt] + rdma_device = rank_to_rdma_device[owner_rank] + if rdma_device in occupied_devices: + break + buckets_with_receiver.append((receiver_rank, owner_rank, bucket)) + occupied_devices.add(rdma_device) + assigned_cnt += 1 + + return buckets_with_receiver def _get_master_port(master_port: int | None = None) -> int: @@ -587,6 +662,20 @@ def _get_master_port(master_port: int | None = None) -> int: return master_port +def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]: + """ + map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1), + which are generated in self.init_process_group_for_ranks + """ + bcast_rank_map: dict[int, int] = {} + if not ranks: + bcast_rank_map = {r: r for r in range(world_size)} + else: + for i, r in enumerate(ranks): + bcast_rank_map[r] = i + return bcast_rank_map + + class P2PStore: def __init__(self): from mooncake.engine import TransferEngine @@ -594,14 +683,14 @@ def __init__(self): self.rank = int(os.getenv("RANK")) gpu_count = torch.cuda.device_count() local_rank = self.rank % gpu_count - device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) + self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) self.ip = _get_ip() # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases retry_count = 8 for i in range(retry_count): self.engine = TransferEngine() - ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", device) + ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device) if ret == 0: break # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time @@ -615,7 +704,7 @@ def __init__(self): self.port = self.engine.get_rpc_port() self.named_tensors: dict[str, torch.Tensor] = {} logger.info( - f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}" + f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}" ) @property @@ -677,6 +766,8 @@ def __init__( self._auto_pg = auto_pg self._all_hosts = [] self._global_device_uuids: list[str] = [] + self._local_rdma_devices: dict[str, set[int]] = defaultdict(set) + self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set) self._mem_fraction = mem_fraction or 0.9 assert self._rank is not None and self._rank >= 0, self._rank @@ -705,6 +796,7 @@ def __init__( device_index = self._local_rank torch.cuda.set_device(device_index) self._device_uuid = _get_physical_gpu_id(device_index) + self._rdma_device = None if self._p2p_store is None else self._p2p_store.device def _logger_rank0(self, msg: str): if self._local_rank == 0: @@ -715,6 +807,13 @@ def get_metas(self) -> dict[int, MemoryBufferMetaList]: def load_metas(self, metas: dict[int, MemoryBufferMetaList]): self._current_global_parameter_metas = metas + self._remote_rdma_devices = defaultdict(set) + for i, meta in self._current_global_parameter_metas.items(): + assert meta.rdma_device is not None, "meta.rdma_device should not be None" + assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None" + self._remote_rdma_devices[ + meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0] + ].add(i) def register_checkpoint( self, @@ -788,11 +887,11 @@ def gather_metas(self, checkpoint_name: str): p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr, host_ip=_get_ip(), device_uuid=self._device_uuid, + rdma_device=self._rdma_device or "", ) dist.all_gather_object(metas_lst, metas) - self._current_global_parameter_metas = {} num_parameters = 0 all_hosts: list[str] = [] global_device_uuids: list[str] = [] @@ -803,17 +902,24 @@ def gather_metas(self, checkpoint_name: str): if not self._global_device_uuids: global_device_uuids.append(metas_buckets.device_uuid) if metas_buckets.memory_buffer_metas_list: - # _current_global_parameter_metas value should be MemoryBufferMetaList, but metas_buckets is DataToGather - # so we need to convert it to MemoryBufferMetaList self._current_global_parameter_metas[i] = MemoryBufferMetaList( memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list, p2p_store_addr=metas_buckets.p2p_store_addr, + rdma_device=metas_buckets.rdma_device, ) num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list) + self._local_rdma_devices[ + metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0] + if metas_buckets.p2p_store_addr + else metas_buckets.host_ip + ].add(i) if not self._all_hosts: self._all_hosts = all_hosts if not self._global_device_uuids: self._global_device_uuids = global_device_uuids + # Sender node and Receiver node have the same GPU-rdma_device topology is considered as default. + # Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas. + self._remote_rdma_devices = self._local_rdma_devices.copy() logger.info( f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}" ) @@ -868,6 +974,7 @@ def update( If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks, which is useful in disaggregated architecture. """ + assert req_func is not None, "req_func is required" try: # if both ranks is None or [], it will use fully broadcast to update to all ranks if not ranks: @@ -875,17 +982,15 @@ def update( self.init_process_group() self._update_per_bucket(checkpoint_name, req_func) else: - if not self._auto_pg and self._rank not in ranks: - return if self._auto_pg: if dist.is_initialized(): dist.destroy_process_group() # HACK: wait 2s to ensure destroy is finished time.sleep(2) - if self._rank not in ranks: - return self.init_process_group_for_ranks(ranks) - self._update_per_bucket_p2p(checkpoint_name, req_func, ranks) + if self._rank not in ranks: + return + self._update_per_bucket(checkpoint_name, req_func, ranks) if self._auto_pg: dist.destroy_process_group() @@ -1030,71 +1135,6 @@ def init_process_group_for_ranks( backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store ) - def _update_per_bucket_p2p( - self, - checkpoint_name: str, - req_func: Callable[[list[tuple[str, str]]], None], - ranks: list[int], - ): - assert self._p2p_store is not None, "p2p store is not initialized" - assert ranks, "ranks should be set" - if len(self._current_global_parameter_metas) == 0: - raise ValueError("parameter metas is empty") - assert dist.is_initialized(), ( - "process group is not initialized when update model per bucket p2p" - ) - - need_update = self._rank in ranks - logger.info( - f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, " - f"gpu_count {self._gpu_count}, world_size {self._world_size}" - ) - - if not need_update: - return - - # first execute a barrier to avoid subsequent cuda oom - dist.barrier() - - bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True) - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda") - ipc_buffer_name = "__ipc_buffer___" - self._p2p_store.register_named_tensors({ipc_buffer_name: buffer}) - logger.info( - f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}" - ) - handle = reduce_tensor(buffer) - - buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size) - socket, socket_paths = self._bind_zmq_socket() - req_thread = threading.Thread( - target=req_func, - args=(socket_paths,), - ) - req_thread.start() - socket.send_pyobj(handle) - for gidx, (owner_rank, bucket) in enumerate(buckets): - self._logger_rank0( - f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. " - ) - _buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size] - if dist.get_rank() == 0: - self._copy_to_buffer(checkpoint_name, bucket, _buffer, owner_rank) - # broadcast the collected data to all ranks - dist.broadcast(_buffer, src=0) - socket.recv() - dist.barrier() - socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size)) - - socket.recv() - socket.send_pyobj(None) - socket.recv() - req_thread.join() - dist.barrier() - socket.close() - self._p2p_store.unregister_named_tensors([ipc_buffer_name]) - torch.cuda.empty_cache() - def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]: addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list @@ -1124,38 +1164,63 @@ def _update_per_bucket( self, checkpoint_name: str, req_func: Callable[[list[tuple[str, str]]], None], + ranks: list[int] | None = None, ): - if len(self._current_global_parameter_metas) == 0: - raise ValueError("parameter metas is empty") - + assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" assert dist.is_initialized(), "process group is not initialized" + # if both ranks is None or [], it will use fully broadcast to update to all ranks + if not ranks: + logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}") + # if ranks is set, it will use p2p to update to the ranks + else: + assert self._p2p_store is not None, "p2p store is not initialized" + assert ranks, "ranks should be set" - logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}") + need_update = self._rank in ranks + logger.info( + f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, " + f"gpu_count {self._gpu_count}, world_size {self._world_size}" + ) + + if not need_update: + return + # first execute a barrier to avoid subsequent cuda oom + dist.barrier() bucket_size, disable_h2d_buffer = self._detect_bucket_size() - buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size) + buckets = _gen_h2d_buckets( + self._current_global_parameter_metas, + bucket_size, + self._local_rdma_devices, + self._remote_rdma_devices, + ranks, + ) h2d_buffer: torch.Tensor | None = ( None if disable_h2d_buffer else torch.empty(bucket_size, dtype=torch.uint8, device="cuda") ) - - owner_rank_buckets: list[H2DBucket] = [] - for owner_rank, bucket in buckets: - if owner_rank != self._rank: + # p2p store need to register h2d_buffer to let other ranks read + if ranks: + h2d_buffer_name = "__h2d_buffer__" + if h2d_buffer is not None and self._p2p_store is not None: + self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer}) + receiver_rank_buckets: list[tuple[int, H2DBucket]] = [] + for receiver_rank, owner_rank, bucket in buckets: + if receiver_rank != self._rank: continue - owner_rank_buckets.append(bucket) + receiver_rank_buckets.append((owner_rank, bucket)) buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda") handle = reduce_tensor(buffer) - buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list) + buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list) max_len = 0 - for owner_rank, bucket in buckets: - buckets_by_owner_rank[owner_rank].append(bucket) - if len(buckets_by_owner_rank[owner_rank]) > max_len: - max_len = len(buckets_by_owner_rank[owner_rank]) + for receiver_rank, _, bucket in buckets: + buckets_by_receiver_rank[receiver_rank].append(bucket) + if len(buckets_by_receiver_rank[receiver_rank]) > max_len: + max_len = len(buckets_by_receiver_rank[receiver_rank]) socket, socket_paths = self._bind_zmq_socket() req_thread = threading.Thread( @@ -1166,11 +1231,16 @@ def _update_per_bucket( socket.send_pyobj(handle) gidx = 0 + bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks) for i in range(max_len): - if i < len(owner_rank_buckets) and not disable_h2d_buffer: - self._copy_to_buffer(checkpoint_name, owner_rank_buckets[i], h2d_buffer) - - for owner_rank, _buckets in buckets_by_owner_rank.items(): + if i < len(receiver_rank_buckets) and not disable_h2d_buffer: + self._copy_to_buffer( + checkpoint_name, + receiver_rank_buckets[i][1], + h2d_buffer, + receiver_rank_buckets[i][0] if ranks else None, + ) + for receiver_rank, _buckets in buckets_by_receiver_rank.items(): if i >= len(_buckets): continue bucket = _buckets[i] @@ -1179,18 +1249,19 @@ def _update_per_bucket( torch.cuda.memory_reserved() / 1024 / 1024, ) self._logger_rank0( - f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. " + f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. " f"Current CUDA allocated {alloc:.2f} MB, " f"reserved {reserved:.2f} MB." ) start = gidx % 2 * bucket_size buffer_b: torch.Tensor = buffer[start : start + bucket.size] - if owner_rank == self._rank: + if receiver_rank == self._rank: if disable_h2d_buffer: self._copy_to_buffer(checkpoint_name, bucket, buffer_b) else: buffer_b.data.copy_(h2d_buffer[: bucket.size]) - dist.broadcast(buffer_b, src=owner_rank) + brank = bcast_rank_map[receiver_rank] + dist.broadcast(buffer_b, src=brank) socket.recv() dist.barrier() socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size)) @@ -1202,6 +1273,9 @@ def _update_per_bucket( req_thread.join() dist.barrier() socket.close() + if ranks and h2d_buffer is not None: + self._p2p_store.unregister_named_tensors([h2d_buffer_name]) + torch.cuda.empty_cache() diff --git a/tests/test_assign_receiver_ranks.py b/tests/test_assign_receiver_ranks.py new file mode 100644 index 0000000..b637794 --- /dev/null +++ b/tests/test_assign_receiver_ranks.py @@ -0,0 +1,68 @@ +import pytest + +from checkpoint_engine.ps import _assign_receiver_ranks + + +@pytest.mark.parametrize( + "buckets,local_topo,remote_topo,expected_results", + [ + ( + [(i % 8, f"bucket{i}") for i in range(80)], + {f"rdma{i}": {i} for i in range(8)}, + {f"rdma{i}": {i} for i in range(8)}, + [(i % 8, i % 8, f"bucket{i}") for i in range(80)], + ), + ( + [(i % 8, f"bucket{i}") for i in range(80)], + {f"rdma{i}": {i} for i in range(8)}, + {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)}, + [((i // 2 % 4), i % 8, f"bucket{i}") for i in range(80)], + ), + ( + [(i % 8, f"bucket{i}") for i in range(80)], + {f"rdma{i}": {i, i + 1, i + 2, i + 3} for i in range(0, 8, 4)}, + {f"rdma{i}": {i} for i in range(8)}, + [((i % 2) * 4, i % 8, f"bucket{i}") for i in range(80)], + ), + ( + [(i % 8, f"bucket{i}") for i in range(13)], + {f"rdma{i}": {i} for i in range(8)}, + {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)}, + [((i // 2 % 4), i % 8, f"bucket{i}") for i in range(13)], + ), + ( + [(i % 8, f"bucket{i}") for i in range(13)], + {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)}, + {f"rdma{i}": {i} for i in range(8)}, + [((i % 4) * 2, i % 8, f"bucket{i}") for i in range(13)], + ), + ( + [(i % 8, f"bucket{i}") for i in range(13)], + {f"rdma{i}": {i} for i in range(3)}, + {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)}, + [ + (0, 0, "bucket0"), + (1, 1, "bucket1"), + (1, 2, "bucket2"), + (2, 3, "bucket3"), + (2, 4, "bucket4"), + (0, 5, "bucket5"), + (0, 6, "bucket6"), + (1, 7, "bucket7"), + (2, 0, "bucket8"), + (2, 1, "bucket9"), + (0, 2, "bucket10"), + (0, 3, "bucket11"), + (1, 4, "bucket12"), + ], + ), + ], +) +def test_basic_functionality( + buckets: list[tuple[int, str]], + local_topo: dict[str, int], + remote_topo: dict[str, int], + expected_results: list[tuple[int, int, str]], +): + assert len(expected_results) == len(buckets) + assert set(expected_results) == set(_assign_receiver_ranks(buckets, local_topo, remote_topo))