Skip to content

Commit 6892fa4

Browse files
authored
Merge pull request #535 from NVIDIA/am/upd-before-access
Optimize slurm updates
2 parents cff6e71 + 756fb17 commit 6892fa4

File tree

4 files changed

+173
-7
lines changed

4 files changed

+173
-7
lines changed

src/cloudai/parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,18 @@ def __init__(self, system_config_path: Path) -> None:
5050
"""
5151
logging.debug(f"Initializing parser with: {system_config_path=}")
5252
self.system_config_path = system_config_path
53+
self._system: Optional[System] = None
5354

5455
@property
5556
def system(self) -> System:
57+
if self._system:
58+
return self._system
59+
5660
try:
57-
return self.parse_system(self.system_config_path)
61+
self._system = self.parse_system(self.system_config_path)
5862
except SystemConfigParsingError:
5963
exit(1) # exit right away to keep error message readable for users
64+
return self._system
6065

6166
def parse(
6267
self,

src/cloudai/systems/slurm/slurm_system.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,10 @@ def get_available_nodes_from_group(
415415
ValueError: If the partition or group is not found, or if the requested number of nodes exceeds the
416416
available nodes.
417417
"""
418-
self.validate_partition_and_group(partition_name, group_name)
419-
420418
self.update()
421419

420+
self.validate_partition_and_group(partition_name, group_name)
421+
422422
grouped_nodes = self.group_nodes_by_state(partition_name, group_name)
423423

424424
try:

src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import logging
1718
from abc import abstractmethod
1819
from datetime import datetime
1920
from pathlib import Path
@@ -44,6 +45,8 @@ def __init__(self, system: SlurmSystem, cmd_args: Dict[str, Any]) -> None:
4445
self.system = system
4546
self.docker_image_url = self.cmd_args.get("docker_image_url", "")
4647

