|
21 | 21 |
|
22 | 22 | import pytest |
23 | 23 |
|
24 | | -from cloudai import BaseJob |
| 24 | +from cloudai import BaseJob, Test, TestRun, TestTemplate |
25 | 25 | from cloudai.systems import SlurmSystem |
26 | 26 | from cloudai.systems.slurm import SlurmNode, SlurmNodeState |
27 | 27 | from cloudai.systems.slurm.slurm_system import parse_node_list |
| 28 | +from cloudai.systems.slurm.strategy.slurm_command_gen_strategy import SlurmCommandGenStrategy |
| 29 | +from cloudai.workloads.nccl_test import NCCLCmdArgs, NCCLTestDefinition |
28 | 30 |
|
29 | 31 |
|
30 | 32 | def test_parse_squeue_output(slurm_system): |
@@ -388,3 +390,143 @@ def test_with_commas(self, slurm_system: SlurmSystem): |
388 | 390 | def test_colon_invalid_syntax(self, slurm_system: SlurmSystem, spec: str): |
389 | 391 | with pytest.raises(ValueError): |
390 | 392 | slurm_system.parse_nodes([spec]) |
| 393 | + |
| 394 | + |
| 395 | +class TestGetNodesBySpec: |
| 396 | + def test_empty_nodes_list(self, slurm_system: SlurmSystem): |
| 397 | + num_nodes, node_list = slurm_system.get_nodes_by_spec(3, []) |
| 398 | + assert num_nodes == 3 |
| 399 | + assert node_list == [] |
| 400 | + |
| 401 | + @pytest.mark.parametrize( |
| 402 | + "in_nnodes,in_nodes,exp_nnodes,exp_nodes", |
| 403 | + [ |
| 404 | + (2, ["node0[1-3]"], 3, ["node01", "node02", "node03"]), |
| 405 | + (4, ["node01,node02"], 2, ["node01", "node02"]), |
| 406 | + (1, ["node01,node02"], 2, ["node01", "node02"]), |
| 407 | + ], |
| 408 | + ) |
| 409 | + @patch("cloudai.systems.slurm.slurm_system.SlurmSystem.parse_nodes") |
| 410 | + def test_explicit_node_names( |
| 411 | + self, |
| 412 | + mock_parse_nodes: Mock, |
| 413 | + slurm_system: SlurmSystem, |
| 414 | + in_nnodes: int, |
| 415 | + in_nodes: list[str], |
| 416 | + exp_nnodes: int, |
| 417 | + exp_nodes: list[str], |
| 418 | + ): |
| 419 | + mock_parse_nodes.return_value = exp_nodes |
| 420 | + |
| 421 | + num_nodes, node_list = slurm_system.get_nodes_by_spec(in_nnodes, in_nodes) |
| 422 | + |
| 423 | + mock_parse_nodes.assert_called_once_with(in_nodes) |
| 424 | + assert num_nodes == exp_nnodes |
| 425 | + assert node_list == exp_nodes |
| 426 | + |
| 427 | + |
| 428 | +class ConcreteSlurmStrategy(SlurmCommandGenStrategy): |
| 429 | + def _container_mounts(self, tr: TestRun) -> list[str]: |
| 430 | + return [] |
| 431 | + |
| 432 | + def generate_test_command(self, env_vars, cmd_args, tr): |
| 433 | + return ["test_command"] |
| 434 | + |
| 435 | + def job_name(self, job_name_prefix: str) -> str: |
| 436 | + return "job_name" |
| 437 | + |
| 438 | + |
| 439 | +@pytest.fixture |
| 440 | +def test_run(slurm_system: SlurmSystem) -> TestRun: |
| 441 | + test_run = TestRun( |
| 442 | + name="test_run", |
| 443 | + test=Test( |
| 444 | + test_definition=NCCLTestDefinition( |
| 445 | + name="test_run", description="test_run", test_template_name="nccl", cmd_args=NCCLCmdArgs() |
| 446 | + ), |
| 447 | + test_template=TestTemplate(slurm_system), |
| 448 | + ), |
| 449 | + num_nodes=2, |
| 450 | + nodes=["main:group1:2"], |
| 451 | + output_path=slurm_system.output_path, |
| 452 | + ) |
| 453 | + |
| 454 | + test_run.output_path.mkdir(parents=True, exist_ok=True) |
| 455 | + |
| 456 | + return test_run |
| 457 | + |
| 458 | + |
| 459 | +class TestSlurmCommandGenStrategyCache: |
| 460 | + @patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec") |
| 461 | + def test_strategy_caching(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun): |
| 462 | + mock_get_nodes.return_value = (2, ["node01", "node02"]) |
| 463 | + |
| 464 | + strategy = ConcreteSlurmStrategy(slurm_system, {}) |
| 465 | + |
| 466 | + # First call to get nodes |
| 467 | + res = strategy.get_cached_nodes_spec(test_run) |
| 468 | + assert mock_get_nodes.call_count == 1 |
| 469 | + assert res == (2, ["node01", "node02"]) |
| 470 | + |
| 471 | + # Second call with same parameters should use cache |
| 472 | + res = strategy.get_cached_nodes_spec(test_run) |
| 473 | + assert mock_get_nodes.call_count == 1 |
| 474 | + assert res == (2, ["node01", "node02"]) |
| 475 | + |
| 476 | + # Different node spec should call get_nodes_by_spec again |
| 477 | + test_run.num_nodes = 1 |
| 478 | + test_run.nodes = [] |
| 479 | + strategy.get_cached_nodes_spec(test_run) |
| 480 | + assert mock_get_nodes.call_count == 2 |
| 481 | + |
| 482 | + test_run.num_nodes = 2 |
| 483 | + test_run.nodes = ["node01", "node03"] |
| 484 | + strategy.get_cached_nodes_spec(test_run) |
| 485 | + assert mock_get_nodes.call_count == 3 |
| 486 | + |
| 487 | + @patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec") |
| 488 | + def test_per_test_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun): |
| 489 | + mock_get_nodes.side_effect = [(2, ["node01", "node02"]), (2, ["node03", "node04"])] |
| 490 | + |
| 491 | + # Simulate two different test cases |
| 492 | + strategy1, strategy2 = ConcreteSlurmStrategy(slurm_system, {}), ConcreteSlurmStrategy(slurm_system, {}) |
| 493 | + |
| 494 | + res = strategy1.get_cached_nodes_spec(test_run) |
| 495 | + assert mock_get_nodes.call_count == 1 |
| 496 | + assert res == (2, ["node01", "node02"]) |
| 497 | + |
| 498 | + res = strategy2.get_cached_nodes_spec(test_run) |
| 499 | + assert mock_get_nodes.call_count == 2 |
| 500 | + assert res == (2, ["node03", "node04"]) |
| 501 | + |
| 502 | + assert strategy1._node_spec_cache != strategy2._node_spec_cache, "Caches should be different" |
| 503 | + |
| 504 | + @patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec") |
| 505 | + def test_per_iteration_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun): |
| 506 | + mock_get_nodes.side_effect = [(2, ["node01", "node02"]), (2, ["node03", "node04"])] |
| 507 | + |
| 508 | + strategy = ConcreteSlurmStrategy(slurm_system, {}) |
| 509 | + |
| 510 | + res = strategy.get_cached_nodes_spec(test_run) |
| 511 | + assert mock_get_nodes.call_count == 1 |
| 512 | + assert res == (2, ["node01", "node02"]) |
| 513 | + |
| 514 | + test_run.current_iteration = 1 |
| 515 | + res = strategy.get_cached_nodes_spec(test_run) |
| 516 | + assert mock_get_nodes.call_count == 2 |
| 517 | + assert res == (2, ["node03", "node04"]) |
| 518 | + |
| 519 | + @patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec") |
| 520 | + def test_per_step_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun): |
| 521 | + mock_get_nodes.side_effect = [(2, ["node01", "node02"]), (2, ["node03", "node04"])] |
| 522 | + |
| 523 | + strategy = ConcreteSlurmStrategy(slurm_system, {}) |
| 524 | + |
| 525 | + res = strategy.get_cached_nodes_spec(test_run) |
| 526 | + assert mock_get_nodes.call_count == 1 |
| 527 | + assert res == (2, ["node01", "node02"]) |
| 528 | + |
| 529 | + test_run.step = 1 |
| 530 | + res = strategy.get_cached_nodes_spec(test_run) |
| 531 | + assert mock_get_nodes.call_count == 2 |
| 532 | + assert res == (2, ["node03", "node04"]) |
0 commit comments