Skip to content

Commit 95b3681

Browse files
authored
Test hook support (#263)
1 parent 12a807c commit 95b3681

27 files changed

+653
-406
lines changed

conf/common/test_scenario/nccl_test.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
# limitations under the License.
1616

1717
name = "nccl-test"
18+
19+
pre_test = "nccl_test"
20+
post_test = "nccl_test"
21+
1822
[[Tests]]
1923
id = "Tests.1"
2024
test_name = "nccl_test_all_reduce"

conf/hook/nccl_test.toml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
name = "nccl_test"
18+
19+
[[Tests]]
20+
id = "Tests.1"
21+
test_name = "nccl_test_all_gather"
22+
time_limit = "00:20:00"
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
name = "nccl_test_all_gather"
18+
description = "all_gather"
19+
test_template_name = "NcclTest"
20+
21+
[cmd_args]
22+
"subtest_name" = "all_gather_perf_mpi"
23+
"ngpus" = "1"
24+
"minbytes" = "128"
25+
"maxbytes" = "4G"
26+
"iters" = "100"
27+
"warmup_iters" = "50"
28+
29+
[extra_cmd_args]
30+
"--stepfactor" = "2"
31+
32+
[extra_env_vars]
33+
"NCCL_TEST_SPLIT_MASK" = "0x7"

src/cloudai/_core/command_gen_strategy.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,29 @@ def gen_exec_command(self, tr: TestRun) -> str:
3939
str: The generated execution command.
4040
"""
4141
pass
42+
43+
@abstractmethod
44+
def gen_srun_command(self, tr: TestRun) -> str:
45+
"""
46+
Generate the Slurm srun command for a test based on the given parameters.
47+
48+
Args:
49+
tr (TestRun): Contains the test and its run-specific configurations.
50+
51+
Returns:
52+
str: The generated Slurm srun command.
53+
"""
54+
pass
55+
56+
@abstractmethod
57+
def gen_srun_success_check(self, tr: TestRun) -> str:
58+
"""
59+
Generate the Slurm success check command to verify if a test run was successful.
60+
61+
Args:
62+
tr (TestRun): Contains the test and its run-specific configurations.
63+
64+
Returns:
65+
str: The generated command to check the success of the test run.
66+
"""
67+
pass

src/cloudai/_core/test_scenario.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class TestRun:
5858
weight: float = 0.0
5959
ideal_perf: float = 1.0
6060
dependencies: dict[str, TestDependency] = field(default_factory=dict)
61+
pre_test: Optional["TestScenario"] = None
62+
post_test: Optional["TestScenario"] = None
6163

6264
def __hash__(self) -> int:
6365
return hash(self.name + self.test.name + str(self.iterations) + str(self.current_iteration))

src/cloudai/_core/test_scenario_parser.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class _TestScenarioTOML(BaseModel):
5454
name: str
5555
job_status_check: bool = True
5656
tests: list[_TestRunTOML] = Field(alias="Tests", min_length=1)
57+
pre_test: Optional[str] = None
58+
post_test: Optional[str] = None
5759

5860
@model_validator(mode="after")
5961
def check_no_self_dependency(self):
@@ -99,9 +101,10 @@ class TestScenarioParser:
99101

100102
__test__ = False
101103

102-
def __init__(self, file_path: Path, test_mapping: Dict[str, Test]) -> None:
104+
def __init__(self, file_path: Path, test_mapping: Dict[str, Test], hook_mapping: Dict[str, TestScenario]) -> None:
103105
self.file_path = file_path
104106
self.test_mapping = test_mapping
107+
self.hook_mapping = hook_mapping
105108

106109
def parse(self) -> TestScenario:
107110
"""
@@ -136,8 +139,31 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario:
136139
total_weight = sum(tr.weight for tr in ts_model.tests)
137140
normalized_weight = 0 if total_weight == 0 else 100 / total_weight
138141

142+
pre_test, post_test = None, None
143+
if ts_model.pre_test:
144+
pre_test = self.hook_mapping.get(ts_model.pre_test)
145+
if pre_test is None:
146+
msg = (
147+
f"Pre-test hook '{ts_model.pre_test}' not found in hook mapping. "
148+
"A corresponding hook should exist under 'conf/hook'. "
149+
"Ensure that a proper hook directory is set under the working directory."
150+
)
151+
logging.error(msg)
152+
raise TestScenarioParsingError(msg)
153+
154+
if ts_model.post_test:
155+
post_test = self.hook_mapping.get(ts_model.post_test)
156+
if post_test is None:
157+
msg = (
158+
f"Post-test hook '{ts_model.post_test}' not found in hook mapping. "
159+
"A corresponding hook should exist under 'conf/hook'. "
160+
"Ensure that a proper hook directory is set under the working directory."
161+
)
162+
logging.error(msg)
163+
raise TestScenarioParsingError(msg)
164+
139165
test_runs_by_id: dict[str, TestRun] = {
140-
tr.id: self._create_test_run(tr, normalized_weight) for tr in ts_model.tests
166+
tr.id: self._create_test_run(tr, normalized_weight, pre_test, post_test) for tr in ts_model.tests
141167
}
142168

143169
tests_data: dict[str, _TestRunTOML] = {tr.id: tr for tr in ts_model.tests}
@@ -153,13 +179,21 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario:
153179
job_status_check=ts_model.job_status_check,
154180
)
155181

