Skip to content

Commit d61a046

Browse files
Merge pull request #298 from NVIDIA/am/slurm-cont
Generic Slurm Container job
2 parents 43d89c2 + 63973e3 commit d61a046

File tree

22 files changed

+344
-42
lines changed

22 files changed

+344
-42
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ jobs:
2323
run: pip install -r requirements-dev.txt
2424

2525
- name: Run ruff linter
26-
run: ruff check .
26+
run: ruff check
2727

2828
- name: Run ruff formatter
29-
run: ruff format --check --diff .
29+
run: ruff format --check --diff
3030

3131
- name: Run pyright
32-
run: pyright .
32+
run: pyright
3333

3434
- name: Run vulture check
3535
run: vulture src/ tests/

conf/common/system/example_slurm_cluster.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
name = "example-cluster"
1818
scheduler = "slurm"
1919

20-
install_path = "./install"
20+
install_path = "./install_dir"
2121
output_path = "./results"
2222
default_partition = "partition_1"
2323

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,6 @@ min_confidence = 100
100100

101101
[tool.coverage.report]
102102
exclude_also = ["@abstractmethod"]
103+
104+
[tool.pyright]
105+
include = ["src", "tests"]

src/cloudai/__init__.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@
8181
from .schema.test_template.sleep.slurm_command_gen_strategy import SleepSlurmCommandGenStrategy
8282
from .schema.test_template.sleep.standalone_command_gen_strategy import SleepStandaloneCommandGenStrategy
8383
from .schema.test_template.sleep.template import Sleep
84+
from .schema.test_template.slurm_container.report_generation_strategy import (
85+
SlurmContainerReportGenerationStrategy,
86+
)
87+
from .schema.test_template.slurm_container.slurm_command_gen_strategy import (
88+
SlurmContainerCommandGenStrategy,
89+
)
90+
from .schema.test_template.slurm_container.template import SlurmContainer
8491
from .schema.test_template.ucc_test.grading_strategy import UCCTestGradingStrategy
8592
from .schema.test_template.ucc_test.report_generation_strategy import UCCTestReportGenerationStrategy
8693
from .schema.test_template.ucc_test.slurm_command_gen_strategy import UCCTestSlurmCommandGenStrategy
@@ -98,6 +105,7 @@
98105
SleepTestDefinition,
99106
UCCTestDefinition,
100107
)
108+
from .test_definitions.slurm_container import SlurmContainerTestDefinition
101109

102110
Registry().add_runner("slurm", SlurmRunner)
103111
Registry().add_runner("kubernetes", KubernetesRunner)
@@ -121,14 +129,21 @@
121129
Registry().add_strategy(JobIdRetrievalStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherSlurmJobIdRetrievalStrategy)
122130
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherSlurmCommandGenStrategy)
123131
Registry().add_strategy(ReportGenerationStrategy, [SlurmSystem], [UCCTest], UCCTestReportGenerationStrategy)
132+
Registry().add_strategy(
133+
ReportGenerationStrategy,
134+
[SlurmSystem],
135+
[SlurmContainer],
136+
SlurmContainerReportGenerationStrategy,
137+
)
124138
Registry().add_strategy(GradingStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherGradingStrategy)
139+
125140
Registry().add_strategy(GradingStrategy, [SlurmSystem], [JaxToolbox], JaxToolboxGradingStrategy)
126141
Registry().add_strategy(GradingStrategy, [SlurmSystem], [UCCTest], UCCTestGradingStrategy)
127142
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [JaxToolbox], JaxToolboxSlurmCommandGenStrategy)
128143
Registry().add_strategy(
129144
JobIdRetrievalStrategy,
130145
[SlurmSystem],
131-
[ChakraReplay, JaxToolbox, NcclTest, UCCTest, Sleep],
146+
[ChakraReplay, JaxToolbox, NcclTest, UCCTest, Sleep, SlurmContainer],
132147
SlurmJobIdRetrievalStrategy,
133148
)
134149
Registry().add_strategy(JobIdRetrievalStrategy, [StandaloneSystem], [Sleep], StandaloneJobIdRetrievalStrategy)
@@ -141,13 +156,14 @@
141156
Registry().add_strategy(
142157
JobStatusRetrievalStrategy,
143158
[SlurmSystem],
144-
[ChakraReplay, UCCTest, NeMoLauncher, Sleep],
159+
[ChakraReplay, UCCTest, NeMoLauncher, Sleep, SlurmContainer],
145160
DefaultJobStatusRetrievalStrategy,
146161
)
147162
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [UCCTest], UCCTestSlurmCommandGenStrategy)
148163
Registry().add_strategy(ReportGenerationStrategy, [SlurmSystem], [ChakraReplay], ChakraReplayReportGenerationStrategy)
149164
Registry().add_strategy(GradingStrategy, [SlurmSystem], [ChakraReplay], ChakraReplayGradingStrategy)
150165
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [ChakraReplay], ChakraReplaySlurmCommandGenStrategy)
166+
Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [SlurmContainer], SlurmContainerCommandGenStrategy)
151167

