Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import concurrent.futures
import ctypes
import os
import pickle
import random
Expand Down Expand Up @@ -269,47 +270,48 @@ def _get_ip() -> str:
return socket.gethostbyname(socket.gethostname())


def _ibv_get_device_list() -> list[str]:
lib = ctypes.CDLL("libibverbs.so.1")
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **

lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *

num = ctypes.c_int()
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
if not dev_array or num.value <= 0:
return []

devices = []
for i in range(num.value):
dev_ptr = dev_array[i] # struct ibv_device *
name = lib.ibv_get_device_name(dev_ptr) # const char *
devices.append(name.decode())
lib.ibv_free_device_list(dev_array)
return devices


def _get_rdma_devices() -> list[str]:
"""
use script like below to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
```bash
pushd /sys/class/infiniband/ > /dev/null;
for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done;
popd > /dev/null;
```
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
"""
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
if devices_str:
return devices_str.split(",")
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
hca = os.getenv("NCCL_IB_HCA", None)
if hca:
l = hca.split(",") # noqa: E741
if len(l) > 1:
hca_list = hca.split(",")
if len(hca_list) > 1:
# if NCCL_IB_HCA has multiple values, just return
return l
return hca_list
else:
hca = l[0]
basepath = "/sys/class/infiniband/"
port_path = "ports/1/gid_attrs/types"
devices = []
for device in sorted(os.listdir(basepath)):
if hca is not None and hca not in device:
continue
path = os.path.join(basepath, device, port_path)
if not os.path.exists(path) or not os.path.isdir(path):
continue
for port in os.listdir(path):
try:
with open(os.path.join(path, port)) as f:
content = f.read()
if "v" in content:
print(f"found rdma device {device} in port {port}: {content.strip()}")
devices.append(device)
break
except Exception: # noqa: BLE001,S110
pass
return devices
hca = hca_list[0]
return [
device for device in sorted(_ibv_get_device_list()) if hca is not None and hca in device
]


def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
Expand Down