156-
def _create_test_run(self, test_info: _TestRunTOML, normalized_weight: float) -> TestRun:
182+
def _create_test_run(
183+
self,
184+
test_info: _TestRunTOML,
185+
normalized_weight: float,
186+
pre_test: Optional[TestScenario] = None,
187+
post_test: Optional[TestScenario] = None,
188+
) -> TestRun:
157189
"""
158190
Create a section-specific Test object by copying from the test mapping.
159191
160192
Args:
161193
test_info (Dict[str, Any]): Information of the test.
162194
normalized_weight (float): Normalized weight for the test.
195+
pre_test (Optional[TestScenario]): TestScenario object representing the pre-test sequence.
196+
post_test (Optional[TestScenario]): TestScenario object representing the post-test sequence.
163197
164198
Returns:
165199
Test: Copied and updated Test object for the section.
@@ -192,5 +226,7 @@ def _create_test_run(self, test_info: _TestRunTOML, normalized_weight: float) ->
192226
sol=test_info.sol,
193227
weight=test_info.weight * normalized_weight,
194228
ideal_perf=test_info.ideal_perf,
229+
pre_test=pre_test,
230+
post_test=post_test,
195231
)
196232
return tr

src/cloudai/_core/test_template.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,40 @@ def gen_exec_command(self, tr: TestRun) -> str:
9393
)
9494
return self.command_gen_strategy.gen_exec_command(tr)
9595

96+
def gen_srun_command(self, tr: TestRun) -> str:
97+
"""
98+
Generate an Slurm srun command for a test using the provided command generation strategy.
99+
100+
Args:
101+
tr (TestRun): Contains the test and its run-specific configurations.
102+
103+
Returns:
104+
str: The generated Slurm srun command.
105+
"""
106+
if self.command_gen_strategy is None:
107+
raise ValueError(
108+
"command_gen_strategy is missing. Ensure the strategy is registered in the Registry "
109+
"by calling the appropriate registration function for the system type."
110+
)
111+
return self.command_gen_strategy.gen_srun_command(tr)
112+
113+
def gen_srun_success_check(self, tr: TestRun) -> str:
114+
"""
115+
Generate a Slurm success check command for a test using the provided command generation strategy.
116+
117+
Args:
118+
tr (TestRun): Contains the test and its run-specific configurations.
119+
120+
Returns:
121+
str: The generated command to check the success of the test run.
122+
"""
123+
if self.command_gen_strategy is None:
124+
raise ValueError(
125+
"command_gen_strategy is missing. Ensure the strategy is registered in the Registry "
126+
"by calling the appropriate registration function for the system type."
127+
)
128+
return self.command_gen_strategy.gen_srun_success_check(tr)
129+
96130
def gen_json(self, tr: TestRun) -> Dict[Any, Any]:
97131
"""
98132
Generate a JSON string representing the Kubernetes job specification for this test using this template.

