Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 60 additions & 23 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,39 @@
from torch.multiprocessing.reductions import reduce_tensor


def is_torch_npu_available() -> bool:
try:
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
return torch.npu.is_available()
return False
except ImportError:
return False

class DeviceManager:
def __init__(self):
self.device_type = self._detect_device_type()
self._setup_device_module()

def _detect_device_type(self):
if is_torch_npu_available():
return "npu"
elif torch.cuda.is_available():
return "cuda"

def _setup_device_module(self):
if self.device_type == "npu":
import torch_npu
self.device_module = torch_npu.npu
elif self.device_type == "cuda":
self.device_module = torch.cuda

def get_backend(self):
if self.device_type == "npu":
return "hccl"
elif self.device_type == "cuda":
return "nccl"


if TYPE_CHECKING:
from typing_extensions import TypedDict

Expand Down Expand Up @@ -249,9 +282,12 @@ def _concat_tp_weights(
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)


def _get_physical_gpu_id(device_index: int | None = None) -> str:
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
try:
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
if device_manager.device_type == "npu":
return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{device_index}"
else:
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
except AssertionError as e:
raise ValueError(f"fail to get physical gpu id {device_index}") from e

Expand Down Expand Up @@ -588,11 +624,11 @@ def _get_master_port(master_port: int | None = None) -> int:


class P2PStore:
def __init__(self):
def __init__(self, device_manager):
from mooncake.engine import TransferEngine

self.rank = int(os.getenv("RANK"))
gpu_count = torch.cuda.device_count()
gpu_count = device_manager.device_module.device_count()
local_rank = self.rank % gpu_count
device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
self.ip = _get_ip()
Expand Down Expand Up @@ -672,7 +708,8 @@ def __init__(
"""
self._rank = rank or int(os.environ.get("RANK", None))
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
self._gpu_count = gpu_count or torch.cuda.device_count()
self.device_manager = DeviceManager()
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
self._local_rank = self._rank % self._gpu_count
self._auto_pg = auto_pg
self._all_hosts = []
Expand All @@ -684,7 +721,7 @@ def __init__(
assert (
self._gpu_count is not None
and self._gpu_count > 0
and self._gpu_count <= torch.cuda.device_count()
and self._gpu_count <= self.device_manager.device_module.device_count()
), self._gpu_count
assert (
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
Expand All @@ -697,14 +734,14 @@ def __init__(
# dict key is owner_rank, value is a bucket metas list in owner_rank
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
try:
self._p2p_store = P2PStore()
self._p2p_store = P2PStore(self.device_manager)
except ImportError as e:
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
self._p2p_store = None

device_index = self._local_rank
torch.cuda.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(device_index)
self.device_manager.device_module.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)

def _logger_rank0(self, msg: str):
if self._local_rank == 0:
Expand Down Expand Up @@ -842,7 +879,7 @@ def init_process_group(
is_master=self._rank == 0,
)
dist.init_process_group(
backend="nccl",
backend=self.device_manager.get_backend(),
world_size=self._world_size,
rank=self._rank,
timeout=timeout,
Expand Down Expand Up @@ -889,12 +926,12 @@ def update(
if self._auto_pg:
dist.destroy_process_group()

torch.cuda.empty_cache()
self.device_manager.device_module.empty_cache()

logger.info(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
)
except Exception as e:
logger.exception(
Expand All @@ -918,13 +955,13 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
tensor = torch.tensor(
[
# proportion of current cuda free memory bytes
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
int(float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction),
# we use negative value to reuse allreduce min operation
# for getting the max value of zmq_addr_counter in all ranks
-self._zmq_addr_counter,
],
dtype=torch.int64,
device="cuda",
device=self.device_manager.device_type,
)
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
tensor = tensor.cpu()
Expand Down Expand Up @@ -987,7 +1024,7 @@ def _copy_to_buffer(
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
if owner_rank is not None:
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
torch.cuda.synchronize()
self.device_manager.device_module.synchronize()

def init_process_group_for_ranks(
self,
Expand Down Expand Up @@ -1057,7 +1094,7 @@ def _update_per_bucket_p2p(
dist.barrier()

bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type)
ipc_buffer_name = "__ipc_buffer___"
self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
logger.info(
Expand Down Expand Up @@ -1093,7 +1130,7 @@ def _update_per_bucket_p2p(
dist.barrier()
socket.close()
self._p2p_store.unregister_named_tensors([ipc_buffer_name])
torch.cuda.empty_cache()
self.device_manager.device_module.empty_cache()

def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
Expand Down Expand Up @@ -1138,7 +1175,7 @@ def _update_per_bucket(
h2d_buffer: torch.Tensor | None = (
None
if disable_h2d_buffer
else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
)

owner_rank_buckets: list[H2DBucket] = []
Expand All @@ -1147,7 +1184,7 @@ def _update_per_bucket(
continue
owner_rank_buckets.append(bucket)

buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type)
handle = reduce_tensor(buffer)

buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list)
Expand Down Expand Up @@ -1175,8 +1212,8 @@ def _update_per_bucket(
continue
bucket = _buckets[i]
alloc, reserved = (
torch.cuda.memory_allocated() / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024,
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
)
self._logger_rank0(
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
Expand All @@ -1202,7 +1239,7 @@ def _update_per_bucket(
req_thread.join()
dist.barrier()
socket.close()
torch.cuda.empty_cache()
self.device_manager.device_module.empty_cache()


def _init_api(ps: ParameterServer) -> Any:
Expand Down
18 changes: 14 additions & 4 deletions checkpoint_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import zmq
from checkpoint_engine.ps import DeviceManager


def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
Expand Down Expand Up @@ -53,13 +54,14 @@ def update_weights_from_ipc(
socket = zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handle)
buffer: torch.Tensor | None = None
device_mananger = DeviceManager()
while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
if payload is None:
# means the update is done
if post_hook is not None:
post_hook()
torch.cuda.synchronize()
device_mananger.device_module.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
Expand All @@ -71,13 +73,13 @@ def update_weights_from_ipc(
continue
assert isinstance(payload, list)
run(_extract_weights(payload, buffer))
torch.cuda.synchronize()
device_mananger.device_module.synchronize()
socket.send(b"")

socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
device_mananger.device_module.empty_cache()


class VllmColocateWorkerExtension:
Expand All @@ -94,10 +96,18 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
from vllm.model_executor.model_loader.utils import process_weights_after_loading
from vllm.platforms import current_platform

# vllm-ascend not init device
if current_platform.device_type == "npu" and self.device is None:
self.device = torch.device(f"npu:{self.local_rank}")
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
device_uuid = current_platform.get_device_uuid(self.device.index)
if current_platform.device_type == "gpu":
device_uuid = current_platform.get_device_uuid(self.device.index)
elif current_platform.device_type == "npu":
device_uuid = (
f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.device.index}"
)
update_weights_from_ipc(
self._zmq_ctx,
zmq_handles[device_uuid],
Expand Down
Loading