Skip to content

Commit 03ca80d

Browse files
committed
feat: NCCLIBHCAParser class added, supporting exact match, exclude, and port specifications for RDMA devices.
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8
1 parent 2ef05a4 commit 03ca80d

File tree

2 files changed

+341
-58
lines changed

2 files changed

+341
-58
lines changed

checkpoint_engine/ps.py

Lines changed: 127 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -271,63 +271,6 @@ def _get_ip() -> str:
271271
return socket.gethostbyname(socket.gethostname())
272272

273273

274-
def _ibv_get_device_list() -> list[str]:
275-
lib = ctypes.CDLL("libibverbs.so.1")
276-
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
277-
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
278-
279-
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
280-
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
281-
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
282-
283-
num = ctypes.c_int()
284-
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
285-
if not dev_array or num.value <= 0:
286-
return []
287-
288-
devices = []
289-
for i in range(num.value):
290-
dev_ptr = dev_array[i] # struct ibv_device *
291-
name = lib.ibv_get_device_name(dev_ptr) # const char *
292-
devices.append(name.decode())
293-
lib.ibv_free_device_list(dev_array)
294-
return devices
295-
296-
297-
def _get_rdma_devices() -> list[str]:
298-
"""
299-
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
300-
"""
301-
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
302-
if devices_str:
303-
return devices_str.split(",")
304-
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
305-
hca = os.getenv("NCCL_IB_HCA", None)
306-
if hca:
307-
hca_list = hca.split(",")
308-
if len(hca_list) > 1:
309-
# if NCCL_IB_HCA has multiple values, just return
310-
return hca_list
311-
else:
312-
hca = hca_list[0]
313-
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
314-
315-
316-
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
317-
"""
318-
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
319-
"""
320-
if not devices:
321-
raise RuntimeError("no rdma devices found")
322-
assert len(devices) <= gpu_count, (
323-
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
324-
)
325-
assert gpu_count % len(devices) == 0, (
326-
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
327-
)
328-
return devices[local_rank // (gpu_count // len(devices))]
329-
330-
331274
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
332275
class TPMeta(BaseModel):
333276
concat_dim: int
@@ -525,14 +468,140 @@ def _get_master_port(master_port: int | None = None) -> int:
525468
return master_port
526469

527470

471+
class NCCLIBHCAParser:
472+
def __init__(self):
473+
self.max_hcas = 32
474+
self.available_devices = self._ibv_get_device_list()
475+
logger.info(f"Available RDMA Devices: {self.available_devices}")
476+
477+
def parse(self, value: str) -> list[str]:
478+
if not value or value.strip() == "":
479+
return self.available_devices[: self.max_hcas]
480+
481+
value = value.strip()
482+
result = []
483+
is_exclude = value.startswith("^")
484+
is_exact_match = value.startswith("=")
485+
486+
cnt = 0
487+
while value and value[0] in ("^", "=") and cnt < 2:
488+
if value[0] == "^":
489+
is_exclude = True
490+
elif value[0] == "=":
491+
is_exact_match = True
492+
value = value[1:]
493+
cnt += 1
494+
495+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
496+
497+
if is_exclude:
498+
excluded_devices = self._resolve_device_specs(device_specs, is_exact_match)
499+
for excluded in excluded_devices:
500+
if excluded not in self.available_devices:
501+
logger.warning(f"device '{excluded}' not found in available devices.")
502+
excluded_devices.remove(excluded)
503+
result = [dev for dev in self.available_devices if dev not in excluded_devices]
504+
else:
505+
result = self._resolve_device_specs(device_specs, is_exact_match)
506+
507+
if len(result) > self.max_hcas:
508+
result = result[: self.max_hcas]
509+
510+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
511+
512+
return result
513+
514+
def _resolve_device_specs(self, device_specs: list[str], is_exact_match: bool) -> list[str]:
515+
devices = set()
516+
for spec in device_specs:
517+
device_name, port = (
518+
map(str.strip, spec.split(":", 1)) if ":" in spec else (spec.strip(), None)
519+
)
520+
base_devices = (
521+
[device_name]
522+
if is_exact_match
523+
else [dev for dev in self.available_devices if dev.startswith(device_name)]
524+
)
525+
if is_exact_match and device_name not in self.available_devices:
526+
logger.warning(f"Device '{device_name}' not found in available devices.")
527+
continue
528+
529+
if not base_devices:
530+
logger.warning(f"No devices match the prefix '{device_name}'.")
531+
continue
532+
533+
for base_dev in base_devices:
534+
devices.add(f"{base_dev}:{port}" if port else f"{base_dev}")
535+
536+
return sorted(devices)
537+
538+
def _ibv_get_device_list(self) -> list[str]:
539+
lib = ctypes.CDLL("libibverbs.so.1")
540+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
541+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
542+
543+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
544+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
545+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
546+
547+
num = ctypes.c_int()
548+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
549+
if not dev_array or num.value <= 0:
550+
return []
551+
552+
devices = []
553+
for i in range(num.value):
554+
dev_ptr = dev_array[i] # struct ibv_device *
555+
name = lib.ibv_get_device_name(dev_ptr) # const char *
556+
devices.append(name.decode())
557+
lib.ibv_free_device_list(dev_array)
558+
return devices
559+
560+
def _get_rdma_devices(self) -> list[str]:
561+
"""
562+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
563+
"""
564+
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
565+
if devices_str:
566+
return devices_str.split(",")
567+
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
568+
hca = os.getenv("NCCL_IB_HCA", None)
569+
570+
if hca:
571+
hca_list = self.parse(hca)
572+
if len(hca_list) > 1:
573+
# if NCCL_IB_HCA has multiple values, just return
574+
return hca_list
575+
else:
576+
hca = hca_list[0]
577+
return [
578+
device for device in sorted(self._ibv_get_device_list()) if hca is None or hca in device
579+
]
580+
581+
def _get_my_rdma_device(self, local_rank: int, gpu_count: int, devices: list[str]) -> str:
582+
"""
583+
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
584+
if some NICs are down, causing the number of NICs is undivisible by the number of GPUs, assign the remaining GPUs to the closest NIC.
585+
"""
586+
if not devices:
587+
raise RuntimeError("no rdma devices found")
588+
assert len(devices) <= gpu_count, (
589+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
590+
)
591+
return devices[local_rank // (gpu_count // len(devices))]
592+
593+
528594
class P2PStore:
529595
def __init__(self):
530596
from mooncake.engine import TransferEngine
531597

532598
self.rank = int(os.getenv("RANK"))
533599
gpu_count = torch.cuda.device_count()
534600
local_rank = self.rank % gpu_count
535-
device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
601+
rdma_parser = NCCLIBHCAParser()
602+
device = rdma_parser._get_my_rdma_device(
603+
local_rank, gpu_count, rdma_parser._get_rdma_devices()
604+
)
536605
self.ip = _get_ip()
537606

538607
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases

0 commit comments

Comments
 (0)