48+
self._node_spec_cache: dict[str, tuple[int, list[str]]] = {}
49+
4750
@abstractmethod
4851
def _container_mounts(self, tr: TestRun) -> list[str]:
4952
"""Return CommandGenStrategy specific container mounts for the test run."""
@@ -132,7 +135,7 @@ def _parse_slurm_args(
132135
KeyError: If partition or essential node settings are missing.
133136
"""
134137
job_name = self.job_name(job_name_prefix)
135-
num_nodes, node_list = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
138+
num_nodes, node_list = self.get_cached_nodes_spec(tr)
136139

137140
slurm_args = {
138141
"job_name": job_name,
@@ -311,7 +314,7 @@ def _ranks_mapping_cmd(self, slurm_args: dict[str, Any], tr: TestRun) -> str:
311314

312315
def _metadata_cmd(self, slurm_args: dict[str, Any], tr: TestRun) -> str:
313316
(tr.output_path.absolute() / "metadata").mkdir(parents=True, exist_ok=True)
314-
num_nodes, _ = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
317+
num_nodes, _ = self.get_cached_nodes_spec(tr)
315318
metadata_script_path = "/cloudai_install"
316319
if "image_path" not in slurm_args:
317320
metadata_script_path = str(self.system.install_path.absolute())
@@ -432,7 +435,7 @@ def _append_sbatch_directives(self, batch_script_content: List[str], args: Dict[
432435
)
433436

434437
def _append_nodes_related_directives(self, content: List[str], args: Dict[str, Any], tr: TestRun) -> Optional[Path]:
435-
num_nodes, node_list = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
438+
num_nodes, node_list = self.get_cached_nodes_spec(tr)
436439

437440
if node_list:
438441
content.append("#SBATCH --distribution=arbitrary")
@@ -480,3 +483,19 @@ def gen_srun_success_check(self, tr: TestRun) -> str:
480483
str: The generated command to check the success of the test run.
481484
"""
482485
return ""
486+
487+
def get_cached_nodes_spec(self, tr: TestRun) -> tuple[int, list[str]]:
488+
"""
489+
Get nodes for a test run, using cache when available.
490+
491+
It is needed to avoid multiple calls to the system.get_nodes_by_spec method which in turn queries the Slurm API.
492+
For a single test run it is not required, we can get actual nodes status only once.
493+
"""
494+
cache_key = f"{tr.current_iteration}:{tr.step}:{tr.num_nodes}:{','.join(tr.nodes)}"
495+
496+
if cache_key in self._node_spec_cache:
497+
logging.debug(f"Using cached node allocation for {cache_key}: {self._node_spec_cache[cache_key]}")
498+
return self._node_spec_cache[cache_key]
499+
500+
self._node_spec_cache[cache_key] = self.system.get_nodes_by_spec(tr.num_nodes, tr.nodes)
501+
return self._node_spec_cache[cache_key]

tests/test_slurm_system.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121

2222
import pytest
2323

24-
from cloudai import BaseJob
24+
from cloudai import BaseJob, Test, TestRun, TestTemplate
2525
from cloudai.systems import SlurmSystem
2626
from cloudai.systems.slurm import SlurmNode, SlurmNodeState
2727
from cloudai.systems.slurm.slurm_system import parse_node_list
28+
from cloudai.systems.slurm.strategy.slurm_command_gen_strategy import SlurmCommandGenStrategy
29+
from cloudai.workloads.nccl_test import NCCLCmdArgs, NCCLTestDefinition
2830

2931

3032
def test_parse_squeue_output(slurm_system):
@@ -388,3 +390,143 @@ def test_with_commas(self, slurm_system: SlurmSystem):
388390
def test_colon_invalid_syntax(self, slurm_system: SlurmSystem, spec: str):
389391
with pytest.raises(ValueError):
390392
slurm_system.parse_nodes([spec])
393+
394+
395+
class TestGetNodesBySpec:
396+
def test_empty_nodes_list(self, slurm_system: SlurmSystem):
397+
num_nodes, node_list = slurm_system.get_nodes_by_spec(3, [])
398+
assert num_nodes == 3
399+
assert node_list == []
400+
401+
@pytest.mark.parametrize(
402+
"in_nnodes,in_nodes,exp_nnodes,exp_nodes",
403+
[
404+
(2, ["node0[1-3]"], 3, ["node01", "node02", "node03"]),
405+
(4, ["node01,node02"], 2, ["node01", "node02"]),
406+
(1, ["node01,node02"], 2, ["node01", "node02"]),
407+
],
408+
)
409+
@patch("cloudai.systems.slurm.slurm_system.SlurmSystem.parse_nodes")
410+
def test_explicit_node_names(
411+
self,
412+
mock_parse_nodes: Mock,
413+
slurm_system: SlurmSystem,
414+
in_nnodes: int,
415+
in_nodes: list[str],
416+
exp_nnodes: int,
417+
exp_nodes: list[str],
418+
):
419+
mock_parse_nodes.return_value = exp_nodes
420+
421+
num_nodes, node_list = slurm_system.get_nodes_by_spec(in_nnodes, in_nodes)
422+
423+
mock_parse_nodes.assert_called_once_with(in_nodes)
424+
assert num_nodes == exp_nnodes
425+
assert node_list == exp_nodes
426+
427+
428+
class ConcreteSlurmStrategy(SlurmCommandGenStrategy):
429+
def _container_mounts(self, tr: TestRun) -> list[str]:
430+
return []
431+
432+
def generate_test_command(self, env_vars, cmd_args, tr):
433+
return ["test_command"]
434+
435+
def job_name(self, job_name_prefix: str) -> str:
436+
return "job_name"
437+
438+
439+
@pytest.fixture
440+
def test_run(slurm_system: SlurmSystem) -> TestRun:
441+
test_run = TestRun(
442+
name="test_run",
443+
test=Test(
444+
test_definition=NCCLTestDefinition(
445+
name="test_run", description="test_run", test_template_name="nccl", cmd_args=NCCLCmdArgs()
446+
),
447+
test_template=TestTemplate(slurm_system),
448+
),
449+
num_nodes=2,
450+
nodes=["main:group1:2"],
451+
output_path=slurm_system.output_path,
452+
)
453+
454+
test_run.output_path.mkdir(parents=True, exist_ok=True)
455+
456+
return test_run
457+
458+
459+
class TestSlurmCommandGenStrategyCache:
460+
@patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec")
461+
def test_strategy_caching(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun):
462+
mock_get_nodes.return_value = (2, ["node01", "node02"])
463+
464+
strategy = ConcreteSlurmStrategy(slurm_system, {})
465+
466+
# First call to get nodes
467+
res = strategy.get_cached_nodes_spec(test_run)
468+
assert mock_get_nodes.call_count == 1
469+
assert res == (2, ["node01", "node02"])
470+
471+
# Second call with same parameters should use cache
472+
res = strategy.get_cached_nodes_spec(test_run)
473+
assert mock_get_nodes.call_count == 1
474+
assert res == (2, ["node01", "node02"])
475+
476+
# Different node spec should call get_nodes_by_spec again
477+
test_run.num_nodes = 1
478+
test_run.nodes = []
479+
strategy.get_cached_nodes_spec(test_run)
480+
assert mock_get_nodes.call_count == 2
481+
482+
test_run.num_nodes = 2
483+
test_run.nodes = ["node01", "node03"]
484+
strategy.get_cached_nodes_spec(test_run)
485+
assert mock_get_nodes.call_count == 3
486+
487+
@patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec")
488+
def test_per_test_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun):
489+
mock_get_nodes.side_effect = [(2, ["node01", "node02"]), (2, ["node03", "node04"])]
490+
491+
# Simulate two different test cases
492+
strategy1, strategy2 = ConcreteSlurmStrategy(slurm_system, {}), ConcreteSlurmStrategy(slurm_system, {})
493+
494+
res = strategy1.get_cached_nodes_spec(test_run)
495+
assert mock_get_nodes.call_count == 1
496+
assert res == (2, ["node01", "node02"])
497+
498+
res = strategy2.get_cached_nodes_spec(test_run)
499+
assert mock_get_nodes.call_count == 2
500+
assert res == (2, ["node03", "node04"])
501+
502+
assert strategy1._node_spec_cache != strategy2._node_spec_cache, "Caches should be different"
503+
504+
@patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec")
505+
def test_per_iteration_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun):
506+
mock_get_nodes.side_effect = [(2, ["node01", "node02"]), (2, ["node03", "node04"])]
507+
508+
strategy = ConcreteSlurmStrategy(slurm_system, {})
509+
510+
res = strategy.get_cached_nodes_spec(test_run)
511+
assert mock_get_nodes.call_count == 1
512+
assert res == (2, ["node01", "node02"])
513+
514+
test_run.current_iteration = 1
515+
res = strategy.get_cached_nodes_spec(test_run)
516+
assert mock_get_nodes.call_count == 2
517+
assert res == (2, ["node03", "node04"])
518+
519+
@patch("cloudai.systems.slurm.SlurmSystem.get_nodes_by_spec")
520+
def test_per_step_isolation(self, mock_get_nodes: Mock, slurm_system: SlurmSystem, test_run: TestRun):
521+
mock_get_nodes.side_effect = [(2, ["node01", "node02"]), (2, ["node03", "node04"])]
522+
523+
strategy = ConcreteSlurmStrategy(slurm_system, {})
524+
525+
res = strategy.get_cached_nodes_spec(test_run)
526+
assert mock_get_nodes.call_count == 1
527+
assert res == (2, ["node01", "node02"])
528+
529+
test_run.step = 1
530+
res = strategy.get_cached_nodes_spec(test_run)
531+
assert mock_get_nodes.call_count == 2
532+
assert res == (2, ["node03", "node04"])

0 commit comments

Comments
 (0)