66import pickle
77import random
88import socket
9- import subprocess
109import threading
1110import time
1211from collections import defaultdict
@@ -242,16 +241,8 @@ def _concat_tp_weights(
242241 return torch .cat ([w for w in tp_weights ], dim = tp_concat_dim )
243242
244243
245- def _get_physical_gpu_id (rank : int ) -> str :
246- result = subprocess .run (["nvidia-smi" , "-L" ], capture_output = True , text = True ) # noqa: S607
247- if result .returncode != 0 :
248- raise ValueError (result .stdout )
249- lines = result .stdout .strip ().split ("\n " )
250- for line in lines :
251- if f"GPU { rank } " in line :
252- uuid = line .split ("UUID: " )[1 ].strip (")" )
253- return uuid
254- raise ValueError (f"not found gpu{ rank } uuid" )
244+ def _get_physical_gpu_id (device_index : int | None = None ) -> str :
245+ return f"GPU-{ torch .cuda .get_device_properties (device_index ).uuid !s} "
255246
256247
257248@lru_cache (maxsize = 1 )
@@ -610,7 +601,6 @@ def __init__(self, *, auto_pg: bool = False):
610601 assert self ._rank is not None and self ._rank >= 0 , self ._rank
611602 assert self ._world_size and self ._world_size > 0 , self ._world_size
612603
613- self ._device_uuid = _get_physical_gpu_id (self ._local_rank )
614604 self ._zmq_ctx = zmq .Context ()
615605 self ._zmq_addr_counter = 0
616606
@@ -623,7 +613,9 @@ def __init__(self, *, auto_pg: bool = False):
623613 logger .warning (f"[rank{ self ._rank } ] fail to initialize p2p store due to { e } " )
624614 self ._p2p_store = None
625615
626- torch .cuda .set_device (self ._local_rank )
616+ device_index = self ._local_rank
617+ torch .cuda .set_device (device_index )
618+ self ._device_uuid = _get_physical_gpu_id (device_index )
627619
628620 def _logger_rank0 (self , msg : str ):
629621 if self ._local_rank == 0 :
0 commit comments