diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 42cac9e..8938ce0 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -2,6 +2,7 @@ import argparse import concurrent.futures +import ctypes import os import pickle import random @@ -269,14 +270,32 @@ def _get_ip() -> str: return socket.gethostbyname(socket.gethostname()) +def _ibv_get_device_list() -> list[str]: + lib = ctypes.CDLL("libibverbs.so.1") + lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices + lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device ** + + lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)] + lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device * + lib.ibv_get_device_name.restype = ctypes.c_char_p # const char * + + num = ctypes.c_int() + dev_array = lib.ibv_get_device_list(ctypes.byref(num)) + if not dev_array or num.value <= 0: + return [] + + devices = [] + for i in range(num.value): + dev_ptr = dev_array[i] # struct ibv_device * + name = lib.ibv_get_device_name(dev_ptr) # const char * + devices.append(name.decode()) + lib.ibv_free_device_list(dev_array) + return devices + + def _get_rdma_devices() -> list[str]: """ - use script like below to get RDMA devices, if NCCL_IB_HCA has multiple values, just return - ```bash - pushd /sys/class/infiniband/ > /dev/null; - for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done; - popd > /dev/null; - ``` + use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return """ devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES") if devices_str: @@ -284,32 +303,15 @@ def _get_rdma_devices() -> list[str]: # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices hca = os.getenv("NCCL_IB_HCA", None) if hca: - l = hca.split(",") # noqa: E741 - if len(l) > 1: + hca_list = hca.split(",") + if len(hca_list) > 1: # if NCCL_IB_HCA has multiple values, just return - return l + return hca_list else: - hca = l[0] - basepath = "/sys/class/infiniband/" - port_path = "ports/1/gid_attrs/types" - devices = [] - for device in sorted(os.listdir(basepath)): - if hca is not None and hca not in device: - continue - path = os.path.join(basepath, device, port_path) - if not os.path.exists(path) or not os.path.isdir(path): - continue - for port in os.listdir(path): - try: - with open(os.path.join(path, port)) as f: - content = f.read() - if "v" in content: - print(f"found rdma device {device} in port {port}: {content.strip()}") - devices.append(device) - break - except Exception: # noqa: BLE001,S110 - pass - return devices + hca = hca_list[0] + return [ + device for device in sorted(_ibv_get_device_list()) if hca is not None and hca in device + ] def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: