@@ -271,63 +271,6 @@ def _get_ip() -> str:
271271 return socket .gethostbyname (socket .gethostname ())
272272
273273
274- def _ibv_get_device_list () -> list [str ]:
275- lib = ctypes .CDLL ("libibverbs.so.1" )
276- lib .ibv_get_device_list .argtypes = [ctypes .POINTER (ctypes .c_int )] # int *num_devices
277- lib .ibv_get_device_list .restype = ctypes .POINTER (ctypes .c_void_p ) # struct ibv_device **
278-
279- lib .ibv_free_device_list .argtypes = [ctypes .POINTER (ctypes .c_void_p )]
280- lib .ibv_get_device_name .argtypes = [ctypes .c_void_p ] # struct ibv_device *
281- lib .ibv_get_device_name .restype = ctypes .c_char_p # const char *
282-
283- num = ctypes .c_int ()
284- dev_array = lib .ibv_get_device_list (ctypes .byref (num ))
285- if not dev_array or num .value <= 0 :
286- return []
287-
288- devices = []
289- for i in range (num .value ):
290- dev_ptr = dev_array [i ] # struct ibv_device *
291- name = lib .ibv_get_device_name (dev_ptr ) # const char *
292- devices .append (name .decode ())
293- lib .ibv_free_device_list (dev_array )
294- return devices
295-
296-
297- def _get_rdma_devices () -> list [str ]:
298- """
299- use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
300- """
301- devices_str = os .getenv ("PS_P2P_STORE_RDMA_DEVICES" )
302- if devices_str :
303- return devices_str .split ("," )
304- # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
305- hca = os .getenv ("NCCL_IB_HCA" , None )
306- if hca :
307- hca_list = hca .split ("," )
308- if len (hca_list ) > 1 :
309- # if NCCL_IB_HCA has multiple values, just return
310- return hca_list
311- else :
312- hca = hca_list [0 ]
313- return [device for device in sorted (_ibv_get_device_list ()) if hca is None or hca in device ]
314-
315-
316- def _get_my_rdma_device (local_rank : int , gpu_count : int , devices : list [str ]) -> str :
317- """
318- implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
319- """
320- if not devices :
321- raise RuntimeError ("no rdma devices found" )
322- assert len (devices ) <= gpu_count , (
323- f"rdma devices count { len (devices )} should be less than or equal to gpu count { gpu_count } "
324- )
325- assert gpu_count % len (devices ) == 0 , (
326- f"gpu count { gpu_count } should be divisible by rdma devices count { len (devices )} "
327- )
328- return devices [local_rank // (gpu_count // len (devices ))]
329-
330-
331274def _load_checkpoint (files : list [str ]) -> dict [str , torch .Tensor ]:
332275 class TPMeta (BaseModel ):
333276 concat_dim : int
@@ -525,14 +468,140 @@ def _get_master_port(master_port: int | None = None) -> int:
525468 return master_port
526469
527470
471+ class NCCLIBHCAParser :
472+ def __init__ (self ):
473+ self .max_hcas = 32
474+ self .available_devices = self ._ibv_get_device_list ()
475+ logger .info (f"Available RDMA Devices: { self .available_devices } " )
476+
477+ def parse (self , value : str ) -> list [str ]:
478+ if not value or value .strip () == "" :
479+ return self .available_devices [: self .max_hcas ]
480+
481+ value = value .strip ()
482+ result = []
483+ is_exclude = value .startswith ("^" )
484+ is_exact_match = value .startswith ("=" )
485+
486+ cnt = 0
487+ while value and value [0 ] in ("^" , "=" ) and cnt < 2 :
488+ if value [0 ] == "^" :
489+ is_exclude = True
490+ elif value [0 ] == "=" :
491+ is_exact_match = True
492+ value = value [1 :]
493+ cnt += 1
494+
495+ device_specs = [spec .strip () for spec in value .split ("," ) if spec .strip ()]
496+
497+ if is_exclude :
498+ excluded_devices = self ._resolve_device_specs (device_specs , is_exact_match )
499+ for excluded in excluded_devices :
500+ if excluded not in self .available_devices :
501+ logger .warning (f"device '{ excluded } ' not found in available devices." )
502+ excluded_devices .remove (excluded )
503+ result = [dev for dev in self .available_devices if dev not in excluded_devices ]
504+ else :
505+ result = self ._resolve_device_specs (device_specs , is_exact_match )
506+
507+ if len (result ) > self .max_hcas :
508+ result = result [: self .max_hcas ]
509+
510+ logger .info (f"RDMA Devices from 'NCCL_IB_HCA': { result } " )
511+
512+ return result
513+
514+ def _resolve_device_specs (self , device_specs : list [str ], is_exact_match : bool ) -> list [str ]:
515+ devices = set ()
516+ for spec in device_specs :
517+ device_name , port = (
518+ map (str .strip , spec .split (":" , 1 )) if ":" in spec else (spec .strip (), None )
519+ )
520+ base_devices = (
521+ [device_name ]
522+ if is_exact_match
523+ else [dev for dev in self .available_devices if dev .startswith (device_name )]
524+ )
525+ if is_exact_match and device_name not in self .available_devices :
526+ logger .warning (f"Device '{ device_name } ' not found in available devices." )
527+ continue
528+
529+ if not base_devices :
530+ logger .warning (f"No devices match the prefix '{ device_name } '." )
531+ continue
532+
533+ for base_dev in base_devices :
534+ devices .add (f"{ base_dev } :{ port } " if port else f"{ base_dev } " )
535+
536+ return sorted (devices )
537+
538+ def _ibv_get_device_list (self ) -> list [str ]:
539+ lib = ctypes .CDLL ("libibverbs.so.1" )
540+ lib .ibv_get_device_list .argtypes = [ctypes .POINTER (ctypes .c_int )] # int *num_devices
541+ lib .ibv_get_device_list .restype = ctypes .POINTER (ctypes .c_void_p ) # struct ibv_device **
542+
543+ lib .ibv_free_device_list .argtypes = [ctypes .POINTER (ctypes .c_void_p )]
544+ lib .ibv_get_device_name .argtypes = [ctypes .c_void_p ] # struct ibv_device *
545+ lib .ibv_get_device_name .restype = ctypes .c_char_p # const char *
546+
547+ num = ctypes .c_int ()
548+ dev_array = lib .ibv_get_device_list (ctypes .byref (num ))
549+ if not dev_array or num .value <= 0 :
550+ return []
551+
552+ devices = []
553+ for i in range (num .value ):
554+ dev_ptr = dev_array [i ] # struct ibv_device *
555+ name = lib .ibv_get_device_name (dev_ptr ) # const char *
556+ devices .append (name .decode ())
557+ lib .ibv_free_device_list (dev_array )
558+ return devices
559+
560+ def _get_rdma_devices (self ) -> list [str ]:
561+ """
562+ use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
563+ """
564+ devices_str = os .getenv ("PS_P2P_STORE_RDMA_DEVICES" )
565+ if devices_str :
566+ return devices_str .split ("," )
567+ # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
568+ hca = os .getenv ("NCCL_IB_HCA" , None )
569+
570+ if hca :
571+ hca_list = self .parse (hca )
572+ if len (hca_list ) > 1 :
573+ # if NCCL_IB_HCA has multiple values, just return
574+ return hca_list
575+ else :
576+ hca = hca_list [0 ]
577+ return [
578+ device for device in sorted (self ._ibv_get_device_list ()) if hca is None or hca in device
579+ ]
580+
581+ def _get_my_rdma_device (self , local_rank : int , gpu_count : int , devices : list [str ]) -> str :
582+ """
583+ implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
584+ if some NICs are down, causing the number of NICs is undivisible by the number of GPUs, assign the remaining GPUs to the closest NIC.
585+ """
586+ if not devices :
587+ raise RuntimeError ("no rdma devices found" )
588+ assert len (devices ) <= gpu_count , (
589+ f"rdma devices count { len (devices )} should be less than or equal to gpu count { gpu_count } "
590+ )
591+ return devices [local_rank // (gpu_count // len (devices ))]
592+
593+
528594class P2PStore :
529595 def __init__ (self ):
530596 from mooncake .engine import TransferEngine
531597
532598 self .rank = int (os .getenv ("RANK" ))
533599 gpu_count = torch .cuda .device_count ()
534600 local_rank = self .rank % gpu_count
535- device = _get_my_rdma_device (local_rank , gpu_count , _get_rdma_devices ())
601+ rdma_parser = NCCLIBHCAParser ()
602+ device = rdma_parser ._get_my_rdma_device (
603+ local_rank , gpu_count , rdma_parser ._get_rdma_devices ()
604+ )
536605 self .ip = _get_ip ()
537606
538607 # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
0 commit comments