152168
Registry().add_installer("slurm", SlurmInstaller)
153169
Registry().add_installer("standalone", StandaloneInstaller)
@@ -165,6 +181,7 @@
165181
Registry().add_test_definition("JaxToolboxGPT", GPTTestDefinition)
166182
Registry().add_test_definition("JaxToolboxGrok", GrokTestDefinition)
167183
Registry().add_test_definition("JaxToolboxNemotron", NemotronTestDefinition)
184+
Registry().add_test_definition("SlurmContainer", SlurmContainerTestDefinition)
168185

169186
Registry().add_test_template("ChakraReplay", ChakraReplay)
170187
Registry().add_test_template("NcclTest", NcclTest)
@@ -174,6 +191,7 @@
174191
Registry().add_test_template("JaxToolboxGPT", JaxToolbox)
175192
Registry().add_test_template("JaxToolboxGrok", JaxToolbox)
176193
Registry().add_test_template("JaxToolboxNemotron", JaxToolbox)
194+
Registry().add_test_template("SlurmContainer", SlurmContainer)
177195

178196
__all__ = [
179197
"BaseInstaller",

src/cloudai/installer/slurm_installer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def install_one(self, item: Installable) -> InstallStatusResult:
118118
if isinstance(item, DockerImage):
119119
res = self._install_docker_image(item)
120120
return InstallStatusResult(res.success, res.message)
121+
elif isinstance(item, GitRepo):
122+
return self._install_one_git_repo(item)
121123
elif isinstance(item, PythonExecutable):
122124
return self._install_python_executable(item)
123125

@@ -139,6 +141,8 @@ def uninstall_one(self, item: Installable) -> InstallStatusResult:
139141
return InstallStatusResult(res.success, res.message)
140142
elif isinstance(item, PythonExecutable):
141143
return self._uninstall_python_executable(item)
144+
elif isinstance(item, GitRepo):
145+
return self._uninstall_git_repo(item)
142146

143147
return InstallStatusResult(False, f"Unsupported item type: {type(item)}")
144148

@@ -148,6 +152,12 @@ def is_installed_one(self, item: Installable) -> InstallStatusResult:
148152
if res.success and res.docker_image_path:
149153
item.installed_path = res.docker_image_path
150154
return InstallStatusResult(res.success, res.message)
155+
elif isinstance(item, GitRepo):
156+
repo_path = self.system.install_path / item.repo_name
157+
if repo_path.exists():
158+
item.installed_path = repo_path
159+
return InstallStatusResult(True)
160+
return InstallStatusResult(False, f"Git repository {item.git_url} not cloned")
151161
elif isinstance(item, PythonExecutable):
152162
return self._is_python_executable_installed(item)
153163

src/cloudai/report_generator/report_generator.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@ def _generate_test_report(self, directory_path: Path, tr: TestRun) -> None:
7070
tr (TestRun): The test run object.
7171
"""
7272
for subdir in directory_path.iterdir():
73-
if subdir.is_dir() and tr.test.test_template.can_handle_directory(subdir):
74-
tr.test.test_template.generate_report(tr.test.name, subdir, tr.sol)
75-
else:
76-
logging.warning(f"Skipping directory '{subdir}' for test '{tr.test.name}'")
73+
if not subdir.is_dir():
74+
logging.debug(f"Skipping file '{subdir}', not a directory.")
75+
continue
76+
if not tr.test.test_template.can_handle_directory(subdir):
77+
logging.warning(
78+
f"Skipping '{subdir}', can't hande with "
79+
f"strategy={tr.test.test_template.report_generation_strategy}."
80+
)
81+
continue
82+
83+
tr.test.test_template.generate_report(tr.test.name, subdir, tr.sol)

src/cloudai/runner/slurm/slurm_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,5 @@ def _submit_test(self, tr: TestRun) -> SlurmJob:
6868
stderr=stderr,
6969
message="Failed to retrieve job ID from command output.",
7070
)
71+
logging.info(f"Submitted slurm job: {job_id}")
7172
return SlurmJob(self.mode, self.system, tr, job_id)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import re
18+
from pathlib import Path
19+
from typing import Optional
20+
21+
from cloudai import ReportGenerationStrategy
22+
23+
24+
class SlurmContainerReportGenerationStrategy(ReportGenerationStrategy):
25+
"""Report generation strategy for a generic Slurm container test."""
26+
27+
def can_handle_directory(self, directory_path: Path) -> bool:
28+
stdout_path = directory_path / "stdout.txt"
29+
if stdout_path.exists():
30+
with stdout_path.open("r") as file:
31+
if re.search(
32+
r"Training epoch \d+, iteration \d+/\d+ | lr: [\d.]+ | global_batch_size: \d+ | global_step: \d+ | "
33+
r"reduced_train_loss: [\d.]+ | train_step_timing in s: [\d.]+",
34+
file.read(),
35+
):
36+
return True
37+
return False
38+
39+
def generate_report(self, test_name: str, directory_path: Path, sol: Optional[float] = None) -> None:
40+
stdout_path = directory_path / "stdout.txt"
41+
if not stdout_path.is_file():
42+
return
43+
44+
with stdout_path.open("r") as file:
45+
lines = file.readlines()
46+
with open(directory_path / "report.csv", "w") as csv_file:
47+
csv_file.write(
48+
"epoch,iteration,lr,global_batch_size,global_step,reduced_train_loss,train_step_timing,consumed_samples\n"
49+
)
50+
for line in lines:
51+
pattern = (
52+
r"Training epoch (\d+), iteration (\d+)/\d+ \| lr: ([\d.]+) \| global_batch_size: (\d+) \| "
53+
r"global_step: (\d+) \| reduced_train_loss: ([\d.]+) \| train_step_timing in s: ([\d.]+)"
54+
)
55+
if " | consumed_samples:" in line:
56+
pattern = (
57+
r"Training epoch (\d+), iteration (\d+)/\d+ \| lr: ([\d.]+) \| global_batch_size: (\d+) \| "
58+
r"global_step: (\d+) \| reduced_train_loss: ([\d.]+) \| train_step_timing in s: ([\d.]+) "
59+
r"\| consumed_samples: (\d+)"
60+
)
61+
62+
match = re.match(pattern, line)
63+
if match:
64+
csv_file.write(",".join(match.groups()) + "\n")
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from typing import Any, cast
18+
19+
from cloudai import TestRun
20+
from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy
21+
from cloudai.test_definitions.slurm_container import SlurmContainerTestDefinition
22+
23+
24+
class SlurmContainerCommandGenStrategy(SlurmCommandGenStrategy):
25+
"""Command generation strategy for generic Slurm container tests."""
26+
27+
def gen_srun_prefix(self, slurm_args: dict[str, Any], tr: TestRun) -> list[str]:
28+
tdef: SlurmContainerTestDefinition = cast(SlurmContainerTestDefinition, tr.test.test_definition)
29+
slurm_args["image_path"] = tdef.docker_image.installed_path
30+
slurm_args["container_mounts"] = ",".join(tdef.container_mounts(self.system.install_path))
31+
32+
cmd = super().gen_srun_prefix(slurm_args, tr)
33+
return cmd + ["--no-container-mount-home"]
34+
35+
def generate_test_command(self, env_vars: dict[str, str], cmd_args: dict[str, str], tr: TestRun) -> list[str]:
36+
srun_command_parts: list[str] = []
37+
if tr.test.extra_cmd_args:
38+
srun_command_parts.append(tr.test.extra_cmd_args)
39+
40+
return srun_command_parts
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from cloudai import TestTemplate
18+
19+
20+
class SlurmContainer(TestTemplate):
21+
"""Generic Slurm container test template."""
22+
23+
pass

0 commit comments

Comments
 (0)