Skip to content

Commit 779eb4c

Browse files
committed
feat: use torch.cuda.get_device_properties() to get device_uuid instead of nvidia-smi -L
1 parent 03ff7e7 commit 779eb4c

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

checkpoint_engine/ps.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pickle
77
import random
88
import socket
9-
import subprocess
109
import threading
1110
import time
1211
from collections import defaultdict
@@ -242,16 +241,8 @@ def _concat_tp_weights(
242241
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
243242

244243

245-
def _get_physical_gpu_id(rank: int) -> str:
246-
result = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True) # noqa: S607
247-
if result.returncode != 0:
248-
raise ValueError(result.stdout)
249-
lines = result.stdout.strip().split("\n")
250-
for line in lines:
251-
if f"GPU {rank}" in line:
252-
uuid = line.split("UUID: ")[1].strip(")")
253-
return uuid
254-
raise ValueError(f"not found gpu{rank} uuid")
244+
def _get_physical_gpu_id(device_index: int | None = None) -> str:
245+
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
255246

256247

257248
@lru_cache(maxsize=1)
@@ -610,7 +601,6 @@ def __init__(self, *, auto_pg: bool = False):
610601
assert self._rank is not None and self._rank >= 0, self._rank
611602
assert self._world_size and self._world_size > 0, self._world_size
612603

613-
self._device_uuid = _get_physical_gpu_id(self._local_rank)
614604
self._zmq_ctx = zmq.Context()
615605
self._zmq_addr_counter = 0
616606

@@ -623,7 +613,9 @@ def __init__(self, *, auto_pg: bool = False):
623613
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
624614
self._p2p_store = None
625615

626-
torch.cuda.set_device(self._local_rank)
616+
device_index = self._local_rank
617+
torch.cuda.set_device(device_index)
618+
self._device_uuid = _get_physical_gpu_id(device_index)
627619

628620
def _logger_rank0(self, msg: str):
629621
if self._local_rank == 0:

0 commit comments

Comments
 (0)