Skip to content

Commit fd6f59b

Browse files
authored
Add extra_srun_args & scripts in SlurmContainerTestDef (#531)
1 parent 6892fa4 commit fd6f59b

File tree

4 files changed

+56
-5
lines changed

4 files changed

+56
-5
lines changed

src/cloudai/workloads/slurm_container/slurm_command_gen_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def gen_srun_prefix(self, slurm_args: dict[str, Any], tr: TestRun, use_pretest_e
3636
tdef: SlurmContainerTestDefinition = cast(SlurmContainerTestDefinition, tr.test.test_definition)
3737
slurm_args["image_path"] = tdef.docker_image.installed_path
3838
cmd = super().gen_srun_prefix(slurm_args, tr)
39-
return [*cmd, "--no-container-mount-home"]
39+
return [*cmd, *tdef.extra_srun_args, "--no-container-mount-home"]
4040

4141
def generate_test_command(
4242
self, env_vars: Dict[str, Union[str, List[str]]], cmd_args: Dict[str, Union[str, List[str]]], tr: TestRun

src/cloudai/workloads/slurm_container/slurm_container.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from typing import Optional
1818

19-
from cloudai import DockerImage, Installable
19+
from pydantic import Field
20+
21+
from cloudai import DockerImage, File, Installable
2022

2123
from ...models.workload import CmdArgs, TestDefinition
2224

@@ -32,7 +34,8 @@ class SlurmContainerTestDefinition(TestDefinition):
3234
"""Test definition for a generic Slurm container test."""
3335

3436
cmd_args: SlurmContainerCmdArgs
35-
37+
extra_srun_args: list[str] = Field(default_factory=list)
38+
scripts: list[File] = Field(default_factory=list)
3639
_docker_image: Optional[DockerImage] = None
3740

3841
@property
@@ -43,7 +46,7 @@ def docker_image(self) -> DockerImage:
4346

4447
@property
4548
def installables(self) -> list[Installable]:
46-
return [self.docker_image, *self.git_repos]
49+
return [self.docker_image, *self.git_repos, *self.scripts]
4750

4851
@property
4952
def extra_args_str(self) -> str:

tests/slurm_command_gen_strategy/test_slurm_container_slurm_command_gen_strategy.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
from pathlib import Path
18+
from typing import cast
1819

1920
import pytest
2021

@@ -72,3 +73,23 @@ def test_with_nsys(slurm_system: SlurmSystem, test_run: TestRun) -> None:
7273
)
7374

7475
assert cmd == f'{srun_part} bash -c "{" ".join(nsys.cmd_args)} cmd"'
76+
77+
78+
def test_with_extra_srun_args(slurm_system: SlurmSystem, test_run: TestRun) -> None:
79+
extra_args = ["--ntasks=1", "--ntasks-per-node=1"]
80+
tdef = cast(SlurmContainerTestDefinition, test_run.test.test_definition)
81+
tdef.extra_srun_args = extra_args
82+
83+
cgs = SlurmContainerCommandGenStrategy(slurm_system, {})
84+
cmd = cgs.gen_srun_command(test_run)
85+
86+
srun_part = (
87+
f"srun --export=ALL --mpi={slurm_system.mpi} "
88+
f"--container-image={test_run.test.test_definition.cmd_args.docker_image_url} "
89+
f"--container-mounts={Path.cwd().absolute()}:/cloudai_run_results,"
90+
f"{slurm_system.install_path.absolute()}:/cloudai_install "
91+
f"{' '.join(extra_args)} "
92+
f"--no-container-mount-home"
93+
)
94+
95+
assert cmd == f'{srun_part} bash -c "cmd"'

tests/test_test_definitions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import toml
2222
from pydantic import ValidationError
2323

24-
from cloudai import NsysConfiguration, Parser, Registry, TestConfigParsingError, TestDefinition, TestParser
24+
from cloudai import File, NsysConfiguration, Parser, Registry, TestConfigParsingError, TestDefinition, TestParser
2525
from cloudai.workloads.chakra_replay import ChakraReplayCmdArgs, ChakraReplayTestDefinition
2626
from cloudai.workloads.jax_toolbox import (
2727
GPTCmdArgs,
@@ -35,6 +35,7 @@
3535
from cloudai.workloads.nccl_test import NCCLCmdArgs, NCCLTestDefinition
3636
from cloudai.workloads.nemo_launcher import NeMoLauncherCmdArgs, NeMoLauncherTestDefinition
3737
from cloudai.workloads.nemo_run import NeMoRunCmdArgs, NeMoRunTestDefinition
38+
from cloudai.workloads.slurm_container import SlurmContainerCmdArgs, SlurmContainerTestDefinition
3839
from cloudai.workloads.ucc_test import UCCCmdArgs, UCCTestDefinition
3940

4041
TOML_FILES = list(Path("conf").glob("**/*.toml"))
@@ -126,6 +127,12 @@ def test_chakra_docker_image_is_required():
126127
test_template_name="chakra",
127128
cmd_args=ChakraReplayCmdArgs(docker_image_url="fake://url/chakra"),
128129
),
130+
SlurmContainerTestDefinition(
131+
name="sc",
132+
description="desc",
133+
test_template_name="sc",
134+
cmd_args=SlurmContainerCmdArgs(docker_image_url="fake://url/sc", cmd="cmd"),
135+
),
129136
],
130137
)
131138
def test_docker_installable_persists(
@@ -137,6 +144,7 @@ def test_docker_installable_persists(
137144
NeMoLauncherTestDefinition,
138145
NemotronTestDefinition,
139146
UCCTestDefinition,
147+
SlurmContainerTestDefinition,
140148
],
141149
tmp_path: Path,
142150
):
@@ -159,6 +167,25 @@ def test_python_executable_installable_persists(test: NeMoLauncherTestDefinition
159167
assert test.python_executable.venv_path == tmp_path
160168

161169

170+
@pytest.mark.parametrize(
171+
"test",
172+
[
173+
SlurmContainerTestDefinition(
174+
name="sc",
175+
description="desc",
176+
test_template_name="sc",
177+
cmd_args=SlurmContainerCmdArgs(docker_image_url="fake://url/sc", cmd="cmd"),
178+
scripts=[File(src=Path("./script1")), File(src=Path("./script2"))],
179+
)
180+
],
181+
)
182+
def test_slurm_container_installables(test: SlurmContainerTestDefinition):
183+
assert len(test.installables) >= 3
184+
assert test.docker_image in test.installables
185+
assert File(src=Path("./script1")) in test.installables
186+
assert File(src=Path("./script2")) in test.installables
187+
188+
162189
class TestNsysConfiguration:
163190
def test_default(self):
164191
nsys = NsysConfiguration()

0 commit comments

Comments
 (0)