Skip to content

Commit 215c010

Browse files
committed
Fix how test-in-scenario is merge with test-in-toml
1 parent 0e26c88 commit 215c010

File tree

3 files changed

+54
-8
lines changed

3 files changed

+54
-8
lines changed

src/cloudai/_core/test_scenario_parser.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,30 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]:
275275
if test_info.test_name not in self.test_mapping:
276276
raise ValueError(f"Test '{test_info.test_name}' is not defined. Was tests directory correctly set?")
277277
test = self.test_mapping[test_info.test_name]
278-
elif test_info.test_template_name:
278+
279+
test_defined = test.test_definition.model_dump()
280+
tc_defined = test_info.tdef_model_dump()
281+
merged_data = deep_merge(test_defined, tc_defined)
282+
test.test_definition = tp.load_test_definition(merged_data, self.strict)
283+
elif test_info.test_template_name: # test fully defined in the scenario
279284
test = tp._parse_data(test_info.tdef_model_dump(), self.strict)
280285
else:
281286
# this should never happen, because we check for this in the modelvalidator
282287
raise ValueError(
283288
f"Cannot configure test case '{test_info.id}' with both 'test_name' and 'test_template_name'."
284289
)
285290

286-
merged_data = test.test_definition.model_dump()
287-
merged_data.update(test_info.tdef_model_dump())
288-
tdef = tp.load_test_definition(merged_data, self.strict)
291+
return test, test.test_definition
292+
289293

290-
return test, tdef
294+
def deep_merge(a: dict, b: dict):
295+
result = a.copy()
296+
for key in b:
297+
if key in result:
298+
if isinstance(result[key], dict) and isinstance(b[key], dict):
299+
result[key] = deep_merge(result[key], b[key])
300+
else:
301+
result[key] = b[key]
302+
else:
303+
result[key] = b[key]
304+
return result

src/cloudai/models/scenario.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def tdef_model_dump(self) -> dict:
7575
"agent_steps": self.agent_steps,
7676
"agent_metric": self.agent_metric,
7777
"extra_container_mounts": self.extra_container_mounts,
78-
"cmd_args": self.cmd_args.model_dump() if self.cmd_args else None,
7978
"extra_env_vars": self.extra_env_vars if self.extra_env_vars else None,
79+
"cmd_args": self.cmd_args.model_dump() if self.cmd_args else None,
8080
"git_repos": [repo.model_dump() for repo in self.git_repos] if self.git_repos else None,
8181
"nsys": self.nsys.model_dump() if self.nsys else None,
8282
}

tests/test_test_scenario.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
JaxToolboxReportGenerationStrategy,
4747
NemotronTestDefinition,
4848
)
49-
from cloudai.workloads.megatron_run import CheckpointTimingReportGenerationStrategy, MegatronRunTestDefinition
49+
from cloudai.workloads.megatron_run import (
50+
CheckpointTimingReportGenerationStrategy,
51+
MegatronRunCmdArgs,
52+
MegatronRunTestDefinition,
53+
)
5054
from cloudai.workloads.nccl_test import (
5155
NCCLCmdArgs,
5256
NCCLTestDefinition,
@@ -280,7 +284,7 @@ def test_total_time_limit_with_empty_hooks():
280284
assert result == "01:00:00"
281285

282286

283-
class TestSpec:
287+
class TestInScenario:
284288
@pytest.mark.parametrize("missing_arg", ["test_template_name", "name", "description"])
285289
def test_without_base(self, missing_arg: str):
286290
spec = {
@@ -414,6 +418,34 @@ def test_spec_can_set_unknown_args_no_base(self, test_scenario_parser: TestScena
414418
assert tdef.cmd_args_dict["unknown"] == 42
415419
assert isinstance(tdef.cmd_args, NCCLCmdArgs)
416420

421+
def test_data_is_merge_correctly(self, test_scenario_parser: TestScenarioParser, slurm_system: SlurmSystem):
422+
test_scenario_parser.test_mapping = {
423+
"megatron": Test(
424+
test_definition=MegatronRunTestDefinition(
425+
name="megatron",
426+
description="desc",
427+
test_template_name="MegatronRun",
428+
cmd_args=MegatronRunCmdArgs(docker_image_url="docker://megatron", run_script=Path("run.sh")),
429+
),
430+
test_template=TestTemplate(system=slurm_system),
431+
)
432+
}
433+
model = TestScenarioModel.model_validate(
434+
toml.loads(
435+
"""
436+
name = "test"
437+
438+
[[Tests]]
439+
id = "1"
440+
test_name = "megatron"
441+
cmd_args = { any = 42 }
442+
"""
443+
)
444+
)
445+
_, tdef = test_scenario_parser._prepare_tdef(model.tests[0])
446+
assert isinstance(tdef.cmd_args, MegatronRunCmdArgs)
447+
assert tdef.cmd_args.run_script == Path("run.sh")
448+
417449

418450
class TestReporters:
419451
def test_default(self):

0 commit comments

Comments
 (0)