Skip to content

Commit b645a1e

Browse files
authored
Merge pull request #140 from NVIDIA/am/slurm-common
Move parts of srun CLI generation into base class
2 parents d35361c + 2383452 commit b645a1e

File tree

8 files changed

+129
-96
lines changed

8 files changed

+129
-96
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
[project]
1717
name = "cloudai"
18-
version = "0.7.12"
18+
version = "0.7.13"
1919
dependencies = [
2020
"bokeh==3.4.1",
2121
"pandas==2.2.1",

src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def gen_exec_command(
4949

5050
job_name_prefix = "chakra_replay"
5151
slurm_args = self._parse_slurm_args(job_name_prefix, final_env_vars, final_cmd_args, num_nodes, nodes)
52-
srun_command = self._generate_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
52+
srun_command = self.generate_full_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
5353
return self._write_sbatch_script(slurm_args, env_vars_str, srun_command, output_path)
5454

5555
def _parse_slurm_args(
@@ -69,23 +69,15 @@ def _parse_slurm_args(
6969

7070
return base_args
7171

72-
def _generate_srun_command(
73-
self,
74-
slurm_args: Dict[str, Any],
75-
env_vars: Dict[str, str],
76-
cmd_args: Dict[str, str],
77-
extra_cmd_args: str,
78-
) -> str:
72+
def generate_test_command(
73+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
74+
) -> List[str]:
7975
srun_command_parts = [
80-
"srun",
81-
f"--mpi={slurm_args['mpi']}",
82-
f'--container-image={slurm_args["image_path"]}',
83-
f'--container-mounts={slurm_args["container_mounts"]}',
8476
"python /workspace/param/train/comms/pt/commsTraceReplay.py",
8577
f'--trace-type {cmd_args["trace_type"]}',
8678
f'--trace-path {cmd_args["trace_path"]}',
8779
f'--backend {cmd_args["backend"]}',
8880
f'--device {cmd_args["device"]}',
8981
extra_cmd_args,
9082
]
91-
return " \\\n".join(srun_command_parts)
83+
return srun_command_parts

src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def gen_exec_command(
5252
env_vars_str = self._format_env_vars(final_env_vars)
5353

5454
slurm_args = self._parse_slurm_args("JaxToolbox", final_env_vars, final_cmd_args, num_nodes, nodes)
55-
srun_command = self._generate_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
55+
srun_command = self.generate_full_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
5656
return self._write_sbatch_script(slurm_args, env_vars_str, srun_command, output_path)
5757

5858
def _format_xla_flags(self, cmd_args: Dict[str, str]) -> str:
@@ -131,18 +131,14 @@ def _parse_slurm_args(
131131

132132
return base_args
133133

134-
def _generate_srun_command(
135-
self,
136-
slurm_args: Dict[str, Any],
137-
env_vars: Dict[str, str],
138-
cmd_args: Dict[str, str],
139-
extra_cmd_args: str,
134+
def generate_full_srun_command(
135+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
140136
) -> str:
141137
self._create_run_script(slurm_args, env_vars, cmd_args, extra_cmd_args)
142138

143139
srun_command_parts = [
144140
"srun",
145-
f"--mpi={slurm_args['mpi']}",
141+
f"--mpi={self.slurm_system.mpi}",
146142
"--export=ALL",
147143
f"-o {slurm_args['output']}",
148144
f"-e {slurm_args['error']}",

src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def gen_exec_command(
4545
raise KeyError("Subtest name not specified or unsupported.")
4646

4747
slurm_args = self._parse_slurm_args(subtest_name, final_env_vars, final_cmd_args, num_nodes, nodes)
48-
srun_command = self._generate_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
48+
srun_command = self.generate_full_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
4949
return self._write_sbatch_script(slurm_args, env_vars_str, srun_command, output_path)
5050

5151
def _parse_slurm_args(
@@ -76,24 +76,10 @@ def _parse_slurm_args(
7676

7777
return base_args
7878

79-
def _generate_srun_command(
80-
self,
81-
slurm_args: Dict[str, Any],
82-
env_vars: Dict[str, str],
83-
cmd_args: Dict[str, str],
84-
extra_cmd_args: str,
85-
) -> str:
86-
srun_command_parts = [
87-
"srun",
88-
f"--mpi={slurm_args['mpi']}",
89-
f"--container-image={slurm_args['image_path']}",
90-
]
91-
92-
if slurm_args.get("container_mounts"):
93-
srun_command_parts.append(f"--container-mounts={slurm_args['container_mounts']}")
94-
95-
srun_command_parts.append(f"/usr/local/bin/{cmd_args['subtest_name']}")
96-
79+
def generate_test_command(
80+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
81+
) -> List[str]:
82+
srun_command_parts = [f"/usr/local/bin/{cmd_args['subtest_name']}"]
9783
nccl_test_args = [
9884
"nthreads",
9985
"ngpus",
@@ -119,4 +105,4 @@ def _generate_srun_command(
119105
if extra_cmd_args:
120106
srun_command_parts.append(extra_cmd_args)
121107

122-
return " \\\n".join(srun_command_parts)
108+
return srun_command_parts

src/cloudai/schema/test_template/sleep/slurm_command_gen_strategy.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,13 @@ def gen_exec_command(
3737
env_vars_str = self._format_env_vars(final_env_vars)
3838

3939
slurm_args = self._parse_slurm_args("sleep", final_env_vars, final_cmd_args, num_nodes, nodes)
40-
srun_command = self._generate_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
40+
srun_command = self.generate_full_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
4141
return self._write_sbatch_script(slurm_args, env_vars_str, srun_command, output_path)
4242

43-
def _generate_srun_command(
44-
self,
45-
slurm_args: Dict[str, Any],
46-
env_vars: Dict[str, str],
47-
cmd_args: Dict[str, str],
48-
extra_cmd_args: str,
43+
def generate_full_srun_command(
44+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
4945
) -> str:
50-
srun_command_parts = [
51-
"srun",
52-
f"--mpi={slurm_args['mpi']}",
53-
]
46+
srun_command_parts = ["srun", f"--mpi={self.slurm_system.mpi}"]
5447

5548
sec = cmd_args["seconds"]
5649
srun_command_parts.append(f"sleep {sec}")

src/cloudai/schema/test_template/ucc_test/slurm_command_gen_strategy.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def gen_exec_command(
4444
raise KeyError("Collective name not specified or unsupported.")
4545

4646
slurm_args = self._parse_slurm_args(collective, final_env_vars, final_cmd_args, num_nodes, nodes)
47-
srun_command = self._generate_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
47+
srun_command = self.generate_full_srun_command(slurm_args, final_env_vars, final_cmd_args, extra_cmd_args)
4848
return self._write_sbatch_script(slurm_args, env_vars_str, srun_command, output_path)
4949

5050
def _parse_slurm_args(
@@ -69,19 +69,10 @@ def _parse_slurm_args(
6969

7070
return base_args
7171

72-
def _generate_srun_command(
73-
self,
74-
slurm_args: Dict[str, Any],
75-
env_vars: Dict[str, str],
76-
cmd_args: Dict[str, str],
77-
extra_cmd_args: str,
78-
) -> str:
79-
srun_command_parts = [
80-
"srun",
81-
f"--mpi={slurm_args['mpi']}",
82-
f"--container-image={slurm_args['image_path']}",
83-
"/opt/hpcx/ucc/bin/ucc_perftest",
84-
]
72+
def generate_test_command(
73+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
74+
) -> List[str]:
75+
srun_command_parts = ["/opt/hpcx/ucc/bin/ucc_perftest"]
8576

8677
# Add collective, minimum bytes, and maximum bytes options if available
8778
if "collective" in cmd_args:
@@ -99,4 +90,4 @@ def _generate_srun_command(
9990
if extra_cmd_args:
10091
srun_command_parts.append(extra_cmd_args)
10192

102-
return " \\\n".join(srun_command_parts)
93+
return srun_command_parts

src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@ class SlurmCommandGenStrategy(CommandGenStrategy):
3131
properties and methods.
3232
"""
3333

34-
def __init__(
35-
self,
36-
system: SlurmSystem,
37-
env_vars: Dict[str, Any],
38-
cmd_args: Dict[str, Any],
39-
) -> None:
34+
def __init__(self, system: SlurmSystem, env_vars: Dict[str, Any], cmd_args: Dict[str, Any]) -> None:
4035
"""
4136
Initialize a new SlurmCommandGenStrategy instance.
4237
@@ -125,8 +120,6 @@ def _parse_slurm_args(
125120
slurm_args["account"] = self.slurm_system.account
126121
if self.slurm_system.distribution:
127122
slurm_args["distribution"] = self.slurm_system.distribution
128-
if self.slurm_system.mpi:
129-
slurm_args["mpi"] = self.slurm_system.mpi
130123
if self.slurm_system.gpus_per_node:
131124
slurm_args["gpus_per_node"] = self.slurm_system.gpus_per_node
132125
if self.slurm_system.ntasks_per_node:
@@ -136,27 +129,28 @@ def _parse_slurm_args(
136129

137130
return slurm_args
138131

139-
def _generate_srun_command(
140-
self,
141-
slurm_args: Dict[str, Any],
142-
env_vars: Dict[str, str],
143-
cmd_args: Dict[str, str],
144-
extra_cmd_args: str,
132+
def generate_full_srun_command(
133+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
145134
) -> str:
146-
"""
147-
Generate the srun command string for executing the test.
148-
149-
Args:
150-
slurm_args (Dict[str, Any]): Arguments containing Slurm job settings including image path and container
151-
mounts.
152-
env_vars (Dict[str, str]): Environment variables.
153-
cmd_args (Dict[str, str]): Command-line arguments.
154-
extra_cmd_args (str): Additional command-line arguments to be included in the srun command.
155-
156-
Returns:
157-
str: The complete srun command to execute the test.
158-
"""
159-
return ""
135+
srun_command_parts = self.generate_srun_command(slurm_args, env_vars, cmd_args, extra_cmd_args)
136+
test_command_parts = self.generate_test_command(slurm_args, env_vars, cmd_args, extra_cmd_args)
137+
return " \\\n".join(srun_command_parts + test_command_parts)
138+
139+
def generate_srun_command(
140+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
141+
) -> List[str]:
142+
srun_command_parts = ["srun", f"--mpi={self.slurm_system.mpi}"]
143+
if slurm_args.get("image_path"):
144+
srun_command_parts.append(f'--container-image={slurm_args["image_path"]}')
145+
if slurm_args.get("container_mounts"):
146+
srun_command_parts.append(f'--container-mounts={slurm_args["container_mounts"]}')
147+
148+
return srun_command_parts
149+
150+
def generate_test_command(
151+
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
152+
) -> List[str]:
153+
return []
160154

161155
def _write_sbatch_script(self, args: Dict[str, Any], env_vars_str: str, srun_command: str, output_path: str) -> str:
162156
"""

tests/test_slurm_command_gen_strategy.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pathlib import Path
1717

1818
import pytest
19+
from cloudai.schema.test_template.nccl_test.slurm_command_gen_strategy import NcclTestSlurmCommandGenStrategy
1920
from cloudai.schema.test_template.nemo_launcher.slurm_command_gen_strategy import (
2021
NeMoLauncherSlurmCommandGenStrategy,
2122
)
@@ -39,6 +40,7 @@ def slurm_system(tmp_path: Path) -> SlurmSystem:
3940
SlurmNode(name="node4", partition="main", state=SlurmNodeState.IDLE),
4041
]
4142
},
43+
mpi="fake-mpi",
4244
)
4345
Path(slurm_system.install_path).mkdir()
4446
Path(slurm_system.output_path).mkdir()
@@ -112,6 +114,51 @@ def test_only_nodes(strategy_fixture: SlurmCommandGenStrategy):
112114
assert slurm_args["num_nodes"] == len(nodes)
113115

114116

117+
class TestGenerateSrunCommand__CmdGeneration:
118+
def test_generate_test_command(self, strategy_fixture: SlurmCommandGenStrategy):
119+
test_command = strategy_fixture.generate_test_command({}, {}, {}, "")
120+
assert test_command == []
121+
122+
def test_generate_srun_command(self, strategy_fixture: SlurmCommandGenStrategy):
123+
srun_command = strategy_fixture.generate_srun_command({}, {}, {}, "")
124+
assert srun_command == ["srun", f"--mpi={strategy_fixture.slurm_system.mpi}"]
125+
126+
def test_generate_srun_command_with_container_image(self, strategy_fixture: SlurmCommandGenStrategy):
127+
slurm_args = {"image_path": "fake_image_path"}
128+
srun_command = strategy_fixture.generate_srun_command(slurm_args, {}, {}, "")
129+
assert srun_command == [
130+
"srun",
131+
f"--mpi={strategy_fixture.slurm_system.mpi}",
132+
"--container-image=fake_image_path",
133+
]
134+
135+
def test_generate_srun_command_with_container_image_and_mounts(self, strategy_fixture: SlurmCommandGenStrategy):
136+
slurm_args = {"image_path": "fake_image_path", "container_mounts": "fake_mounts"}
137+
srun_command = strategy_fixture.generate_srun_command(slurm_args, {}, {}, "")
138+
assert srun_command == [
139+
"srun",
140+
f"--mpi={strategy_fixture.slurm_system.mpi}",
141+
"--container-image=fake_image_path",
142+
"--container-mounts=fake_mounts",
143+
]
144+
145+
def test_generate_srun_empty_str(self, strategy_fixture: SlurmCommandGenStrategy):
146+
slurm_args = {"image_path": "", "container_mounts": ""}
147+
srun_command = strategy_fixture.generate_srun_command(slurm_args, {}, {}, "")
148+
assert srun_command == ["srun", f"--mpi={strategy_fixture.slurm_system.mpi}"]
149+
150+
slurm_args = {"image_path": "fake", "container_mounts": ""}
151+
srun_command = strategy_fixture.generate_srun_command(slurm_args, {}, {}, "")
152+
assert srun_command == ["srun", f"--mpi={strategy_fixture.slurm_system.mpi}", "--container-image=fake"]
153+
154+
def test_generate_full_srun_command(self, strategy_fixture: SlurmCommandGenStrategy):
155+
strategy_fixture.generate_srun_command = lambda *_, **__: ["srun", "--test", "test_arg"]
156+
strategy_fixture.generate_test_command = lambda *_, **__: ["test_command"]
157+
158+
full_srun_command = strategy_fixture.generate_full_srun_command({}, {}, {}, "")
159+
assert full_srun_command == " \\\n".join(["srun", "--test", "test_arg", "test_command"])
160+
161+
115162
class TestNeMoLauncherSlurmCommandGenStrategy__GenExecCommand:
116163
@pytest.fixture
117164
def nemo_cmd_gen(self, slurm_system: SlurmSystem) -> NeMoLauncherSlurmCommandGenStrategy:
@@ -305,3 +352,37 @@ def test_disable_output_and_error(self, add_arg: str, strategy_fixture: SlurmCom
305352

306353
self.assert_positional_lines(file_contents.splitlines())
307354
assert f"--{add_arg}=" not in file_contents
355+
356+
357+
class TestNCCLSlurmCommandGen:
358+
def get_cmd(self, slurm_system: SlurmSystem, slurm_args: dict, cmd_args: dict) -> str:
359+
return NcclTestSlurmCommandGenStrategy(slurm_system, {}, {}).generate_full_srun_command(
360+
slurm_args, {}, cmd_args, ""
361+
)
362+
363+
def test_only_mandatory(self, slurm_system: SlurmSystem) -> None:
364+
slurm_args = {"image_path": "fake_image_path"}
365+
cmd_args = {"subtest_name": "fake_subtest_name"}
366+
cmd = self.get_cmd(slurm_system, slurm_args, cmd_args)
367+
assert cmd == " \\\n".join(
368+
[
369+
"srun",
370+
f"--mpi={slurm_system.mpi}",
371+
f"--container-image={slurm_args['image_path']}",
372+
f"/usr/local/bin/{cmd_args['subtest_name']}",
373+
]
374+
)
375+
376+
def test_with_container_mounts(self, slurm_system: SlurmSystem) -> None:
377+
slurm_args = {"image_path": "fake_image_path", "container_mounts": "fake_mounts"}
378+
cmd_args = {"subtest_name": "fake_subtest_name"}
379+
cmd = self.get_cmd(slurm_system, slurm_args, cmd_args)
380+
assert cmd == " \\\n".join(
381+
[
382+
"srun",
383+
f"--mpi={slurm_system.mpi}",
384+
f"--container-image={slurm_args['image_path']}",
385+
f"--container-mounts={slurm_args['container_mounts']}",
386+
f"/usr/local/bin/{cmd_args['subtest_name']}",
387+
]
388+
)

0 commit comments

Comments
 (0)