Skip to content

Commit 96483de

Browse files
author
cuixiaojin
committed
compatible with PS_P2P_STORE_RDMA_DEVICES env
1 parent d564f1a commit 96483de

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

checkpoint_engine/ps.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -679,19 +679,23 @@ def __init__(self, device_manager: DeviceManager):
679679
self.rank = int(os.getenv("RANK"))
680680
gpu_count = device_manager.device_module.device_count()
681681
local_rank = self.rank % gpu_count
682-
if device_manager.device_type == "npu":
682+
device_type = device_manager.device_type
683+
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
683684
self.device = ""
684-
protocol = "ascend_direct"
685685
else:
686686
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
687-
protocol = "rdma"
688687
self.ip = get_ip()
689688

690689
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
691690
retry_count = 8
692691
for i in range(retry_count):
693692
self.engine = TransferEngine()
694-
ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", protocol, self.device)
693+
ret = self.engine.initialize(
694+
self.ip,
695+
"P2PHANDSHAKE",
696+
"ascend_direct" if device_type == "npu" else "rdma",
697+
self.device
698+
)
695699
if ret == 0:
696700
break
697701
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time

0 commit comments

Comments
 (0)