Skip to content

Commit 78f325c

Browse files
kip-cxjcuixiaojin
andauthored
[Hardware] broadcast support for Huawei Ascend NPU (#39)
--------- Signed-off-by: kip-cxj <939544916@qq.com> Co-authored-by: cuixiaojin <c00855547@china.huawei.com>
1 parent a291782 commit 78f325c

File tree

3 files changed

+130
-44
lines changed

3 files changed

+130
-44
lines changed

checkpoint_engine/device_utils.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
import re
3+
import socket
4+
import subprocess
5+
from functools import lru_cache
6+
7+
import torch
8+
from loguru import logger
9+
10+
11+
@lru_cache(maxsize=1)
12+
def get_ip() -> str:
13+
try:
14+
# try to get ip from network interface
15+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
16+
s.connect(("8.8.8.8", 80))
17+
return s.getsockname()[0]
18+
except Exception as e: # noqa: BLE001
19+
# fallback to get ip from hostname
20+
logger.warning(
21+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
22+
)
23+
return socket.gethostbyname(socket.gethostname())
24+
25+
26+
def npu_generate_uuid() -> str:
27+
str_pid = str(os.getpid())
28+
npu_num = 8
29+
try:
30+
for npu_id in range(npu_num):
31+
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
32+
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
33+
str_result = str(result.stdout)
34+
if str_pid in str_result:
35+
# In A3 server, one NPU has two chips.
36+
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
37+
chip_count = int(match_chip_count.group(1))
38+
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
39+
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
40+
chip_id = int(match_chip_id.group(1))
41+
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
42+
ValueError("The current process is not running on the npu device")
43+
except subprocess.CalledProcessError:
44+
ValueError("The current process is not running on the npu device")
45+
46+
47+
class DeviceManager:
48+
def __init__(self):
49+
self.device_type = self._detect_device_type()
50+
self._setup_device_module()
51+
52+
def _is_torch_npu_available(self) -> bool:
53+
try:
54+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
55+
return torch.npu.is_available()
56+
else:
57+
return False
58+
except ImportError:
59+
return False
60+
61+
def _detect_device_type(self) -> str:
62+
if self._is_torch_npu_available():
63+
return "npu"
64+
elif torch.cuda.is_available():
65+
return "cuda"
66+
else:
67+
raise TypeError("The current device type is not supported")
68+
69+
def _setup_device_module(self):
70+
if self.device_type == "npu":
71+
import torch_npu
72+
73+
self.device_module = torch_npu.npu
74+
elif self.device_type == "cuda":
75+
self.device_module = torch.cuda
76+
else:
77+
raise TypeError("The current device type is not supported")
78+
79+
@property
80+
def backend(self) -> str:
81+
if self.device_type == "npu":
82+
return "hccl"
83+
elif self.device_type == "cuda":
84+
return "nccl"

checkpoint_engine/ps.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import os
55
import pickle
66
import random
7-
import socket
87
import threading
98
import time
109
from collections import defaultdict
1110
from collections.abc import Callable
1211
from datetime import timedelta
13-
from functools import lru_cache
1412
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
1513

1614
import httpx
@@ -23,6 +21,8 @@
2321
from safetensors.torch import safe_open
2422
from torch.multiprocessing.reductions import reduce_tensor
2523

24+
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
25+
2626

2727
if TYPE_CHECKING:
2828
from typing import TypeVar
@@ -254,28 +254,16 @@ def _concat_tp_weights(
254254
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
255255

256256

257-
def _get_physical_gpu_id(device_index: int | None = None) -> str:
257+
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
258258
try:
259-
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
259+
if device_manager.device_type == "npu":
260+
return f"NPU-{npu_generate_uuid()}"
261+
else:
262+
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
260263
except AssertionError as e:
261264
raise ValueError(f"fail to get physical gpu id {device_index}") from e
262265

263266

264-
@lru_cache(maxsize=1)
265-
def _get_ip() -> str:
266-
try:
267-
# try to get ip from network interface
268-
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
269-
s.connect(("8.8.8.8", 80))
270-
return s.getsockname()[0]
271-
except Exception as e: # noqa: BLE001
272-
# fallback to get ip from hostname
273-
logger.warning(
274-
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
275-
)
276-
return socket.gethostbyname(socket.gethostname())
277-
278-
279267
def _ibv_get_device_list() -> list[str]:
280268
lib = ctypes.CDLL("libibverbs.so.1")
281269
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
@@ -677,14 +665,14 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i
677665

678666

679667
class P2PStore:
680-
def __init__(self):
668+
def __init__(self, device_manager: DeviceManager):
681669
from mooncake.engine import TransferEngine
682670

683671
self.rank = int(os.getenv("RANK"))
684-
gpu_count = torch.cuda.device_count()
672+
gpu_count = device_manager.device_module.device_count()
685673
local_rank = self.rank % gpu_count
686674
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
687-
self.ip = _get_ip()
675+
self.ip = get_ip()
688676

689677
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
690678
retry_count = 8
@@ -761,7 +749,8 @@ def __init__(
761749
"""
762750
self._rank = rank or int(os.environ.get("RANK", None))
763751
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
764-
self._gpu_count = gpu_count or torch.cuda.device_count()
752+
self.device_manager = DeviceManager()
753+
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
765754
self._local_rank = self._rank % self._gpu_count
766755
self._auto_pg = auto_pg
767756
self._all_hosts = []
@@ -775,7 +764,7 @@ def __init__(
775764
assert (
776765
self._gpu_count is not None
777766
and self._gpu_count > 0
778-
and self._gpu_count <= torch.cuda.device_count()
767+
and self._gpu_count <= self.device_manager.device_module.device_count()
779768
), self._gpu_count
780769
assert (
781770
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
@@ -788,14 +777,14 @@ def __init__(
788777
# dict key is owner_rank, value is a bucket metas list in owner_rank
789778
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
790779
try:
791-
self._p2p_store = P2PStore()
780+
self._p2p_store = P2PStore(self.device_manager)
792781
except ImportError as e:
793782
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
794783
self._p2p_store = None
795784

796785
device_index = self._local_rank
797-
torch.cuda.set_device(device_index)
798-
self._device_uuid = _get_physical_gpu_id(device_index)
786+
self.device_manager.device_module.set_device(device_index)
787+
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
799788
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
800789

801790
def _logger_rank0(self, msg: str):
@@ -885,7 +874,7 @@ def gather_metas(self, checkpoint_name: str):
885874
for x in self._memory_pool.get(checkpoint_name, [])
886875
],
887876
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
888-
host_ip=_get_ip(),
877+
host_ip=get_ip(),
889878
device_uuid=self._device_uuid,
890879
rdma_device=self._rdma_device or "",
891880
)
@@ -948,7 +937,7 @@ def init_process_group(
948937
is_master=self._rank == 0,
949938
)
950939
dist.init_process_group(
951-
backend="nccl",
940+
backend=self.device_manager.backend,
952941
world_size=self._world_size,
953942
rank=self._rank,
954943
timeout=timeout,
@@ -994,12 +983,12 @@ def update(
994983
if self._auto_pg:
995984
dist.destroy_process_group()
996985

997-
torch.cuda.empty_cache()
986+
self.device_manager.device_module.empty_cache()
998987

999988
logger.info(
1000989
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1001-
f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
1002-
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
990+
f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
991+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1003992
)
1004993
except Exception as e:
1005994
logger.exception(
@@ -1023,13 +1012,15 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
10231012
tensor = torch.tensor(
10241013
[
10251014
# proportion of current cuda free memory bytes
1026-
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
1015+
int(
1016+
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
1017+
),
10271018
# we use negative value to reuse allreduce min operation
10281019
# for getting the max value of zmq_addr_counter in all ranks
10291020
-self._zmq_addr_counter,
10301021
],
10311022
dtype=torch.int64,
1032-
device="cuda",
1023+
device=self.device_manager.device_type,
10331024
)
10341025
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
10351026
tensor = tensor.cpu()
@@ -1092,7 +1083,7 @@ def _copy_to_buffer(
10921083
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
10931084
if owner_rank is not None:
10941085
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1095-
torch.cuda.synchronize()
1086+
self.device_manager.device_module.synchronize()
10961087

10971088
def init_process_group_for_ranks(
10981089
self,
@@ -1199,7 +1190,7 @@ def _update_per_bucket(
11991190
h2d_buffer: torch.Tensor | None = (
12001191
None
12011192
if disable_h2d_buffer
1202-
else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1193+
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
12031194
)
12041195
# p2p store need to register h2d_buffer to let other ranks read
12051196
if ranks:
@@ -1212,7 +1203,9 @@ def _update_per_bucket(
12121203
continue
12131204
receiver_rank_buckets.append((owner_rank, bucket))
12141205

1215-
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
1206+
buffer = torch.empty(
1207+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1208+
)
12161209
handle = reduce_tensor(buffer)
12171210

12181211
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
@@ -1245,8 +1238,8 @@ def _update_per_bucket(
12451238
continue
12461239
bucket = _buckets[i]
12471240
alloc, reserved = (
1248-
torch.cuda.memory_allocated() / 1024 / 1024,
1249-
torch.cuda.memory_reserved() / 1024 / 1024,
1241+
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
1242+
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
12501243
)
12511244
self._logger_rank0(
12521245
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
@@ -1276,7 +1269,7 @@ def _update_per_bucket(
12761269
if ranks and h2d_buffer is not None:
12771270
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
12781271

1279-
torch.cuda.empty_cache()
1272+
self.device_manager.device_module.empty_cache()
12801273

12811274

12821275
def _init_api(ps: ParameterServer) -> Any:

checkpoint_engine/worker.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch
66
import zmq
77

8+
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
9+
810

911
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
1012
func, args = handle
@@ -53,13 +55,14 @@ def update_weights_from_ipc(
5355
socket = zmq_ctx.socket(zmq.REP)
5456
socket.connect(zmq_handle)
5557
buffer: torch.Tensor | None = None
58+
device_mananger = DeviceManager()
5659
while True:
5760
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
5861
if payload is None:
5962
# means the update is done
6063
if post_hook is not None:
6164
post_hook()
62-
torch.cuda.synchronize()
65+
device_mananger.device_module.synchronize()
6366
socket.send(b"")
6467
break
6568
if isinstance(payload, tuple):
@@ -71,13 +74,13 @@ def update_weights_from_ipc(
7174
continue
7275
assert isinstance(payload, list)
7376
run(_extract_weights(payload, buffer))
74-
torch.cuda.synchronize()
77+
device_mananger.device_module.synchronize()
7578
socket.send(b"")
7679

7780
socket.close()
7881
del buffer
7982
gc.collect()
80-
torch.cuda.empty_cache()
83+
device_mananger.device_module.empty_cache()
8184

8285

8386
class VllmColocateWorkerExtension:
@@ -94,10 +97,16 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
9497
from vllm.model_executor.model_loader.utils import process_weights_after_loading
9598
from vllm.platforms import current_platform
9699

100+
# vllm-ascend not init device
101+
if current_platform.device_type == "npu" and self.device is None:
102+
self.device = torch.device(f"npu:{self.local_rank}")
97103
assert self.device is not None
98104
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
99105
self._zmq_ctx = zmq.Context()
100-
device_uuid = current_platform.get_device_uuid(self.device.index)
106+
if current_platform.device_type == "gpu":
107+
device_uuid = current_platform.get_device_uuid(self.device.index)
108+
elif current_platform.device_type == "npu":
109+
device_uuid = f"NPU-{npu_generate_uuid()}"
101110
update_weights_from_ipc(
102111
self._zmq_ctx,
103112
zmq_handles[device_uuid],

0 commit comments

Comments
 (0)