Skip to content

Commit a9f3642

Browse files
committed
misc
1 parent da111d1 commit a9f3642

File tree

1 file changed

+0
-75
lines changed

1 file changed

+0
-75
lines changed

checkpoint_engine/ps.py

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import threading
3-
import time
43
from collections import defaultdict
54
from collections.abc import Callable
65
from datetime import timedelta
@@ -162,80 +161,6 @@ def _get_master_port(master_port: int | None = None) -> int:
162161
return master_port
163162

164163

165-
class P2PStore:
166-
def __init__(self, device_manager: DeviceManager):
167-
from mooncake.engine import TransferEngine
168-
169-
self.rank = int(os.getenv("RANK"))
170-
gpu_count = device_manager.device_module.device_count()
171-
local_rank = self.rank % gpu_count
172-
device_type = device_manager.device_type
173-
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
174-
self.device = ""
175-
else:
176-
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
177-
self.ip = get_ip()
178-
179-
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
180-
retry_count = 8
181-
for i in range(retry_count):
182-
self.engine = TransferEngine()
183-
ret = self.engine.initialize(
184-
self.ip,
185-
"P2PHANDSHAKE",
186-
"ascend_direct" if device_type == "npu" else "rdma",
187-
self.device,
188-
)
189-
if ret == 0:
190-
break
191-
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
192-
sleep_ms = random.randint(500, 2000)
193-
logger.warning(
194-
f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
195-
)
196-
time.sleep(sleep_ms / 1000)
197-
else:
198-
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
199-
self.port = self.engine.get_rpc_port()
200-
self.named_tensors: dict[str, torch.Tensor] = {}
201-
logger.info(
202-
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
203-
)
204-
205-
@property
206-
def addr(self) -> str:
207-
return f"{self.ip}:{self.port}"
208-
209-
def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
210-
buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
211-
capacities = [tensor.nbytes for tensor in named_tensors.values()]
212-
self.named_tensors.update(named_tensors)
213-
for i, name in enumerate(named_tensors.keys()):
214-
logger.info(
215-
f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
216-
)
217-
assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
218-
219-
def unregister_named_tensors(self, names: list[str]) -> int:
220-
buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
221-
assert self.engine.batch_unregister_memory(buffer_addresses) == 0
222-
num_unregistered = 0
223-
for i, name in enumerate(names):
224-
del self.named_tensors[name]
225-
logger.info(
226-
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
227-
)
228-
num_unregistered += 1
229-
return num_unregistered
230-
231-
def batch_transfer_sync_read(
232-
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
233-
):
234-
assert (
235-
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
236-
)
237-
238-
239164
class ParameterServer:
240165
shared_memory_pool_name = "__shared_memory_pool__"
241166

0 commit comments

Comments
 (0)