|
1 | 1 | import os |
2 | 2 | import threading |
3 | | -import time |
4 | 3 | from collections import defaultdict |
5 | 4 | from collections.abc import Callable |
6 | 5 | from datetime import timedelta |
@@ -162,80 +161,6 @@ def _get_master_port(master_port: int | None = None) -> int: |
162 | 161 | return master_port |
163 | 162 |
|
164 | 163 |
|
165 | | -class P2PStore: |
166 | | - def __init__(self, device_manager: DeviceManager): |
167 | | - from mooncake.engine import TransferEngine |
168 | | - |
169 | | - self.rank = int(os.getenv("RANK")) |
170 | | - gpu_count = device_manager.device_module.device_count() |
171 | | - local_rank = self.rank % gpu_count |
172 | | - device_type = device_manager.device_type |
173 | | - if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None: |
174 | | - self.device = "" |
175 | | - else: |
176 | | - self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) |
177 | | - self.ip = get_ip() |
178 | | - |
179 | | - # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases |
180 | | - retry_count = 8 |
181 | | - for i in range(retry_count): |
182 | | - self.engine = TransferEngine() |
183 | | - ret = self.engine.initialize( |
184 | | - self.ip, |
185 | | - "P2PHANDSHAKE", |
186 | | - "ascend_direct" if device_type == "npu" else "rdma", |
187 | | - self.device, |
188 | | - ) |
189 | | - if ret == 0: |
190 | | - break |
191 | | - # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time |
192 | | - sleep_ms = random.randint(500, 2000) |
193 | | - logger.warning( |
194 | | - f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms" |
195 | | - ) |
196 | | - time.sleep(sleep_ms / 1000) |
197 | | - else: |
198 | | - raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine") |
199 | | - self.port = self.engine.get_rpc_port() |
200 | | - self.named_tensors: dict[str, torch.Tensor] = {} |
201 | | - logger.info( |
202 | | - f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}" |
203 | | - ) |
204 | | - |
205 | | - @property |
206 | | - def addr(self) -> str: |
207 | | - return f"{self.ip}:{self.port}" |
208 | | - |
209 | | - def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]): |
210 | | - buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()] |
211 | | - capacities = [tensor.nbytes for tensor in named_tensors.values()] |
212 | | - self.named_tensors.update(named_tensors) |
213 | | - for i, name in enumerate(named_tensors.keys()): |
214 | | - logger.info( |
215 | | - f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}" |
216 | | - ) |
217 | | - assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0 |
218 | | - |
219 | | - def unregister_named_tensors(self, names: list[str]) -> int: |
220 | | - buffer_addresses = [self.named_tensors[name].data_ptr() for name in names] |
221 | | - assert self.engine.batch_unregister_memory(buffer_addresses) == 0 |
222 | | - num_unregistered = 0 |
223 | | - for i, name in enumerate(names): |
224 | | - del self.named_tensors[name] |
225 | | - logger.info( |
226 | | - f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}" |
227 | | - ) |
228 | | - num_unregistered += 1 |
229 | | - return num_unregistered |
230 | | - |
231 | | - def batch_transfer_sync_read( |
232 | | - self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int] |
233 | | - ): |
234 | | - assert ( |
235 | | - self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0 |
236 | | - ) |
237 | | - |
238 | | - |
239 | 164 | class ParameterServer: |
240 | 165 | shared_memory_pool_name = "__shared_memory_pool__" |
241 | 166 |
|
|
0 commit comments