diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml new file mode 100644 index 0000000..6219b55 --- /dev/null +++ b/.github/workflows/cpu-tests.yml @@ -0,0 +1,30 @@ +name: CPU Tests + +on: + push: + branches: [main] + pull_request: + types: [opened, synchronize, reopened] + + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install .[p2p] + - name: Do CPU tests with pytest + run: | + pytest -v -m "not gpu" tests/ diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index e61acf1..1493a69 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -303,14 +303,7 @@ def _get_rdma_devices() -> list[str]: return devices_str.split(",") # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices hca = os.getenv("NCCL_IB_HCA", None) - if hca: - hca_list = hca.split(",") - if len(hca_list) > 1: - # if NCCL_IB_HCA has multiple values, just return - return hca_list - else: - hca = hca_list[0] - return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device] + return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list() def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: @@ -328,6 +321,75 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> return devices[local_rank // (gpu_count // len(devices))] +def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: + """ + The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8. + The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662. + + The list is comma-separated; port numbers are NOT supported yet. + An optional prefix '^' indicates the list is an exclude list. + A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix. + Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported. + + Examples: + - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`. + - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`. + - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`. + - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`. + """ + max_hcas = 32 + if not value or value.strip() == "": + return available_devices[:max_hcas] + + value = value.strip() + result = [] + is_exclude = value.startswith("^") + if is_exclude: + value = value.removeprefix("^") + is_exact_match = value.startswith("=") + if is_exact_match: + value = value.removeprefix("=") + + device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] + + result = _resolve_device_specs(device_specs, is_exact_match, available_devices) + if is_exclude: + result = [dev for dev in available_devices if dev not in result] + if len(result) > max_hcas: + result = result[:max_hcas] + + logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}") + + return result + + +def _resolve_device_specs( + device_specs: list[str], is_exact_match: bool, available_devices: list[str] +) -> list[str]: + devices = set() + for spec in device_specs: + parts = spec.split(":", 1) + device_name = parts[0].strip() + # HACK: mooncake transfer engine does not support port specification yet, so we ignore it + # port = parts[1].strip() if len(parts) > 1 else None + base_devices = ( + [device_name] + if device_name in available_devices + else [] + if is_exact_match + else [dev for dev in available_devices if dev.startswith(device_name)] + ) + + if not base_devices: + logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.") + continue + + for base_dev in base_devices: + devices.add(base_dev) + + return sorted(devices) + + def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]: class TPMeta(BaseModel): concat_dim: int diff --git a/pyproject.toml b/pyproject.toml index df26fb2..c200382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,3 +158,8 @@ inline-quotes = "double" [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" + +[tool.pytest.ini_options] +markers = [ + "gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')", +] diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py new file mode 100644 index 0000000..9b0951a --- /dev/null +++ b/tests/test_rdma_parser.py @@ -0,0 +1,197 @@ +import os +from unittest.mock import patch + +import pytest + +from checkpoint_engine.ps import ( + _get_my_rdma_device, + _get_rdma_devices, + _ibv_get_device_list, + _parse_NCCL_IB_HCA, +) + + +@pytest.fixture +def mock_available_devices() -> list[str]: + """Provide mock available device list""" + return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"] + + +def test_detect_ibv_list(): + """Test detection of _ibv_get_device_list function""" + # Skip this test if no real infiniband devices exist + if not os.path.exists("/sys/class/infiniband"): + pytest.skip("No infiniband devices found on system") + + real_ibv_list = sorted(os.listdir("/sys/class/infiniband")) + if real_ibv_list: + devices = _ibv_get_device_list() + assert isinstance(devices, list) + + +def test_parse_max_hcas_limit(): + """Test maximum HCA quantity limit""" + # Create mock data with more than 32 devices + many_devices = [f"device_{i}" for i in range(50)] + result = _parse_NCCL_IB_HCA("", many_devices) + assert len(result) == 32 + assert result == many_devices[:32] + + +def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): + """Test _get_rdma_devices with no environment variables""" + with ( + patch.dict(os.environ, clear=True), + patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices), + ): + devices = _get_rdma_devices() + assert sorted(devices) == sorted(mock_available_devices) + + +@pytest.mark.parametrize( + "input_value,expected", + [ + pytest.param("", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="empty string"), + pytest.param(" \t\n ", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="whitespace"), + pytest.param("None", [], id="None string"), + pytest.param("^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret"), + pytest.param("^=", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret-equals"), + pytest.param("=^", [], id="equals-caret"), + pytest.param("^^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="double-caret"), + pytest.param("=", [], id="equals"), + pytest.param("==", [], id="double-equals"), + ], +) +def test_parse_basic_cases( + input_value: str, expected: list[str], mock_available_devices: list[str] +): + """Test basic parsing cases: empty string, whitespace, None""" + result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) + assert result == expected + + +@pytest.mark.parametrize( + "input_value,expected", + [ + # prefix + ("mlx5_0", ["mlx5_0"]), + ("mlx5", ["mlx5_0", "mlx5_1"]), + # exact match + ("=mlx5_0", ["mlx5_0"]), + ("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), + # ignore ports, whitespace and duplicated commas + ("mlx5_0:1,mlx5_1:2", ["mlx5_0", "mlx5_1"]), + ("mlx5_0:1,mlx5_1", ["mlx5_0", "mlx5_1"]), + (" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]), + ("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]), + # exclusion + ("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]), + ("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]), + ("^mlx5", ["mlx4_0", "mlx4_1"]), + ("^=mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]), + ("^=mlx4", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]), + ], +) +def test_parse_various_patterns( + input_value: str, expected: list[str], mock_available_devices: list[str] +): + """Test various parsing patterns""" + result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) + assert result == expected + + +@pytest.mark.parametrize( + "input_value,expected_result,expected_warning", + [ + ("=mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=True."), + ("mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=False."), + ( + "^mlx5_100", + ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], + "No RDMA device match device_name='mlx5_100' where is_exact_match=False.", + ), + ("mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=False."), + ("=mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=True."), + ], +) +def test_parse_exact_match_with_nonexistent_device( + input_value: str, + expected_result: list[str], + expected_warning: str, + mock_available_devices: list[str], +): + """Test exact matching with non-existent device""" + with patch("checkpoint_engine.ps.logger") as mock_logger: + result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) + assert result == expected_result + mock_logger.warning.assert_called_once_with(expected_warning) + + +@pytest.mark.parametrize( + "env_var_name,env_var_value,expected_devices", + [ + ("PS_P2P_STORE_RDMA_DEVICES", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), + ("NCCL_IB_HCA", "mlx5", ["mlx5_0", "mlx5_1"]), + ("NCCL_IB_HCA", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), + ("NCCL_IB_HCA", "^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]), + ("NCCL_IB_HCA", "mlx6", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]), + ("NCCL_IB_HCA", "", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]), + ], +) +def test_get_rdma_devices_with_env_vars( + env_var_name: str, + env_var_value: str, + expected_devices: list[str], + mock_available_devices: list[str], +): + """Test _get_rdma_devices with various environment variables""" + env_dict = {env_var_name: env_var_value} + with ( + patch.dict(os.environ, env_dict), + patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices), + ): + devices = _get_rdma_devices() + assert sorted(devices) == sorted(expected_devices) + + +@pytest.mark.parametrize( + "local_rank,gpu_count,expected_device", + [ + (0, 4, "mlx5_0"), + (3, 4, "mlx5_3"), + (4, 8, "mlx5_2"), + (7, 8, "mlx5_3"), + ], +) +def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_device: str): + """Test _get_my_rdma_device with basic allocation""" + # Use fewer devices to match the GPU count constraint + devices = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3"] + device = _get_my_rdma_device(local_rank, gpu_count, devices) + assert device == expected_device + + +@pytest.mark.parametrize( + "local_rank,gpu_count,devices,error", + [ + ( + 0, + 4, + ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx5_4"], + AssertionError, + ), # Too many devices + ( + 0, + 8, + ["mlx5_0", "mlx5_1", "mlx5_2"], + AssertionError, + ), # GPU count not divisible by device count + (0, 8, [], RuntimeError), # No devices + ], +) +def test_get_my_rdma_device_invalid_config( + local_rank: int, gpu_count: int, devices: list[str], error: type +): + """Test _get_my_rdma_device with invalid configuration""" + with pytest.raises(error): + _get_my_rdma_device(local_rank, gpu_count, devices)