Skip to content

Commit c7c2d8d

Browse files
author
cuixiaojin
committed
[modify] generate uuid by npu smi info
1 parent c9d3d42 commit c7c2d8d

File tree

3 files changed

+87
-49
lines changed

3 files changed

+87
-49
lines changed

checkpoint_engine/device_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
import re
3+
import socket
4+
import subprocess
5+
6+
import torch
7+
8+
9+
def npu_generate_uuid() -> str:
10+
str_pid = str(os.getpid())
11+
npu_num = 8
12+
try:
13+
for npu_id in range(npu_num):
14+
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
15+
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
16+
str_result = str(result.stdout)
17+
if str_pid in str_result:
18+
# In A3 server, one NPU has two chips.
19+
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
20+
chip_count = int(match_chip_count.group(1))
21+
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
22+
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
23+
chip_id = int(match_chip_id.group(1))
24+
server_ip = socket.gethostbyname(socket.gethostname())
25+
return f"{server_ip}-{npu_id * chip_count + chip_id}"
26+
ValueError("The current process is not running on the npu device")
27+
except subprocess.CalledProcessError:
28+
ValueError("The current process is not running on the npu device")
29+
30+
31+
class DeviceManager:
32+
def __init__(self):
33+
self.device_type = self._detect_device_type()
34+
self._setup_device_module()
35+
36+
def _is_torch_npu_available(self) -> bool:
37+
try:
38+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
39+
return torch.npu.is_available()
40+
else:
41+
return False
42+
except ImportError:
43+
return False
44+
45+
def _detect_device_type(self) -> str:
46+
if self._is_torch_npu_available():
47+
return "npu"
48+
elif torch.cuda.is_available():
49+
return "cuda"
50+
else:
51+
raise TypeError("The current device type is not supported")
52+
53+
def _setup_device_module(self):
54+
if self.device_type == "npu":
55+
import torch_npu
56+
57+
self.device_module = torch_npu.npu
58+
elif self.device_type == "cuda":
59+
self.device_module = torch.cuda
60+
else:
61+
raise TypeError("The current device type is not supported")
62+
63+
@property
64+
def backend(self) -> str:
65+
if self.device_type == "npu":
66+
return "hccl"
67+
elif self.device_type == "cuda":
68+
return "nccl"

checkpoint_engine/ps.py

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,44 +23,7 @@
2323
from safetensors.torch import safe_open
2424
from torch.multiprocessing.reductions import reduce_tensor
2525

26-
27-
class DeviceManager:
28-
def __init__(self):
29-
self.device_type = self._detect_device_type()
30-
self._setup_device_module()
31-
32-
def _is_torch_npu_available(self) -> bool:
33-
try:
34-
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
35-
return torch.npu.is_available()
36-
else:
37-
return False
38-
except ImportError:
39-
return False
40-
41-
def _detect_device_type(self) -> str:
42-
if self._is_torch_npu_available():
43-
return "npu"
44-
elif torch.cuda.is_available():
45-
return "cuda"
46-
else:
47-
raise TypeError("The current device type is not supported")
48-
49-
def _setup_device_module(self):
50-
if self.device_type == "npu":
51-
import torch_npu
52-
self.device_module = torch_npu.npu
53-
elif self.device_type == "cuda":
54-
self.device_module = torch.cuda
55-
else:
56-
raise TypeError("The current device type is not supported")
57-
58-
@property
59-
def backend(self) -> str:
60-
if self.device_type == "npu":
61-
return "hccl"
62-
elif self.device_type == "cuda":
63-
return "nccl"
26+
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
6427

6528

6629
if TYPE_CHECKING:
@@ -288,10 +251,11 @@ def _concat_tp_weights(
288251
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
289252

290253

291-
def _get_physical_gpu_id(device_manager: DeviceManager, rank_id: int, device_index: int | None = None) -> str:
254+
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
292255
try:
293256
if device_manager.device_type == "npu":
294-
return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{rank_id}"
257+
serial_number = npu_generate_uuid()
258+
return f"NPU-{serial_number}"
295259
else:
296260
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
297261
except AssertionError as e:
@@ -630,7 +594,7 @@ def _get_master_port(master_port: int | None = None) -> int:
630594

631595

632596
class P2PStore:
633-
def __init__(self, device_manager : DeviceManager):
597+
def __init__(self, device_manager: DeviceManager):
634598
from mooncake.engine import TransferEngine
635599

636600
self.rank = int(os.getenv("RANK"))
@@ -747,7 +711,7 @@ def __init__(
747711

748712
device_index = self._local_rank
749713
self.device_manager.device_module.set_device(device_index)
750-
self._device_uuid = _get_physical_gpu_id(self.device_manager, self._rank, device_index)
714+
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
751715

752716
def _logger_rank0(self, msg: str):
753717
if self._local_rank == 0:
@@ -961,7 +925,9 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
961925
tensor = torch.tensor(
962926
[
963927
# proportion of current cuda free memory bytes
964-
int(float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction),
928+
int(
929+
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
930+
),
965931
# we use negative value to reuse allreduce min operation
966932
# for getting the max value of zmq_addr_counter in all ranks
967933
-self._zmq_addr_counter,
@@ -1100,7 +1066,9 @@ def _update_per_bucket_p2p(
11001066
dist.barrier()
11011067

11021068
bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
1103-
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type)
1069+
buffer = torch.empty(
1070+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1071+
)
11041072
ipc_buffer_name = "__ipc_buffer___"
11051073
self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
11061074
logger.info(
@@ -1190,7 +1158,9 @@ def _update_per_bucket(
11901158
continue
11911159
owner_rank_buckets.append(bucket)
11921160

1193-
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type)
1161+
buffer = torch.empty(
1162+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1163+
)
11941164
handle = reduce_tensor(buffer)
11951165

11961166
buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list)

checkpoint_engine/worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import torch
66
import zmq
7-
from .ps import DeviceManager
7+
8+
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
89

910

1011
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
@@ -105,9 +106,8 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
105106
if current_platform.device_type == "gpu":
106107
device_uuid = current_platform.get_device_uuid(self.device.index)
107108
elif current_platform.device_type == "npu":
108-
device_uuid = (
109-
f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.rank}"
110-
)
109+
serial_number = npu_generate_uuid()
110+
device_uuid = f"NPU-{serial_number}"
111111
update_weights_from_ipc(
112112
self._zmq_ctx,
113113
zmq_handles[device_uuid],

0 commit comments

Comments
 (0)