From 7e522fdc84c90133db27a9f772e60a8a42c9da0c Mon Sep 17 00:00:00 2001 From: weixiao-huang Date: Sat, 20 Sep 2025 12:47:41 +0800 Subject: [PATCH] feat: use torch.cuda.get_device_properties() to get device_uuid instead of nvidia-smi -L --- checkpoint_engine/ps.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 3ad81b6..a8d75a0 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -7,7 +7,6 @@ import pickle import random import socket -import subprocess import threading import time from collections import defaultdict @@ -243,16 +242,11 @@ def _concat_tp_weights( return torch.cat([w for w in tp_weights], dim=tp_concat_dim) -def _get_physical_gpu_id(rank: int) -> str: - result = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True) # noqa: S607 - if result.returncode != 0: - raise ValueError(result.stdout) - lines = result.stdout.strip().split("\n") - for line in lines: - if f"GPU {rank}" in line: - uuid = line.split("UUID: ")[1].strip(")") - return uuid - raise ValueError(f"not found gpu{rank} uuid") +def _get_physical_gpu_id(device_index: int | None = None) -> str: + try: + return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}" + except AssertionError as e: + raise ValueError(f"fail to get physical gpu id {device_index}") from e @lru_cache(maxsize=1) @@ -613,7 +607,6 @@ def __init__( assert self._rank is not None and self._rank >= 0, self._rank assert self._world_size and self._world_size > 0, self._world_size - self._device_uuid = _get_physical_gpu_id(self._local_rank) self._zmq_ctx = zmq.Context() self._zmq_addr_counter = 0 @@ -626,7 +619,9 @@ def __init__( logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}") self._p2p_store = None - torch.cuda.set_device(self._local_rank) + device_index = self._local_rank + torch.cuda.set_device(device_index) + self._device_uuid = _get_physical_gpu_id(device_index) def _logger_rank0(self, msg: str): if self._local_rank == 0: