|
23 | 23 | from safetensors.torch import safe_open |
24 | 24 | from torch.multiprocessing.reductions import reduce_tensor |
25 | 25 |
|
26 | | - |
27 | | -class DeviceManager: |
28 | | - def __init__(self): |
29 | | - self.device_type = self._detect_device_type() |
30 | | - self._setup_device_module() |
31 | | - |
32 | | - def _is_torch_npu_available(self) -> bool: |
33 | | - try: |
34 | | - if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): |
35 | | - return torch.npu.is_available() |
36 | | - else: |
37 | | - return False |
38 | | - except ImportError: |
39 | | - return False |
40 | | - |
41 | | - def _detect_device_type(self) -> str: |
42 | | - if self._is_torch_npu_available(): |
43 | | - return "npu" |
44 | | - elif torch.cuda.is_available(): |
45 | | - return "cuda" |
46 | | - else: |
47 | | - raise TypeError("The current device type is not supported") |
48 | | - |
49 | | - def _setup_device_module(self): |
50 | | - if self.device_type == "npu": |
51 | | - import torch_npu |
52 | | - self.device_module = torch_npu.npu |
53 | | - elif self.device_type == "cuda": |
54 | | - self.device_module = torch.cuda |
55 | | - else: |
56 | | - raise TypeError("The current device type is not supported") |
57 | | - |
58 | | - @property |
59 | | - def backend(self) -> str: |
60 | | - if self.device_type == "npu": |
61 | | - return "hccl" |
62 | | - elif self.device_type == "cuda": |
63 | | - return "nccl" |
| 26 | +from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid |
64 | 27 |
|
65 | 28 |
|
66 | 29 | if TYPE_CHECKING: |
@@ -288,10 +251,11 @@ def _concat_tp_weights( |
288 | 251 | return torch.cat([w for w in tp_weights], dim=tp_concat_dim) |
289 | 252 |
|
290 | 253 |
|
291 | | -def _get_physical_gpu_id(device_manager: DeviceManager, rank_id: int, device_index: int | None = None) -> str: |
| 254 | +def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: |
292 | 255 | try: |
293 | 256 | if device_manager.device_type == "npu": |
294 | | - return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{rank_id}" |
| 257 | + serial_number = npu_generate_uuid() |
| 258 | + return f"NPU-{serial_number}" |
295 | 259 | else: |
296 | 260 | return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}" |
297 | 261 | except AssertionError as e: |
@@ -630,7 +594,7 @@ def _get_master_port(master_port: int | None = None) -> int: |
630 | 594 |
|
631 | 595 |
|
632 | 596 | class P2PStore: |
633 | | - def __init__(self, device_manager : DeviceManager): |
| 597 | + def __init__(self, device_manager: DeviceManager): |
634 | 598 | from mooncake.engine import TransferEngine |
635 | 599 |
|
636 | 600 | self.rank = int(os.getenv("RANK")) |
@@ -747,7 +711,7 @@ def __init__( |
747 | 711 |
|
748 | 712 | device_index = self._local_rank |
749 | 713 | self.device_manager.device_module.set_device(device_index) |
750 | | - self._device_uuid = _get_physical_gpu_id(self.device_manager, self._rank, device_index) |
| 714 | + self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) |
751 | 715 |
|
752 | 716 | def _logger_rank0(self, msg: str): |
753 | 717 | if self._local_rank == 0: |
@@ -961,7 +925,9 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, |
961 | 925 | tensor = torch.tensor( |
962 | 926 | [ |
963 | 927 | # proportion of current cuda free memory bytes |
964 | | - int(float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction), |
| 928 | + int( |
| 929 | + float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction |
| 930 | + ), |
965 | 931 | # we use negative value to reuse allreduce min operation |
966 | 932 | # for getting the max value of zmq_addr_counter in all ranks |
967 | 933 | -self._zmq_addr_counter, |
@@ -1100,7 +1066,9 @@ def _update_per_bucket_p2p( |
1100 | 1066 | dist.barrier() |
1101 | 1067 |
|
1102 | 1068 | bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True) |
1103 | | - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) |
| 1069 | + buffer = torch.empty( |
| 1070 | + bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type |
| 1071 | + ) |
1104 | 1072 | ipc_buffer_name = "__ipc_buffer___" |
1105 | 1073 | self._p2p_store.register_named_tensors({ipc_buffer_name: buffer}) |
1106 | 1074 | logger.info( |
@@ -1190,7 +1158,9 @@ def _update_per_bucket( |
1190 | 1158 | continue |
1191 | 1159 | owner_rank_buckets.append(bucket) |
1192 | 1160 |
|
1193 | | - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) |
| 1161 | + buffer = torch.empty( |
| 1162 | + bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type |
| 1163 | + ) |
1194 | 1164 | handle = reduce_tensor(buffer) |
1195 | 1165 |
|
1196 | 1166 | buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list) |
|
0 commit comments