Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: CPU Tests

on:
push:
branches: [main]

permissions:
contents: read

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install .[p2p]
- name: Do CPU tests with pytest
run: |
pytest -v -m "cpu" tests/
78 changes: 70 additions & 8 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,7 @@ def _get_rdma_devices() -> list[str]:
return devices_str.split(",")
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
hca = os.getenv("NCCL_IB_HCA", None)
if hca:
hca_list = hca.split(",")
if len(hca_list) > 1:
# if NCCL_IB_HCA has multiple values, just return
return hca_list
else:
hca = hca_list[0]
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()


def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
Expand All @@ -328,6 +321,75 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
return devices[local_rank // (gpu_count // len(devices))]


def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
"""
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.

The list is comma-separated; port numbers are NOT supported yet.
An optional prefix '^' indicates the list is an exclude list.
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.

Examples:
- NCCL_IB_HCA=mlx5: Use all cards starting with mlx5.
- NCCL_IB_HCA==mlx5_0,mlx5_1 : Use specific cards mlx5_0 and mlx5_1.
- NCCL_IB_HCA=^mlx5: Use all cards except those starting with mlx5.
- NCCL_IB_HCA=^=mlx5_0,mlx5_1: Use all cards except mlx5_0 and mlx5_1.
"""
max_hcas = 32
if not value or value.strip() == "":
return available_devices[:max_hcas]

value = value.strip()
result = []
is_exclude = value.startswith("^")
if is_exclude:
value = value.removeprefix("^")
is_exact_match = value.startswith("=")
if is_exact_match:
value = value.removeprefix("=")

device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]

result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
if is_exclude:
result = [dev for dev in available_devices if dev not in result]
if len(result) > max_hcas:
result = result[:max_hcas]

logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")

return result


def _resolve_device_specs(
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
) -> list[str]:
devices = set()
for spec in device_specs:
parts = spec.split(":", 1)
device_name = parts[0].strip()
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
# port = parts[1].strip() if len(parts) > 1 else None
base_devices = (
[device_name]
if device_name in available_devices
else []
if is_exact_match
else [dev for dev in available_devices if dev.startswith(device_name)]
)

if not base_devices:
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
continue

for base_dev in base_devices:
devices.add(base_dev)

return sorted(devices)


def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
class TPMeta(BaseModel):
concat_dim: int
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,9 @@ inline-quotes = "double"

[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.pytest.ini_options]
markers = [
"cpu: marks tests as CPU test (deselect with '-m \"not cpu\"')",
"gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')",
]
206 changes: 206 additions & 0 deletions tests/test_rdma_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import os
from unittest.mock import patch

import pytest

from checkpoint_engine.ps import (
_get_my_rdma_device,
_get_rdma_devices,
_ibv_get_device_list,
_parse_NCCL_IB_HCA,
)


@pytest.fixture
def mock_available_devices() -> list[str]:
"""Provide mock available device list"""
return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]


@pytest.mark.cpu
def test_detect_ibv_list():
"""Test detection of _ibv_get_device_list function"""
# Skip this test if no real infiniband devices exist
if not os.path.exists("/sys/class/infiniband"):
pytest.skip("No infiniband devices found on system")

real_ibv_list = sorted(os.listdir("/sys/class/infiniband"))
if real_ibv_list:
devices = _ibv_get_device_list()
assert isinstance(devices, list)


@pytest.mark.cpu
def test_parse_max_hcas_limit():
"""Test maximum HCA quantity limit"""
# Create mock data with more than 32 devices
many_devices = [f"device_{i}" for i in range(50)]
result = _parse_NCCL_IB_HCA("", many_devices)
assert len(result) == 32
assert result == many_devices[:32]


@pytest.mark.cpu
def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]):
"""Test _get_rdma_devices with no environment variables"""
with (
patch.dict(os.environ, clear=True),
patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices),
):
devices = _get_rdma_devices()
assert sorted(devices) == sorted(mock_available_devices)


