Skip to content

Commit c9d3d42

Browse files
author
cuixiaojin
committed
[modify] address code view feedback
1 parent bed9862 commit c9d3d42

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

checkpoint_engine/ps.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,39 @@
2424
from torch.multiprocessing.reductions import reduce_tensor
2525

2626

27-
def is_torch_npu_available() -> bool:
28-
try:
29-
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
30-
return torch.npu.is_available()
31-
else:
32-
return False
33-
except ImportError:
34-
return False
35-
3627
class DeviceManager:
3728
def __init__(self):
3829
self.device_type = self._detect_device_type()
3930
self._setup_device_module()
4031

32+
def _is_torch_npu_available(self) -> bool:
33+
try:
34+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
35+
return torch.npu.is_available()
36+
else:
37+
return False
38+
except ImportError:
39+
return False
40+
4141
def _detect_device_type(self) -> str:
42-
if is_torch_npu_available():
42+
if self._is_torch_npu_available():
4343
return "npu"
4444
elif torch.cuda.is_available():
4545
return "cuda"
46+
else:
47+
raise TypeError("The current device type is not supported")
4648

4749
def _setup_device_module(self):
4850
if self.device_type == "npu":
4951
import torch_npu
5052
self.device_module = torch_npu.npu
5153
elif self.device_type == "cuda":
5254
self.device_module = torch.cuda
55+
else:
56+
raise TypeError("The current device type is not supported")
5357

54-
def get_backend(self) -> str:
58+
@property
59+
def backend(self) -> str:
5560
if self.device_type == "npu":
5661
return "hccl"
5762
elif self.device_type == "cuda":
@@ -283,10 +288,10 @@ def _concat_tp_weights(
283288
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
284289

285290

286-
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
291+
def _get_physical_gpu_id(device_manager: DeviceManager, rank_id: int, device_index: int | None = None) -> str:
287292
try:
288293
if device_manager.device_type == "npu":
289-
return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{device_index}"
294+
return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{rank_id}"
290295
else:
291296
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
292297
except AssertionError as e:
@@ -625,7 +630,7 @@ def _get_master_port(master_port: int | None = None) -> int:
625630

626631

627632
class P2PStore:
628-
def __init__(self, device_manager: DeviceManager):
633+
def __init__(self, device_manager : DeviceManager):
629634
from mooncake.engine import TransferEngine
630635

631636
self.rank = int(os.getenv("RANK"))
@@ -742,7 +747,7 @@ def __init__(
742747

743748
device_index = self._local_rank
744749
self.device_manager.device_module.set_device(device_index)
745-
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
750+
self._device_uuid = _get_physical_gpu_id(self.device_manager, self._rank, device_index)
746751

747752
def _logger_rank0(self, msg: str):
748753
if self._local_rank == 0:
@@ -880,7 +885,7 @@ def init_process_group(
880885
is_master=self._rank == 0,
881886
)
882887
dist.init_process_group(
883-
backend=self.device_manager.get_backend(),
888+
backend=self.device_manager.backend,
884889
world_size=self._world_size,
885890
rank=self._rank,
886891
timeout=timeout,

checkpoint_engine/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
import zmq
7-
from checkpoint_engine.ps import DeviceManager
7+
from .ps import DeviceManager
88

99

1010
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
@@ -106,7 +106,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
106106
device_uuid = current_platform.get_device_uuid(self.device.index)
107107
elif current_platform.device_type == "npu":
108108
device_uuid = (
109-
f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.device.index}"
109+
f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.rank}"
110110
)
111111
update_weights_from_ipc(
112112
self._zmq_ctx,

0 commit comments

Comments
 (0)