src/cloudai/cli/handlers.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from cloudai import Installable, Parser, Registry, ReportGenerator, Runner, System
2525

26+
from ..parser import HOOK_ROOT
27+
2628

2729
def handle_install_and_uninstall(args: argparse.Namespace) -> int:
2830
"""
@@ -212,7 +214,11 @@ def verify_test_configs(test_tomls: List[Path]) -> int:
212214

213215

214216
def verify_test_scenarios(
215-
scenario_tomls: List[Path], test_tomls: list[Path], system_config: Optional[Path] = None
217+
scenario_tomls: List[Path],
218+
test_tomls: list[Path],
219+
hook_tomls: List[Path],
220+
hook_test_tomls: list[Path],
221+
system_config: Optional[Path] = None,
216222
) -> int:
217223
system = Mock(spec=System)
218224
if system_config:
@@ -225,7 +231,9 @@ def verify_test_scenarios(
225231
logging.debug(f"Verifying Test Scenario: {scenario_file}...")
226232
try:
227233
tests = Parser.parse_tests(test_tomls, system)
228-
Parser.parse_test_scenario(scenario_file, {t.name: t for t in tests})
234+
hook_tests = Parser.parse_tests(hook_test_tomls, system)
235+
hooks = Parser.parse_hooks(hook_tomls, {t.name: t for t in hook_tests})
236+
Parser.parse_test_scenario(scenario_file, {t.name: t for t in tests}, hooks)
229237
except Exception:
230238
nfailed += 1
231239

@@ -243,6 +251,9 @@ def handle_verify_all_configs(args: argparse.Namespace) -> int:
243251
if err:
244252
return err
245253

254+
err, hook_tomls = expand_file_list(HOOK_ROOT, glob="**/*.toml")
255+
tomls += hook_tomls
256+
246257
files = load_tomls_by_type(tomls)
247258

248259
test_tomls = files["test"]
@@ -259,7 +270,9 @@ def handle_verify_all_configs(args: argparse.Namespace) -> int:
259270
if files["test"]:
260271
nfailed += verify_test_configs(files["test"])
261272
if files["scenario"]:
262-
nfailed += verify_test_scenarios(files["scenario"], test_tomls, args.system_config)
273+
nfailed += verify_test_scenarios(
274+
files["scenario"], test_tomls, files["hook"], files["hook_test"], args.system_config
275+
)
263276
if files["unknown"]:
264277
logging.error(f"Unknown configuration files: {[str(f) for f in files['unknown']]}")
265278
nfailed += len(files["unknown"])
@@ -273,9 +286,31 @@ def handle_verify_all_configs(args: argparse.Namespace) -> int:
273286

274287

275288
def load_tomls_by_type(tomls: List[Path]) -> dict[str, List[Path]]:
276-
files: dict[str, List[Path]] = {"system": [], "test": [], "scenario": [], "unknown": []}
289+
files: dict[str, List[Path]] = {
290+
"system": [],
291+
"test": [],
292+
"scenario": [],
293+
"hook_test": [],
294+
"hook": [],
295+
"unknown": [],
296+
}
277297
for toml_file in tomls:
278298
content = toml_file.read_text()
299+
300+
is_in_hook_root = False
301+
try:
302+
toml_file.relative_to(HOOK_ROOT)
303+
is_in_hook_root = True
304+
except ValueError:
305+
pass
306+
307+
if is_in_hook_root:
308+
if "test" in toml_file.parts:
309+
files["hook_test"].append(toml_file)
310+
else:
311+
files["hook"].append(toml_file)
312+
continue
313+
279314
if "scheduler =" in content:
280315
files["system"].append(toml_file)
281316
elif "test_template_name =" in content:

0 commit comments

Comments
 (0)