@pytest.mark.cpu
@pytest.mark.parametrize(
"input_value,expected",
[
pytest.param("", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="empty string"),
pytest.param(" \t\n ", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="whitespace"),
pytest.param("None", [], id="None string"),
pytest.param("^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret"),
pytest.param("^=", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret-equals"),
pytest.param("=^", [], id="equals-caret"),
pytest.param("^^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="double-caret"),
pytest.param("=", [], id="equals"),
pytest.param("==", [], id="double-equals"),
],
)
def test_parse_basic_cases(
input_value: str, expected: list[str], mock_available_devices: list[str]
):
"""Test basic parsing cases: empty string, whitespace, None"""
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
assert result == expected


@pytest.mark.cpu
@pytest.mark.parametrize(
"input_value,expected",
[
# prefix
("mlx5_0", ["mlx5_0"]),
("mlx5", ["mlx5_0", "mlx5_1"]),
# exact match
("=mlx5_0", ["mlx5_0"]),
("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
# ignore ports, whitespace and duplicated commas
("mlx5_0:1,mlx5_1:2", ["mlx5_0", "mlx5_1"]),
("mlx5_0:1,mlx5_1", ["mlx5_0", "mlx5_1"]),
(" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]),
("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]),
# exclusion
("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]),
("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]),
("^mlx5", ["mlx4_0", "mlx4_1"]),
("^=mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]),
("^=mlx4", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
],
)
def test_parse_various_patterns(
input_value: str, expected: list[str], mock_available_devices: list[str]
):
"""Test various parsing patterns"""
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
assert result == expected


@pytest.mark.cpu
@pytest.mark.parametrize(
"input_value,expected_result,expected_warning",
[
("=mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=True."),
("mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=False."),
(
"^mlx5_100",
["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"],
"No RDMA device match device_name='mlx5_100' where is_exact_match=False.",
),
("mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=False."),
("=mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=True."),
],
)
def test_parse_exact_match_with_nonexistent_device(
input_value: str,
expected_result: list[str],
expected_warning: str,
mock_available_devices: list[str],
):
"""Test exact matching with non-existent device"""
with patch("checkpoint_engine.ps.logger") as mock_logger:
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
assert result == expected_result
mock_logger.warning.assert_called_once_with(expected_warning)


@pytest.mark.cpu
@pytest.mark.parametrize(
"env_var_name,env_var_value,expected_devices",
[
("PS_P2P_STORE_RDMA_DEVICES", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
("NCCL_IB_HCA", "mlx5", ["mlx5_0", "mlx5_1"]),
("NCCL_IB_HCA", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
("NCCL_IB_HCA", "^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]),
("NCCL_IB_HCA", "mlx6", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
("NCCL_IB_HCA", "", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
],
)
def test_get_rdma_devices_with_env_vars(
env_var_name: str,
env_var_value: str,
expected_devices: list[str],
mock_available_devices: list[str],
):
"""Test _get_rdma_devices with various environment variables"""
env_dict = {env_var_name: env_var_value}
with (
patch.dict(os.environ, env_dict),
patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices),
):
devices = _get_rdma_devices()
assert sorted(devices) == sorted(expected_devices)


@pytest.mark.cpu
@pytest.mark.parametrize(
"local_rank,gpu_count,expected_device",
[
(0, 4, "mlx5_0"),
(3, 4, "mlx5_3"),
(4, 8, "mlx5_2"),
(7, 8, "mlx5_3"),
],
)
def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_device: str):
"""Test _get_my_rdma_device with basic allocation"""
# Use fewer devices to match the GPU count constraint
devices = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3"]
device = _get_my_rdma_device(local_rank, gpu_count, devices)
assert device == expected_device


@pytest.mark.cpu
@pytest.mark.parametrize(
"local_rank,gpu_count,devices,error",
[
(
0,
4,
["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx5_4"],
AssertionError,
), # Too many devices
(
0,
8,
["mlx5_0", "mlx5_1", "mlx5_2"],
AssertionError,
), # GPU count not divisible by device count
(0, 8, [], RuntimeError), # No devices
],
)
def test_get_my_rdma_device_invalid_config(
local_rank: int, gpu_count: int, devices: list[str], error: type
):
"""Test _get_my_rdma_device with invalid configuration"""
with pytest.raises(error):
_get_my_rdma_device(local_rank, gpu_count, devices)