From 03ca80d8763b4c4747c6aab0fad727c7c4a28af9 Mon Sep 17 00:00:00 2001 From: specture724 Date: Wed, 15 Oct 2025 07:10:58 +0000 Subject: [PATCH 01/13] feat: NCCLIBHCAParser class added, supporting exact match, exclude, and port specifications for RDMA devices. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8 --- checkpoint_engine/ps.py | 185 +++++++++++++++++++++----------- tests/test_rdma_parser.py | 214 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 341 insertions(+), 58 deletions(-) create mode 100644 tests/test_rdma_parser.py diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index e61acf1..c170a14 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -271,63 +271,6 @@ def _get_ip() -> str: return socket.gethostbyname(socket.gethostname()) -def _ibv_get_device_list() -> list[str]: - lib = ctypes.CDLL("libibverbs.so.1") - lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices - lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device ** - - lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)] - lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device * - lib.ibv_get_device_name.restype = ctypes.c_char_p # const char * - - num = ctypes.c_int() - dev_array = lib.ibv_get_device_list(ctypes.byref(num)) - if not dev_array or num.value <= 0: - return [] - - devices = [] - for i in range(num.value): - dev_ptr = dev_array[i] # struct ibv_device * - name = lib.ibv_get_device_name(dev_ptr) # const char * - devices.append(name.decode()) - lib.ibv_free_device_list(dev_array) - return devices - - -def _get_rdma_devices() -> list[str]: - """ - use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return - """ - devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES") - if devices_str: - return devices_str.split(",") - # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices - hca = os.getenv("NCCL_IB_HCA", None) - if hca: - hca_list = hca.split(",") - if len(hca_list) > 1: - # if NCCL_IB_HCA has multiple values, just return - return hca_list - else: - hca = hca_list[0] - return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device] - - -def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: - """ - implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc. - """ - if not devices: - raise RuntimeError("no rdma devices found") - assert len(devices) <= gpu_count, ( - f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" - ) - assert gpu_count % len(devices) == 0, ( - f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" - ) - return devices[local_rank // (gpu_count // len(devices))] - - def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]: class TPMeta(BaseModel): concat_dim: int @@ -525,6 +468,129 @@ def _get_master_port(master_port: int | None = None) -> int: return master_port +class NCCLIBHCAParser: + def __init__(self): + self.max_hcas = 32 + self.available_devices = self._ibv_get_device_list() + logger.info(f"Available RDMA Devices: {self.available_devices}") + + def parse(self, value: str) -> list[str]: + if not value or value.strip() == "": + return self.available_devices[: self.max_hcas] + + value = value.strip() + result = [] + is_exclude = value.startswith("^") + is_exact_match = value.startswith("=") + + cnt = 0 + while value and value[0] in ("^", "=") and cnt < 2: + if value[0] == "^": + is_exclude = True + elif value[0] == "=": + is_exact_match = True + value = value[1:] + cnt += 1 + + device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] + + if is_exclude: + excluded_devices = self._resolve_device_specs(device_specs, is_exact_match) + for excluded in excluded_devices: + if excluded not in self.available_devices: + logger.warning(f"device '{excluded}' not found in available devices.") + excluded_devices.remove(excluded) + result = [dev for dev in self.available_devices if dev not in excluded_devices] + else: + result = self._resolve_device_specs(device_specs, is_exact_match) + + if len(result) > self.max_hcas: + result = result[: self.max_hcas] + + logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}") + + return result + + def _resolve_device_specs(self, device_specs: list[str], is_exact_match: bool) -> list[str]: + devices = set() + for spec in device_specs: + device_name, port = ( + map(str.strip, spec.split(":", 1)) if ":" in spec else (spec.strip(), None) + ) + base_devices = ( + [device_name] + if is_exact_match + else [dev for dev in self.available_devices if dev.startswith(device_name)] + ) + if is_exact_match and device_name not in self.available_devices: + logger.warning(f"Device '{device_name}' not found in available devices.") + continue + + if not base_devices: + logger.warning(f"No devices match the prefix '{device_name}'.") + continue + + for base_dev in base_devices: + devices.add(f"{base_dev}:{port}" if port else f"{base_dev}") + + return sorted(devices) + + def _ibv_get_device_list(self) -> list[str]: + lib = ctypes.CDLL("libibverbs.so.1") + lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices + lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device ** + + lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)] + lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device * + lib.ibv_get_device_name.restype = ctypes.c_char_p # const char * + + num = ctypes.c_int() + dev_array = lib.ibv_get_device_list(ctypes.byref(num)) + if not dev_array or num.value <= 0: + return [] + + devices = [] + for i in range(num.value): + dev_ptr = dev_array[i] # struct ibv_device * + name = lib.ibv_get_device_name(dev_ptr) # const char * + devices.append(name.decode()) + lib.ibv_free_device_list(dev_array) + return devices + + def _get_rdma_devices(self) -> list[str]: + """ + use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return + """ + devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES") + if devices_str: + return devices_str.split(",") + # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices + hca = os.getenv("NCCL_IB_HCA", None) + + if hca: + hca_list = self.parse(hca) + if len(hca_list) > 1: + # if NCCL_IB_HCA has multiple values, just return + return hca_list + else: + hca = hca_list[0] + return [ + device for device in sorted(self._ibv_get_device_list()) if hca is None or hca in device + ] + + def _get_my_rdma_device(self, local_rank: int, gpu_count: int, devices: list[str]) -> str: + """ + implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc. + if some NICs are down, causing the number of NICs is undivisible by the number of GPUs, assign the remaining GPUs to the closest NIC. + """ + if not devices: + raise RuntimeError("no rdma devices found") + assert len(devices) <= gpu_count, ( + f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" + ) + return devices[local_rank // (gpu_count // len(devices))] + + class P2PStore: def __init__(self): from mooncake.engine import TransferEngine @@ -532,7 +598,10 @@ def __init__(self): self.rank = int(os.getenv("RANK")) gpu_count = torch.cuda.device_count() local_rank = self.rank % gpu_count - device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) + rdma_parser = NCCLIBHCAParser() + device = rdma_parser._get_my_rdma_device( + local_rank, gpu_count, rdma_parser._get_rdma_devices() + ) self.ip = _get_ip() # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py new file mode 100644 index 0000000..885dec4 --- /dev/null +++ b/tests/test_rdma_parser.py @@ -0,0 +1,214 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from checkpoint_engine.ps import NCCLIBHCAParser +import os + +class TestNCCLIBHCAParser: + """Unit tests for NCCLIBHCAParser class""" + + @pytest.fixture + def mock_available_devices(self): + """Provide mock available device list""" + return [ + 'mlx5_0', 'mlx5_1', 'mlx5_2', 'mlx5_3', + 'mlx4_0', 'mlx4_1', + 'roce_0', 'roce_1' + ] + + @pytest.fixture + def parser_with_mock_devices(self, mock_available_devices): + """Create parser instance with mock devices""" + with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=mock_available_devices): + parser = NCCLIBHCAParser() + return parser + + def test_detect_ibv_list(self): + """Test detection of _ibv_get_device_list function""" + parser = NCCLIBHCAParser() + real_ibv_list = os.listdir('/sys/class/infiniband') if os.path.exists('/sys/class/infiniband') else [] + if real_ibv_list: + assert parser.available_devices == real_ibv_list + + def test_init_with_mock_devices(self, mock_available_devices): + """Test correct device list initialization""" + with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=mock_available_devices): + parser = NCCLIBHCAParser() + assert parser.available_devices == mock_available_devices + assert parser.max_hcas == 32 + + def test_parse_empty_string_returns_all_devices(self, parser_with_mock_devices): + """Test empty string returns all devices""" + result = parser_with_mock_devices.parse("") + assert result == parser_with_mock_devices.available_devices + + def test_parse_whitespace_only_returns_all_devices(self, parser_with_mock_devices): + """Test whitespace-only string returns all devices""" + result = parser_with_mock_devices.parse(" \t\n ") + assert result == parser_with_mock_devices.available_devices + + def test_parse_none_string_returns_all_devices(self, parser_with_mock_devices): + """Test 'None' string returns empty list (no matching devices)""" + result = parser_with_mock_devices.parse("None") + # "None" is treated as a regular string, tries prefix matching + # Since no devices start with "None", should return empty list + assert result == [] + + def test_parse_prefix_match_single_device(self, parser_with_mock_devices): + """Test prefix matching for single device""" + result = parser_with_mock_devices.parse("mlx5_0") + assert result == ['mlx5_0'] + + def test_parse_prefix_match_multiple_devices(self, parser_with_mock_devices): + """Test prefix matching for multiple devices""" + result = parser_with_mock_devices.parse("mlx5") + expected = ['mlx5_0', 'mlx5_1', 'mlx5_2', 'mlx5_3'] + assert result == expected + + def test_parse_exact_match_single_device(self, parser_with_mock_devices): + """Test exact matching for single device""" + result = parser_with_mock_devices.parse("=mlx5_0") + assert result == ['mlx5_0'] + + def test_parse_exact_match_multiple_devices(self, parser_with_mock_devices): + """Test exact matching for multiple devices""" + result = parser_with_mock_devices.parse("=mlx5_0,mlx5_1") + assert result == ['mlx5_0', 'mlx5_1'] + + def test_parse_exact_match_with_nonexistent_device(self, parser_with_mock_devices): + """Test exact matching with non-existent device""" + with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: + result = parser_with_mock_devices.parse("=mlx5_100") + assert result == [] + mock_logger.warning.assert_called_once_with("Device 'mlx5_100' not found in available devices.") + + def test_parse_exclude_single_device(self, parser_with_mock_devices): + """Test excluding single device""" + result = parser_with_mock_devices.parse("^mlx5_0") + expected = [dev for dev in parser_with_mock_devices.available_devices if dev != 'mlx5_0'] + assert result == expected + + def test_parse_exclude_multiple_devices(self, parser_with_mock_devices): + """Test excluding multiple devices""" + result = parser_with_mock_devices.parse("^mlx5_0,mlx5_1") + expected = [dev for dev in parser_with_mock_devices.available_devices + if dev not in ['mlx5_0', 'mlx5_1']] + assert result == expected + + def test_parse_exclude_with_prefix_match(self, parser_with_mock_devices): + """Test exclusion with prefix matching""" + result = parser_with_mock_devices.parse("^mlx5") + expected = ['mlx4_0', 'mlx4_1', 'roce_0', 'roce_1'] + assert result == expected + + def test_parse_exclude_with_exact_match(self, parser_with_mock_devices): + """Test exclusion with exact matching""" + result = parser_with_mock_devices.parse("^=mlx5_0,mlx5_1") + expected = [dev for dev in parser_with_mock_devices.available_devices + if dev not in ['mlx5_0', 'mlx5_1']] + assert result == expected + + def test_parse_exclude_nonexistent_device(self, parser_with_mock_devices): + """Test excluding non-existent device""" + with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: + result = parser_with_mock_devices.parse("^mlx5_100") + expected = parser_with_mock_devices.available_devices + assert result == expected + mock_logger.warning.assert_called_once_with("No devices match the prefix 'mlx5_100'.") + + def test_parse_with_port_specification(self, parser_with_mock_devices): + """Test parsing with port specification""" + result = parser_with_mock_devices.parse("mlx5_0:1,mlx5_1:2") + expected = ['mlx5_0:1', 'mlx5_1:2'] + assert result == expected + + def test_parse_mixed_with_and_without_ports(self, parser_with_mock_devices): + """Test mixed parsing with and without port specifications""" + result = parser_with_mock_devices.parse("mlx5_0:1,mlx5_1") + expected = ['mlx5_0:1', 'mlx5_1'] + assert result == expected + + def test_parse_max_hcas_limit(self, parser_with_mock_devices): + """Test maximum HCA quantity limit""" + # Create mock data with more than 32 devices + many_devices = [f'device_{i}' for i in range(50)] + with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=many_devices): + parser = NCCLIBHCAParser() + result = parser.parse("") + assert len(result) == 32 + assert result == many_devices[:32] + + def test_parse_complex_combination(self, parser_with_mock_devices): + """Test complex combination parsing""" + result = parser_with_mock_devices.parse("^=mlx5_3,mlx4_1") + expected = [dev for dev in parser_with_mock_devices.available_devices + if dev not in ['mlx5_3', 'mlx4_1']] + assert result == expected + + def test_parse_multiple_prefix_operators(self, parser_with_mock_devices): + """Test multiple prefix operators""" + result = parser_with_mock_devices.parse("^=mlx5_0") + expected = [dev for dev in parser_with_mock_devices.available_devices if dev != 'mlx5_0'] + assert result == expected + + def test_parse_edge_case_empty_after_operators(self, parser_with_mock_devices): + """Test edge cases with empty content after operators""" + result = parser_with_mock_devices.parse("^") + # Empty exclusion list means exclude nothing, return all devices + assert result == parser_with_mock_devices.available_devices + + result = parser_with_mock_devices.parse("=") + # Empty exact match list means match nothing, return empty list + assert result == [] + + def test_parse_edge_case_only_operators(self, parser_with_mock_devices): + """Test edge cases with only operators""" + result = parser_with_mock_devices.parse("^=") + assert result == parser_with_mock_devices.available_devices + + result = parser_with_mock_devices.parse("=^") + assert result == parser_with_mock_devices.available_devices + + result = parser_with_mock_devices.parse("^^") + assert result == parser_with_mock_devices.available_devices + + result = parser_with_mock_devices.parse("==") + assert result == [] + + def test_parse_with_spaces_in_input(self, parser_with_mock_devices): + """Test parsing with spaces in input""" + result = parser_with_mock_devices.parse(" mlx5_0 , mlx5_1 ") + assert result == ['mlx5_0', 'mlx5_1'] + + def test_parse_empty_device_spec(self, parser_with_mock_devices): + """Test parsing with empty device specifications""" + result = parser_with_mock_devices.parse("mlx5_0,,mlx5_1") + assert result == ['mlx5_0', 'mlx5_1'] + + def test_ibv_get_device_list_real_implementation_mocked(self): + """Test ibv_get_device_list implementation with complete mocking to avoid ctypes issues""" + # Mock the entire method instead of using ctypes + with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=['mlx5_0', 'mlx5_1']): + parser = NCCLIBHCAParser() + devices = parser._ibv_get_device_list() + assert devices == ['mlx5_0', 'mlx5_1'] + + def test_ibv_get_device_list_no_devices_mocked(self): + """Test no available devices case with complete mocking""" + with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=[]): + parser = NCCLIBHCAParser() + devices = parser._ibv_get_device_list() + assert devices == [] + + def test_resolve_device_specs_no_match(self, parser_with_mock_devices): + """Test _resolve_device_specs with no matching devices""" + with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: + result = parser_with_mock_devices._resolve_device_specs(['nonexistent'], False) + assert result == [] + mock_logger.warning.assert_called_once_with("No devices match the prefix 'nonexistent'.") + + def test_resolve_device_specs_exact_match_not_found(self, parser_with_mock_devices): + """Test _resolve_device_specs with exact match not found""" + with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: + result = parser_with_mock_devices._resolve_device_specs(['nonexistent'], True) + assert result == [] + mock_logger.warning.assert_called_once_with("Device 'nonexistent' not found in available devices.") \ No newline at end of file From 34725ab08bcdbcd65d02a108799ff51596f39daa Mon Sep 17 00:00:00 2001 From: specture724 Date: Wed, 15 Oct 2025 07:22:00 +0000 Subject: [PATCH 02/13] style: ruff format --- tests/test_rdma_parser.py | 168 ++++++++++++++++++++++---------------- 1 file changed, 99 insertions(+), 69 deletions(-) diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index 885dec4..d4430de 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -1,156 +1,178 @@ +import os +from unittest.mock import patch + import pytest -from unittest.mock import Mock, patch, MagicMock + from checkpoint_engine.ps import NCCLIBHCAParser -import os + class TestNCCLIBHCAParser: """Unit tests for NCCLIBHCAParser class""" @pytest.fixture - def mock_available_devices(self): + def mock_available_devices(self) -> list[str]: """Provide mock available device list""" - return [ - 'mlx5_0', 'mlx5_1', 'mlx5_2', 'mlx5_3', - 'mlx4_0', 'mlx4_1', - 'roce_0', 'roce_1' - ] - + return ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx4_0", "mlx4_1", "roce_0", "roce_1"] + @pytest.fixture - def parser_with_mock_devices(self, mock_available_devices): + def parser_with_mock_devices(self, mock_available_devices: list[str]) -> NCCLIBHCAParser: """Create parser instance with mock devices""" - with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=mock_available_devices): + with patch.object( + NCCLIBHCAParser, "_ibv_get_device_list", return_value=mock_available_devices + ): parser = NCCLIBHCAParser() return parser def test_detect_ibv_list(self): """Test detection of _ibv_get_device_list function""" parser = NCCLIBHCAParser() - real_ibv_list = os.listdir('/sys/class/infiniband') if os.path.exists('/sys/class/infiniband') else [] + real_ibv_list = ( + os.listdir("/sys/class/infiniband") if os.path.exists("/sys/class/infiniband") else [] + ) if real_ibv_list: assert parser.available_devices == real_ibv_list - def test_init_with_mock_devices(self, mock_available_devices): + def test_init_with_mock_devices(self, mock_available_devices: list[str]): """Test correct device list initialization""" - with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=mock_available_devices): + with patch.object( + NCCLIBHCAParser, "_ibv_get_device_list", return_value=mock_available_devices + ): parser = NCCLIBHCAParser() assert parser.available_devices == mock_available_devices assert parser.max_hcas == 32 - def test_parse_empty_string_returns_all_devices(self, parser_with_mock_devices): + def test_parse_empty_string_returns_all_devices( + self, parser_with_mock_devices: NCCLIBHCAParser + ): """Test empty string returns all devices""" result = parser_with_mock_devices.parse("") assert result == parser_with_mock_devices.available_devices - def test_parse_whitespace_only_returns_all_devices(self, parser_with_mock_devices): + def test_parse_whitespace_only_returns_all_devices( + self, parser_with_mock_devices: NCCLIBHCAParser + ): """Test whitespace-only string returns all devices""" result = parser_with_mock_devices.parse(" \t\n ") assert result == parser_with_mock_devices.available_devices - def test_parse_none_string_returns_all_devices(self, parser_with_mock_devices): + def test_parse_none_string_returns_all_devices(self, parser_with_mock_devices: NCCLIBHCAParser): """Test 'None' string returns empty list (no matching devices)""" result = parser_with_mock_devices.parse("None") # "None" is treated as a regular string, tries prefix matching # Since no devices start with "None", should return empty list assert result == [] - def test_parse_prefix_match_single_device(self, parser_with_mock_devices): + def test_parse_prefix_match_single_device(self, parser_with_mock_devices: NCCLIBHCAParser): """Test prefix matching for single device""" result = parser_with_mock_devices.parse("mlx5_0") - assert result == ['mlx5_0'] + assert result == ["mlx5_0"] - def test_parse_prefix_match_multiple_devices(self, parser_with_mock_devices): + def test_parse_prefix_match_multiple_devices(self, parser_with_mock_devices: NCCLIBHCAParser): """Test prefix matching for multiple devices""" result = parser_with_mock_devices.parse("mlx5") - expected = ['mlx5_0', 'mlx5_1', 'mlx5_2', 'mlx5_3'] + expected = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3"] assert result == expected - def test_parse_exact_match_single_device(self, parser_with_mock_devices): + def test_parse_exact_match_single_device(self, parser_with_mock_devices: NCCLIBHCAParser): """Test exact matching for single device""" result = parser_with_mock_devices.parse("=mlx5_0") - assert result == ['mlx5_0'] + assert result == ["mlx5_0"] - def test_parse_exact_match_multiple_devices(self, parser_with_mock_devices): + def test_parse_exact_match_multiple_devices(self, parser_with_mock_devices: NCCLIBHCAParser): """Test exact matching for multiple devices""" result = parser_with_mock_devices.parse("=mlx5_0,mlx5_1") - assert result == ['mlx5_0', 'mlx5_1'] + assert result == ["mlx5_0", "mlx5_1"] - def test_parse_exact_match_with_nonexistent_device(self, parser_with_mock_devices): + def test_parse_exact_match_with_nonexistent_device( + self, parser_with_mock_devices: NCCLIBHCAParser + ): """Test exact matching with non-existent device""" - with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: + with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: result = parser_with_mock_devices.parse("=mlx5_100") assert result == [] - mock_logger.warning.assert_called_once_with("Device 'mlx5_100' not found in available devices.") + mock_logger.warning.assert_called_once_with( + "Device 'mlx5_100' not found in available devices." + ) - def test_parse_exclude_single_device(self, parser_with_mock_devices): + def test_parse_exclude_single_device(self, parser_with_mock_devices: NCCLIBHCAParser): """Test excluding single device""" result = parser_with_mock_devices.parse("^mlx5_0") - expected = [dev for dev in parser_with_mock_devices.available_devices if dev != 'mlx5_0'] + expected = [dev for dev in parser_with_mock_devices.available_devices if dev != "mlx5_0"] assert result == expected - def test_parse_exclude_multiple_devices(self, parser_with_mock_devices): + def test_parse_exclude_multiple_devices(self, parser_with_mock_devices: NCCLIBHCAParser): """Test excluding multiple devices""" result = parser_with_mock_devices.parse("^mlx5_0,mlx5_1") - expected = [dev for dev in parser_with_mock_devices.available_devices - if dev not in ['mlx5_0', 'mlx5_1']] + expected = [ + dev + for dev in parser_with_mock_devices.available_devices + if dev not in ["mlx5_0", "mlx5_1"] + ] assert result == expected - def test_parse_exclude_with_prefix_match(self, parser_with_mock_devices): + def test_parse_exclude_with_prefix_match(self, parser_with_mock_devices: NCCLIBHCAParser): """Test exclusion with prefix matching""" result = parser_with_mock_devices.parse("^mlx5") - expected = ['mlx4_0', 'mlx4_1', 'roce_0', 'roce_1'] + expected = ["mlx4_0", "mlx4_1", "roce_0", "roce_1"] assert result == expected - def test_parse_exclude_with_exact_match(self, parser_with_mock_devices): + def test_parse_exclude_with_exact_match(self, parser_with_mock_devices: NCCLIBHCAParser): """Test exclusion with exact matching""" result = parser_with_mock_devices.parse("^=mlx5_0,mlx5_1") - expected = [dev for dev in parser_with_mock_devices.available_devices - if dev not in ['mlx5_0', 'mlx5_1']] + expected = [ + dev + for dev in parser_with_mock_devices.available_devices + if dev not in ["mlx5_0", "mlx5_1"] + ] assert result == expected - def test_parse_exclude_nonexistent_device(self, parser_with_mock_devices): + def test_parse_exclude_nonexistent_device(self, parser_with_mock_devices: NCCLIBHCAParser): """Test excluding non-existent device""" - with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: + with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: result = parser_with_mock_devices.parse("^mlx5_100") expected = parser_with_mock_devices.available_devices assert result == expected mock_logger.warning.assert_called_once_with("No devices match the prefix 'mlx5_100'.") - def test_parse_with_port_specification(self, parser_with_mock_devices): + def test_parse_with_port_specification(self, parser_with_mock_devices: NCCLIBHCAParser): """Test parsing with port specification""" result = parser_with_mock_devices.parse("mlx5_0:1,mlx5_1:2") - expected = ['mlx5_0:1', 'mlx5_1:2'] + expected = ["mlx5_0:1", "mlx5_1:2"] assert result == expected - def test_parse_mixed_with_and_without_ports(self, parser_with_mock_devices): + def test_parse_mixed_with_and_without_ports(self, parser_with_mock_devices: NCCLIBHCAParser): """Test mixed parsing with and without port specifications""" result = parser_with_mock_devices.parse("mlx5_0:1,mlx5_1") - expected = ['mlx5_0:1', 'mlx5_1'] + expected = ["mlx5_0:1", "mlx5_1"] assert result == expected - def test_parse_max_hcas_limit(self, parser_with_mock_devices): + def test_parse_max_hcas_limit(self): """Test maximum HCA quantity limit""" # Create mock data with more than 32 devices - many_devices = [f'device_{i}' for i in range(50)] - with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=many_devices): + many_devices = [f"device_{i}" for i in range(50)] + with patch.object(NCCLIBHCAParser, "_ibv_get_device_list", return_value=many_devices): parser = NCCLIBHCAParser() result = parser.parse("") assert len(result) == 32 assert result == many_devices[:32] - def test_parse_complex_combination(self, parser_with_mock_devices): + def test_parse_complex_combination(self, parser_with_mock_devices: NCCLIBHCAParser): """Test complex combination parsing""" result = parser_with_mock_devices.parse("^=mlx5_3,mlx4_1") - expected = [dev for dev in parser_with_mock_devices.available_devices - if dev not in ['mlx5_3', 'mlx4_1']] + expected = [ + dev + for dev in parser_with_mock_devices.available_devices + if dev not in ["mlx5_3", "mlx4_1"] + ] assert result == expected - def test_parse_multiple_prefix_operators(self, parser_with_mock_devices): + def test_parse_multiple_prefix_operators(self, parser_with_mock_devices: NCCLIBHCAParser): """Test multiple prefix operators""" result = parser_with_mock_devices.parse("^=mlx5_0") - expected = [dev for dev in parser_with_mock_devices.available_devices if dev != 'mlx5_0'] + expected = [dev for dev in parser_with_mock_devices.available_devices if dev != "mlx5_0"] assert result == expected - def test_parse_edge_case_empty_after_operators(self, parser_with_mock_devices): + def test_parse_edge_case_empty_after_operators(self, parser_with_mock_devices: NCCLIBHCAParser): """Test edge cases with empty content after operators""" result = parser_with_mock_devices.parse("^") # Empty exclusion list means exclude nothing, return all devices @@ -160,7 +182,7 @@ def test_parse_edge_case_empty_after_operators(self, parser_with_mock_devices): # Empty exact match list means match nothing, return empty list assert result == [] - def test_parse_edge_case_only_operators(self, parser_with_mock_devices): + def test_parse_edge_case_only_operators(self, parser_with_mock_devices: NCCLIBHCAParser): """Test edge cases with only operators""" result = parser_with_mock_devices.parse("^=") assert result == parser_with_mock_devices.available_devices @@ -174,41 +196,49 @@ def test_parse_edge_case_only_operators(self, parser_with_mock_devices): result = parser_with_mock_devices.parse("==") assert result == [] - def test_parse_with_spaces_in_input(self, parser_with_mock_devices): + def test_parse_with_spaces_in_input(self, parser_with_mock_devices: NCCLIBHCAParser): """Test parsing with spaces in input""" result = parser_with_mock_devices.parse(" mlx5_0 , mlx5_1 ") - assert result == ['mlx5_0', 'mlx5_1'] + assert result == ["mlx5_0", "mlx5_1"] - def test_parse_empty_device_spec(self, parser_with_mock_devices): + def test_parse_empty_device_spec(self, parser_with_mock_devices: NCCLIBHCAParser): """Test parsing with empty device specifications""" result = parser_with_mock_devices.parse("mlx5_0,,mlx5_1") - assert result == ['mlx5_0', 'mlx5_1'] + assert result == ["mlx5_0", "mlx5_1"] def test_ibv_get_device_list_real_implementation_mocked(self): """Test ibv_get_device_list implementation with complete mocking to avoid ctypes issues""" # Mock the entire method instead of using ctypes - with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=['mlx5_0', 'mlx5_1']): + with patch.object( + NCCLIBHCAParser, "_ibv_get_device_list", return_value=["mlx5_0", "mlx5_1"] + ): parser = NCCLIBHCAParser() devices = parser._ibv_get_device_list() - assert devices == ['mlx5_0', 'mlx5_1'] + assert devices == ["mlx5_0", "mlx5_1"] def test_ibv_get_device_list_no_devices_mocked(self): """Test no available devices case with complete mocking""" - with patch.object(NCCLIBHCAParser, '_ibv_get_device_list', return_value=[]): + with patch.object(NCCLIBHCAParser, "_ibv_get_device_list", return_value=[]): parser = NCCLIBHCAParser() devices = parser._ibv_get_device_list() assert devices == [] - def test_resolve_device_specs_no_match(self, parser_with_mock_devices): + def test_resolve_device_specs_no_match(self, parser_with_mock_devices: NCCLIBHCAParser): """Test _resolve_device_specs with no matching devices""" - with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: - result = parser_with_mock_devices._resolve_device_specs(['nonexistent'], False) + with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: + result = parser_with_mock_devices._resolve_device_specs(["nonexistent"], False) assert result == [] - mock_logger.warning.assert_called_once_with("No devices match the prefix 'nonexistent'.") + mock_logger.warning.assert_called_once_with( + "No devices match the prefix 'nonexistent'." + ) - def test_resolve_device_specs_exact_match_not_found(self, parser_with_mock_devices): + def test_resolve_device_specs_exact_match_not_found( + self, parser_with_mock_devices: NCCLIBHCAParser + ): """Test _resolve_device_specs with exact match not found""" - with patch('checkpoint_engine.rdma_parser.logger') as mock_logger: - result = parser_with_mock_devices._resolve_device_specs(['nonexistent'], True) + with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: + result = parser_with_mock_devices._resolve_device_specs(["nonexistent"], True) assert result == [] - mock_logger.warning.assert_called_once_with("Device 'nonexistent' not found in available devices.") \ No newline at end of file + mock_logger.warning.assert_called_once_with( + "Device 'nonexistent' not found in available devices." + ) From 8080d47b4f457bf1f2e267f68825e034fb99a031 Mon Sep 17 00:00:00 2001 From: specture724 Date: Wed, 15 Oct 2025 07:44:53 +0000 Subject: [PATCH 03/13] fix: logger import dir --- tests/test_rdma_parser.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index d4430de..184c4c6 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -27,7 +27,9 @@ def test_detect_ibv_list(self): """Test detection of _ibv_get_device_list function""" parser = NCCLIBHCAParser() real_ibv_list = ( - os.listdir("/sys/class/infiniband") if os.path.exists("/sys/class/infiniband") else [] + os.listdir("/sys/class/infiniband").sort() + if os.path.exists("/sys/class/infiniband") + else [] ) if real_ibv_list: assert parser.available_devices == real_ibv_list @@ -87,7 +89,7 @@ def test_parse_exact_match_with_nonexistent_device( self, parser_with_mock_devices: NCCLIBHCAParser ): """Test exact matching with non-existent device""" - with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: + with patch("checkpoint_engine.ps.logger") as mock_logger: result = parser_with_mock_devices.parse("=mlx5_100") assert result == [] mock_logger.warning.assert_called_once_with( @@ -128,7 +130,7 @@ def test_parse_exclude_with_exact_match(self, parser_with_mock_devices: NCCLIBHC def test_parse_exclude_nonexistent_device(self, parser_with_mock_devices: NCCLIBHCAParser): """Test excluding non-existent device""" - with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: + with patch("checkpoint_engine.ps.logger") as mock_logger: result = parser_with_mock_devices.parse("^mlx5_100") expected = parser_with_mock_devices.available_devices assert result == expected @@ -225,7 +227,7 @@ def test_ibv_get_device_list_no_devices_mocked(self): def test_resolve_device_specs_no_match(self, parser_with_mock_devices: NCCLIBHCAParser): """Test _resolve_device_specs with no matching devices""" - with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: + with patch("checkpoint_engine.ps.logger") as mock_logger: result = parser_with_mock_devices._resolve_device_specs(["nonexistent"], False) assert result == [] mock_logger.warning.assert_called_once_with( @@ -236,7 +238,7 @@ def test_resolve_device_specs_exact_match_not_found( self, parser_with_mock_devices: NCCLIBHCAParser ): """Test _resolve_device_specs with exact match not found""" - with patch("checkpoint_engine.rdma_parser.logger") as mock_logger: + with patch("checkpoint_engine.ps.logger") as mock_logger: result = parser_with_mock_devices._resolve_device_specs(["nonexistent"], True) assert result == [] mock_logger.warning.assert_called_once_with( From 083e1a505cac05ab4c958bf5d8ab5c68d7088b09 Mon Sep 17 00:00:00 2001 From: specture724 Date: Wed, 15 Oct 2025 10:04:59 +0000 Subject: [PATCH 04/13] misc --- checkpoint_engine/ps.py | 72 ++++++++++++-- tests/test_rdma_parser.py | 194 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+), 8 deletions(-) create mode 100644 tests/test_rdma_parser.py diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index e61acf1..cbe5391 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -303,14 +303,7 @@ def _get_rdma_devices() -> list[str]: return devices_str.split(",") # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices hca = os.getenv("NCCL_IB_HCA", None) - if hca: - hca_list = hca.split(",") - if len(hca_list) > 1: - # if NCCL_IB_HCA has multiple values, just return - return hca_list - else: - hca = hca_list[0] - return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device] + return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list() def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str: @@ -328,6 +321,68 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> return devices[local_rank // (gpu_count // len(devices))] +def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: + max_hcas = 32 + if not value or value.strip() == "": + return available_devices[:max_hcas] + + value = value.strip() + result = [] + is_exclude = value.startswith("^") + is_exact_match = value.startswith("=") + + cnt = 0 + while value and value[0] in ("^", "=") and cnt < 2: + if value[0] == "^": + is_exclude = True + elif value[0] == "=": + is_exact_match = True + value = value[1:] + cnt += 1 + + device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] + + if is_exclude: + excluded_devices = _resolve_device_specs(device_specs, is_exact_match, available_devices) + result = [dev for dev in available_devices if dev not in excluded_devices] + else: + result = _resolve_device_specs(device_specs, is_exact_match, available_devices) + + if len(result) > max_hcas: + result = result[:max_hcas] + + logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}") + + return result + + +def _resolve_device_specs( + device_specs: list[str], is_exact_match: bool, available_devices: list[str] +) -> list[str]: + devices = set() + for spec in device_specs: + device_name, port = ( + map(str.strip, spec.split(":", 1)) if ":" in spec else (spec.strip(), None) + ) + base_devices = ( + [device_name] + if is_exact_match + else [dev for dev in available_devices if dev.startswith(device_name)] + ) + if is_exact_match and device_name not in available_devices: + logger.warning(f"Device '{device_name}' not found in available devices.") + continue + + if not base_devices: + logger.warning(f"No devices match the prefix '{device_name}'.") + continue + + for base_dev in base_devices: + devices.add(f"{base_dev}:{port}" if port else f"{base_dev}") + + return sorted(devices) + + def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]: class TPMeta(BaseModel): concat_dim: int @@ -533,6 +588,7 @@ def __init__(self): gpu_count = torch.cuda.device_count() local_rank = self.rank % gpu_count device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) + device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) self.ip = _get_ip() # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py new file mode 100644 index 0000000..31f5e4a --- /dev/null +++ b/tests/test_rdma_parser.py @@ -0,0 +1,194 @@ +import os +from unittest.mock import patch + +import pytest + +from checkpoint_engine.ps import ( + _get_my_rdma_device, + _get_rdma_devices, + _ibv_get_device_list, + _parse_NCCL_IB_HCA, +) + + +@pytest.fixture +def mock_available_devices() -> list[str]: + """Provide mock available device list""" + return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"] + + +def test_detect_ibv_list(): + """Test detection of _ibv_get_device_list function""" + # Skip this test if no real infiniband devices exist + if not os.path.exists("/sys/class/infiniband"): + pytest.skip("No infiniband devices found on system") + + real_ibv_list = sorted(os.listdir("/sys/class/infiniband")) + if real_ibv_list: + devices = _ibv_get_device_list() + assert isinstance(devices, list) + + +def test_parse_max_hcas_limit(): + """Test maximum HCA quantity limit""" + # Create mock data with more than 32 devices + many_devices = [f"device_{i}" for i in range(50)] + result = _parse_NCCL_IB_HCA("", many_devices) + assert len(result) == 32 + assert result == many_devices[:32] + + +def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): + """Test _get_rdma_devices with no environment variables""" + with ( + patch.dict(os.environ, clear=True), + patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices), + ): + devices = _get_rdma_devices() + assert sorted(devices) == sorted(mock_available_devices) + + +@pytest.mark.parametrize( + "input_value,expected", + [ + ("", "mock_available_devices"), # Special marker for fixture + (" \t\n ", "mock_available_devices"), # Special marker for fixture + ("None", []), + ("^", "mock_available_devices"), # Special marker for fixture + ("^=", "mock_available_devices"), + ("=^", "mock_available_devices"), + ("^^", "mock_available_devices"), + ("=", []), + ("==", []), + ], +) +def test_parse_basic_cases(input_value: str, expected: str, mock_available_devices: list[str]): + """Test basic parsing cases: empty string, whitespace, None""" + result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) + if expected == "mock_available_devices": + assert result == mock_available_devices + else: + assert result == expected + + +@pytest.mark.parametrize( + "input_value,expected", + [ + ("mlx5_0", ["mlx5_0"]), + ("mlx5", ["mlx5_0", "mlx5_1"]), + ("=mlx5_0", ["mlx5_0"]), + ("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), + ("mlx5_0:1,mlx5_1:2", ["mlx5_0:1", "mlx5_1:2"]), + ("mlx5_0:1,mlx5_1", ["mlx5_0:1", "mlx5_1"]), + (" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]), + ("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]), + ("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]), + ("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]), + ("^mlx5", ["mlx4_0", "mlx4_1"]), + ("^=mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]), + ("^=mlx4", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]), + ], +) +def test_parse_various_patterns( + input_value: str, expected: list[str], mock_available_devices: list[str] +): + """Test various parsing patterns""" + result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) + assert result == expected + + +@pytest.mark.parametrize( + "input_value,expected_result,expected_warning", + [ + ("=mlx5_100", [], "Device 'mlx5_100' not found in available devices."), + ("mlx5_100", [], "No devices match the prefix 'mlx5_100'."), + ( + "^mlx5_100", + ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], + "No devices match the prefix 'mlx5_100'.", + ), + ("mlx6", [], "No devices match the prefix 'mlx6'."), + ("=mlx6", [], "Device 'mlx6' not found in available devices."), + ], +) +def test_parse_exact_match_with_nonexistent_device( + input_value: str, + expected_result: list[str], + expected_warning: str, + mock_available_devices: list[str], +): + """Test exact matching with non-existent device""" + with patch("checkpoint_engine.ps.logger") as mock_logger: + result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) + assert result == expected_result + mock_logger.warning.assert_called_once_with(expected_warning) + + +@pytest.mark.parametrize( + "env_var_name,env_var_value,expected_devices", + [ + ("PS_P2P_STORE_RDMA_DEVICES", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), + ("NCCL_IB_HCA", "mlx5", ["mlx5_0", "mlx5_1"]), + ("NCCL_IB_HCA", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), + ("NCCL_IB_HCA", "^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]), + ("NCCL_IB_HCA", "mlx6", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]), + ("NCCL_IB_HCA", "", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]), + ], +) +def test_get_rdma_devices_with_env_vars( + env_var_name: str, + env_var_value: str, + expected_devices: list[str], + mock_available_devices: list[str], +): + """Test _get_rdma_devices with various environment variables""" + env_dict = {env_var_name: env_var_value} + with ( + patch.dict(os.environ, env_dict), + patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices), + ): + devices = _get_rdma_devices() + assert sorted(devices) == sorted(expected_devices) + + +@pytest.mark.parametrize( + "local_rank,gpu_count,expected_device", + [ + (0, 4, "mlx5_0"), + (3, 4, "mlx5_3"), + (4, 8, "mlx5_2"), + (7, 8, "mlx5_3"), + ], +) +def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_device: str): + """Test _get_my_rdma_device with basic allocation""" + # Use fewer devices to match the GPU count constraint + devices = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3"] + device = _get_my_rdma_device(local_rank, gpu_count, devices) + assert device == expected_device + + +@pytest.mark.parametrize( + "local_rank,gpu_count,devices,error", + [ + ( + 0, + 4, + ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx5_4"], + AssertionError, + ), # Too many devices + ( + 0, + 8, + ["mlx5_0", "mlx5_1", "mlx5_2"], + AssertionError, + ), # GPU count not divisible by device count + (0, 8, [], RuntimeError), # No devices + ], +) +def test_get_my_rdma_device_invalid_config( + local_rank: int, gpu_count: int, devices: list[str], error: type +): + """Test _get_my_rdma_device with invalid configuration""" + with pytest.raises(error): + _get_my_rdma_device(local_rank, gpu_count, devices) From 5316d47275548b96d14d8f5c62788406d438229a Mon Sep 17 00:00:00 2001 From: specture724 Date: Wed, 15 Oct 2025 11:36:35 +0000 Subject: [PATCH 05/13] misc --- checkpoint_engine/ps.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index f037c5a..a859ab5 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -331,14 +331,12 @@ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: is_exclude = value.startswith("^") is_exact_match = value.startswith("=") - cnt = 0 - while value and value[0] in ("^", "=") and cnt < 2: - if value[0] == "^": - is_exclude = True - elif value[0] == "=": - is_exact_match = True + prefix_chars_processed = 0 + while value and value[0] in ("^", "=") and prefix_chars_processed < 2: value = value[1:] - cnt += 1 + is_exact_match = is_exact_match or value.startswith("=") + is_exclude = is_exclude or value.startswith("^") + prefix_chars_processed += 1 device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] @@ -361,9 +359,9 @@ def _resolve_device_specs( ) -> list[str]: devices = set() for spec in device_specs: - device_name, port = ( - map(str.strip, spec.split(":", 1)) if ":" in spec else (spec.strip(), None) - ) + parts = spec.split(":", 1) + device_name = parts[0].strip() + port = parts[1].strip() if len(parts) > 1 else None base_devices = ( [device_name] if is_exact_match From d9f6f1e1e3c1e59318dea2a296789d258e293926 Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 07:19:09 +0000 Subject: [PATCH 06/13] misc: fix pr issues --- checkpoint_engine/ps.py | 38 ++++++++++++++++++++++----------- tests/test_rdma_parser.py | 45 +++++++++++++++++++++------------------ 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index a859ab5..04d7ba1 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -322,6 +322,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: + """ + The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8. + The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662. + + The list is comma-separated; port numbers are NOT supported yet. + An optional prefix '^' indicates the list is an exclude list. + A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix. + Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported. + + Examples: + - NCCL_IB_HCA=mlx5: Use all cards starting with mlx5. + - NCCL_IB_HCA==mlx5_0,mlx5_1 : Use specific cards mlx5_0 and mlx5_1. + - NCCL_IB_HCA=^mlx5: Use all cards except those starting with mlx5. + - NCCL_IB_HCA=^=mlx5_0,mlx5_1: Use all cards except mlx5_0 and mlx5_1. + """ max_hcas = 32 if not value or value.strip() == "": return available_devices[:max_hcas] @@ -329,14 +344,11 @@ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: value = value.strip() result = [] is_exclude = value.startswith("^") + if is_exclude: + value = value.removeprefix("^") is_exact_match = value.startswith("=") - - prefix_chars_processed = 0 - while value and value[0] in ("^", "=") and prefix_chars_processed < 2: - value = value[1:] - is_exact_match = is_exact_match or value.startswith("=") - is_exclude = is_exclude or value.startswith("^") - prefix_chars_processed += 1 + if is_exact_match: + value = value.removeprefix("=") device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] @@ -361,22 +373,22 @@ def _resolve_device_specs( for spec in device_specs: parts = spec.split(":", 1) device_name = parts[0].strip() - port = parts[1].strip() if len(parts) > 1 else None + # HACK: mooncake transfer engine does not support port specification yet, so we ignore it + # port = parts[1].strip() if len(parts) > 1 else None base_devices = ( [device_name] + if device_name in available_devices + else [] if is_exact_match else [dev for dev in available_devices if dev.startswith(device_name)] ) - if is_exact_match and device_name not in available_devices: - logger.warning(f"Device '{device_name}' not found in available devices.") - continue if not base_devices: - logger.warning(f"No devices match the prefix '{device_name}'.") + logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.") continue for base_dev in base_devices: - devices.add(f"{base_dev}:{port}" if port else f"{base_dev}") + devices.add(base_dev) return sorted(devices) diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index 31f5e4a..9b0951a 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -51,37 +51,40 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): @pytest.mark.parametrize( "input_value,expected", [ - ("", "mock_available_devices"), # Special marker for fixture - (" \t\n ", "mock_available_devices"), # Special marker for fixture - ("None", []), - ("^", "mock_available_devices"), # Special marker for fixture - ("^=", "mock_available_devices"), - ("=^", "mock_available_devices"), - ("^^", "mock_available_devices"), - ("=", []), - ("==", []), + pytest.param("", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="empty string"), + pytest.param(" \t\n ", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="whitespace"), + pytest.param("None", [], id="None string"), + pytest.param("^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret"), + pytest.param("^=", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret-equals"), + pytest.param("=^", [], id="equals-caret"), + pytest.param("^^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="double-caret"), + pytest.param("=", [], id="equals"), + pytest.param("==", [], id="double-equals"), ], ) -def test_parse_basic_cases(input_value: str, expected: str, mock_available_devices: list[str]): +def test_parse_basic_cases( + input_value: str, expected: list[str], mock_available_devices: list[str] +): """Test basic parsing cases: empty string, whitespace, None""" result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) - if expected == "mock_available_devices": - assert result == mock_available_devices - else: - assert result == expected + assert result == expected @pytest.mark.parametrize( "input_value,expected", [ + # prefix ("mlx5_0", ["mlx5_0"]), ("mlx5", ["mlx5_0", "mlx5_1"]), + # exact match ("=mlx5_0", ["mlx5_0"]), ("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), - ("mlx5_0:1,mlx5_1:2", ["mlx5_0:1", "mlx5_1:2"]), - ("mlx5_0:1,mlx5_1", ["mlx5_0:1", "mlx5_1"]), + # ignore ports, whitespace and duplicated commas + ("mlx5_0:1,mlx5_1:2", ["mlx5_0", "mlx5_1"]), + ("mlx5_0:1,mlx5_1", ["mlx5_0", "mlx5_1"]), (" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]), ("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]), + # exclusion ("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]), ("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]), ("^mlx5", ["mlx4_0", "mlx4_1"]), @@ -100,15 +103,15 @@ def test_parse_various_patterns( @pytest.mark.parametrize( "input_value,expected_result,expected_warning", [ - ("=mlx5_100", [], "Device 'mlx5_100' not found in available devices."), - ("mlx5_100", [], "No devices match the prefix 'mlx5_100'."), + ("=mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=True."), + ("mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=False."), ( "^mlx5_100", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], - "No devices match the prefix 'mlx5_100'.", + "No RDMA device match device_name='mlx5_100' where is_exact_match=False.", ), - ("mlx6", [], "No devices match the prefix 'mlx6'."), - ("=mlx6", [], "Device 'mlx6' not found in available devices."), + ("mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=False."), + ("=mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=True."), ], ) def test_parse_exact_match_with_nonexistent_device( From 64c2d1083fc152f3e1c291e33fe887112a076b9c Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 07:28:01 +0000 Subject: [PATCH 07/13] feat: add cpu-gpu markers for tests --- pyproject.toml | 6 ++++++ tests/test_rdma_parser.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index df26fb2..e4b8873 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,3 +158,9 @@ inline-quotes = "double" [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" + +[tool.pytest.ini_options] +markers = [ + "cpu: marks tests as CPU test (deselect with '-m \"not cpu\"')", + "gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')", +] diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index 9b0951a..d21c685 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -17,6 +17,7 @@ def mock_available_devices() -> list[str]: return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"] +@pytest.mark.cpu def test_detect_ibv_list(): """Test detection of _ibv_get_device_list function""" # Skip this test if no real infiniband devices exist @@ -29,6 +30,7 @@ def test_detect_ibv_list(): assert isinstance(devices, list) +@pytest.mark.cpu def test_parse_max_hcas_limit(): """Test maximum HCA quantity limit""" # Create mock data with more than 32 devices @@ -38,6 +40,7 @@ def test_parse_max_hcas_limit(): assert result == many_devices[:32] +@pytest.mark.cpu def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): """Test _get_rdma_devices with no environment variables""" with ( @@ -48,6 +51,7 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): assert sorted(devices) == sorted(mock_available_devices) +@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected", [ @@ -70,6 +74,7 @@ def test_parse_basic_cases( assert result == expected +@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected", [ @@ -100,6 +105,7 @@ def test_parse_various_patterns( assert result == expected +@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected_result,expected_warning", [ @@ -127,6 +133,7 @@ def test_parse_exact_match_with_nonexistent_device( mock_logger.warning.assert_called_once_with(expected_warning) +@pytest.mark.cpu @pytest.mark.parametrize( "env_var_name,env_var_value,expected_devices", [ @@ -154,6 +161,7 @@ def test_get_rdma_devices_with_env_vars( assert sorted(devices) == sorted(expected_devices) +@pytest.mark.cpu @pytest.mark.parametrize( "local_rank,gpu_count,expected_device", [ @@ -171,6 +179,7 @@ def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_devi assert device == expected_device +@pytest.mark.cpu @pytest.mark.parametrize( "local_rank,gpu_count,devices,error", [ From d641312932b17c29fe1d99facd8d7ebed741eb49 Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 07:51:29 +0000 Subject: [PATCH 08/13] feat: ci added cpu tests workflow --- .github/workflows/cpu-tests.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/workflows/cpu-tests.yml diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml new file mode 100644 index 0000000..aafd4e5 --- /dev/null +++ b/.github/workflows/cpu-tests.yml @@ -0,0 +1,27 @@ +name: CPU Tests + +on: + push: + branches: [main] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install .[p2p] + - name: Do CPU tests with pytest + run: | + pytest -v -m "cpu" tests/ From 4520370b9fb61da0adfa386d5df731f1c7cb575d Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 07:19:09 +0000 Subject: [PATCH 09/13] misc: fix pr issues --- checkpoint_engine/ps.py | 45 +++++++++++++++++++++++---------------- tests/test_rdma_parser.py | 45 +++++++++++++++++++++------------------ 2 files changed, 51 insertions(+), 39 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index a859ab5..b5a52e9 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -322,6 +322,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: + """ + The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8. + The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662. + + The list is comma-separated; port numbers are NOT supported yet. + An optional prefix '^' indicates the list is an exclude list. + A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix. + Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported. + + Examples: + - NCCL_IB_HCA=mlx5: Use all cards starting with mlx5. + - NCCL_IB_HCA==mlx5_0,mlx5_1 : Use specific cards mlx5_0 and mlx5_1. + - NCCL_IB_HCA=^mlx5: Use all cards except those starting with mlx5. + - NCCL_IB_HCA=^=mlx5_0,mlx5_1: Use all cards except mlx5_0 and mlx5_1. + """ max_hcas = 32 if not value or value.strip() == "": return available_devices[:max_hcas] @@ -329,23 +344,17 @@ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: value = value.strip() result = [] is_exclude = value.startswith("^") + if is_exclude: + value = value.removeprefix("^") is_exact_match = value.startswith("=") - - prefix_chars_processed = 0 - while value and value[0] in ("^", "=") and prefix_chars_processed < 2: - value = value[1:] - is_exact_match = is_exact_match or value.startswith("=") - is_exclude = is_exclude or value.startswith("^") - prefix_chars_processed += 1 + if is_exact_match: + value = value.removeprefix("=") device_specs = [spec.strip() for spec in value.split(",") if spec.strip()] + result = _resolve_device_specs(device_specs, is_exact_match, available_devices) if is_exclude: - excluded_devices = _resolve_device_specs(device_specs, is_exact_match, available_devices) - result = [dev for dev in available_devices if dev not in excluded_devices] - else: - result = _resolve_device_specs(device_specs, is_exact_match, available_devices) - + result = [dev for dev in available_devices if dev not in result] if len(result) > max_hcas: result = result[:max_hcas] @@ -361,22 +370,22 @@ def _resolve_device_specs( for spec in device_specs: parts = spec.split(":", 1) device_name = parts[0].strip() - port = parts[1].strip() if len(parts) > 1 else None + # HACK: mooncake transfer engine does not support port specification yet, so we ignore it + # port = parts[1].strip() if len(parts) > 1 else None base_devices = ( [device_name] + if device_name in available_devices + else [] if is_exact_match else [dev for dev in available_devices if dev.startswith(device_name)] ) - if is_exact_match and device_name not in available_devices: - logger.warning(f"Device '{device_name}' not found in available devices.") - continue if not base_devices: - logger.warning(f"No devices match the prefix '{device_name}'.") + logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.") continue for base_dev in base_devices: - devices.add(f"{base_dev}:{port}" if port else f"{base_dev}") + devices.add(base_dev) return sorted(devices) diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index 31f5e4a..9b0951a 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -51,37 +51,40 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): @pytest.mark.parametrize( "input_value,expected", [ - ("", "mock_available_devices"), # Special marker for fixture - (" \t\n ", "mock_available_devices"), # Special marker for fixture - ("None", []), - ("^", "mock_available_devices"), # Special marker for fixture - ("^=", "mock_available_devices"), - ("=^", "mock_available_devices"), - ("^^", "mock_available_devices"), - ("=", []), - ("==", []), + pytest.param("", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="empty string"), + pytest.param(" \t\n ", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="whitespace"), + pytest.param("None", [], id="None string"), + pytest.param("^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret"), + pytest.param("^=", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret-equals"), + pytest.param("=^", [], id="equals-caret"), + pytest.param("^^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="double-caret"), + pytest.param("=", [], id="equals"), + pytest.param("==", [], id="double-equals"), ], ) -def test_parse_basic_cases(input_value: str, expected: str, mock_available_devices: list[str]): +def test_parse_basic_cases( + input_value: str, expected: list[str], mock_available_devices: list[str] +): """Test basic parsing cases: empty string, whitespace, None""" result = _parse_NCCL_IB_HCA(input_value, mock_available_devices) - if expected == "mock_available_devices": - assert result == mock_available_devices - else: - assert result == expected + assert result == expected @pytest.mark.parametrize( "input_value,expected", [ + # prefix ("mlx5_0", ["mlx5_0"]), ("mlx5", ["mlx5_0", "mlx5_1"]), + # exact match ("=mlx5_0", ["mlx5_0"]), ("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]), - ("mlx5_0:1,mlx5_1:2", ["mlx5_0:1", "mlx5_1:2"]), - ("mlx5_0:1,mlx5_1", ["mlx5_0:1", "mlx5_1"]), + # ignore ports, whitespace and duplicated commas + ("mlx5_0:1,mlx5_1:2", ["mlx5_0", "mlx5_1"]), + ("mlx5_0:1,mlx5_1", ["mlx5_0", "mlx5_1"]), (" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]), ("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]), + # exclusion ("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]), ("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]), ("^mlx5", ["mlx4_0", "mlx4_1"]), @@ -100,15 +103,15 @@ def test_parse_various_patterns( @pytest.mark.parametrize( "input_value,expected_result,expected_warning", [ - ("=mlx5_100", [], "Device 'mlx5_100' not found in available devices."), - ("mlx5_100", [], "No devices match the prefix 'mlx5_100'."), + ("=mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=True."), + ("mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=False."), ( "^mlx5_100", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], - "No devices match the prefix 'mlx5_100'.", + "No RDMA device match device_name='mlx5_100' where is_exact_match=False.", ), - ("mlx6", [], "No devices match the prefix 'mlx6'."), - ("=mlx6", [], "Device 'mlx6' not found in available devices."), + ("mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=False."), + ("=mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=True."), ], ) def test_parse_exact_match_with_nonexistent_device( From 378d44604cfd9973ffb3ca977d67140081232eeb Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 07:28:01 +0000 Subject: [PATCH 10/13] feat: add cpu-gpu markers for tests --- pyproject.toml | 6 ++++++ tests/test_rdma_parser.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index df26fb2..e4b8873 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,3 +158,9 @@ inline-quotes = "double" [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" + +[tool.pytest.ini_options] +markers = [ + "cpu: marks tests as CPU test (deselect with '-m \"not cpu\"')", + "gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')", +] diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index 9b0951a..d21c685 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -17,6 +17,7 @@ def mock_available_devices() -> list[str]: return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"] +@pytest.mark.cpu def test_detect_ibv_list(): """Test detection of _ibv_get_device_list function""" # Skip this test if no real infiniband devices exist @@ -29,6 +30,7 @@ def test_detect_ibv_list(): assert isinstance(devices, list) +@pytest.mark.cpu def test_parse_max_hcas_limit(): """Test maximum HCA quantity limit""" # Create mock data with more than 32 devices @@ -38,6 +40,7 @@ def test_parse_max_hcas_limit(): assert result == many_devices[:32] +@pytest.mark.cpu def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): """Test _get_rdma_devices with no environment variables""" with ( @@ -48,6 +51,7 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): assert sorted(devices) == sorted(mock_available_devices) +@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected", [ @@ -70,6 +74,7 @@ def test_parse_basic_cases( assert result == expected +@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected", [ @@ -100,6 +105,7 @@ def test_parse_various_patterns( assert result == expected +@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected_result,expected_warning", [ @@ -127,6 +133,7 @@ def test_parse_exact_match_with_nonexistent_device( mock_logger.warning.assert_called_once_with(expected_warning) +@pytest.mark.cpu @pytest.mark.parametrize( "env_var_name,env_var_value,expected_devices", [ @@ -154,6 +161,7 @@ def test_get_rdma_devices_with_env_vars( assert sorted(devices) == sorted(expected_devices) +@pytest.mark.cpu @pytest.mark.parametrize( "local_rank,gpu_count,expected_device", [ @@ -171,6 +179,7 @@ def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_devi assert device == expected_device +@pytest.mark.cpu @pytest.mark.parametrize( "local_rank,gpu_count,devices,error", [ From 51bff640864e12b07d22b2d6a196d35ace118fcf Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 07:51:29 +0000 Subject: [PATCH 11/13] feat: ci added cpu tests workflow --- .github/workflows/cpu-tests.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/workflows/cpu-tests.yml diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml new file mode 100644 index 0000000..aafd4e5 --- /dev/null +++ b/.github/workflows/cpu-tests.yml @@ -0,0 +1,27 @@ +name: CPU Tests + +on: + push: + branches: [main] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install .[p2p] + - name: Do CPU tests with pytest + run: | + pytest -v -m "cpu" tests/ From 308fa798fbcf9ced8aa857afb41c289df4f8d264 Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 08:25:25 +0000 Subject: [PATCH 12/13] feat: enhance CPU test workflow for RDMA device detection --- .github/workflows/cpu-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index aafd4e5..67052eb 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -3,6 +3,9 @@ name: CPU Tests on: push: branches: [main] + pull_request: + types: [opened, synchronize, reopened] + permissions: contents: read From ac709d4a228924fad8fa30116565ba0ab42cd674 Mon Sep 17 00:00:00 2001 From: specture724 Date: Mon, 20 Oct 2025 09:09:21 +0000 Subject: [PATCH 13/13] misc: fix pr issues --- .github/workflows/cpu-tests.yml | 4 ++-- checkpoint_engine/ps.py | 8 ++++---- pyproject.toml | 1 - tests/test_rdma_parser.py | 9 --------- 4 files changed, 6 insertions(+), 16 deletions(-) diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index 67052eb..6219b55 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: "3.x" + python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -27,4 +27,4 @@ jobs: pip install .[p2p] - name: Do CPU tests with pytest run: | - pytest -v -m "cpu" tests/ + pytest -v -m "not gpu" tests/ diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index b5a52e9..1493a69 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -332,10 +332,10 @@ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]: Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported. Examples: - - NCCL_IB_HCA=mlx5: Use all cards starting with mlx5. - - NCCL_IB_HCA==mlx5_0,mlx5_1 : Use specific cards mlx5_0 and mlx5_1. - - NCCL_IB_HCA=^mlx5: Use all cards except those starting with mlx5. - - NCCL_IB_HCA=^=mlx5_0,mlx5_1: Use all cards except mlx5_0 and mlx5_1. + - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`. + - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`. + - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`. + - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`. """ max_hcas = 32 if not value or value.strip() == "": diff --git a/pyproject.toml b/pyproject.toml index e4b8873..c200382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,6 +161,5 @@ ban-relative-imports = "all" [tool.pytest.ini_options] markers = [ - "cpu: marks tests as CPU test (deselect with '-m \"not cpu\"')", "gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')", ] diff --git a/tests/test_rdma_parser.py b/tests/test_rdma_parser.py index d21c685..9b0951a 100644 --- a/tests/test_rdma_parser.py +++ b/tests/test_rdma_parser.py @@ -17,7 +17,6 @@ def mock_available_devices() -> list[str]: return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"] -@pytest.mark.cpu def test_detect_ibv_list(): """Test detection of _ibv_get_device_list function""" # Skip this test if no real infiniband devices exist @@ -30,7 +29,6 @@ def test_detect_ibv_list(): assert isinstance(devices, list) -@pytest.mark.cpu def test_parse_max_hcas_limit(): """Test maximum HCA quantity limit""" # Create mock data with more than 32 devices @@ -40,7 +38,6 @@ def test_parse_max_hcas_limit(): assert result == many_devices[:32] -@pytest.mark.cpu def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): """Test _get_rdma_devices with no environment variables""" with ( @@ -51,7 +48,6 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]): assert sorted(devices) == sorted(mock_available_devices) -@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected", [ @@ -74,7 +70,6 @@ def test_parse_basic_cases( assert result == expected -@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected", [ @@ -105,7 +100,6 @@ def test_parse_various_patterns( assert result == expected -@pytest.mark.cpu @pytest.mark.parametrize( "input_value,expected_result,expected_warning", [ @@ -133,7 +127,6 @@ def test_parse_exact_match_with_nonexistent_device( mock_logger.warning.assert_called_once_with(expected_warning) -@pytest.mark.cpu @pytest.mark.parametrize( "env_var_name,env_var_value,expected_devices", [ @@ -161,7 +154,6 @@ def test_get_rdma_devices_with_env_vars( assert sorted(devices) == sorted(expected_devices) -@pytest.mark.cpu @pytest.mark.parametrize( "local_rank,gpu_count,expected_device", [ @@ -179,7 +171,6 @@ def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_devi assert device == expected_device -@pytest.mark.cpu @pytest.mark.parametrize( "local_rank,gpu_count,devices,error", [