Skip to content

Commit b5493bf

Browse files
feat: use torch.cuda.get_device_properties() to get device_uuid instead of nvidia-smi -L (#21)
1 parent 109efc0 commit b5493bf

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

checkpoint_engine/ps.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pickle
88
import random
99
import socket
10-
import subprocess
1110
import threading
1211
import time
1312
from collections import defaultdict
@@ -243,16 +242,11 @@ def _concat_tp_weights(
243242
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
244243

245244

246-
def _get_physical_gpu_id(rank: int) -> str:
247-
result = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True) # noqa: S607
248-
if result.returncode != 0:
249-
raise ValueError(result.stdout)
250-
lines = result.stdout.strip().split("\n")
251-
for line in lines:
252-
if f"GPU {rank}" in line:
253-
uuid = line.split("UUID: ")[1].strip(")")
254-
return uuid
255-
raise ValueError(f"not found gpu{rank} uuid")
245+
def _get_physical_gpu_id(device_index: int | None = None) -> str:
246+
try:
247+
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
248+
except AssertionError as e:
249+
raise ValueError(f"fail to get physical gpu id {device_index}") from e
256250

257251

258252
@lru_cache(maxsize=1)
@@ -613,7 +607,6 @@ def __init__(
613607
assert self._rank is not None and self._rank >= 0, self._rank
614608
assert self._world_size and self._world_size > 0, self._world_size
615609

616-
self._device_uuid = _get_physical_gpu_id(self._local_rank)
617610
self._zmq_ctx = zmq.Context()
618611
self._zmq_addr_counter = 0
619612

@@ -626,7 +619,9 @@ def __init__(
626619
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
627620
self._p2p_store = None
628621

629-
torch.cuda.set_device(self._local_rank)
622+
device_index = self._local_rank
623+
torch.cuda.set_device(device_index)
624+
self._device_uuid = _get_physical_gpu_id(device_index)
630625

631626
def _logger_rank0(self, msg: str):
632627
if self._local_rank == 0:

0 commit comments

Comments
 (0)