Skip to content

Commit 7d696c1

Browse files
authored
Merge pull request #521 from NVIDIA/am/fix-construction-megatron
Update MegatronRun model dump logic
2 parents 386cdb3 + 13e1108 commit 7d696c1

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

src/cloudai/workloads/megatron_run/megatron_run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
class MegatronRunCmdArgs(CmdArgs):
2929
"""MegatronRun test command arguments."""
3030

31-
docker_image_url: str = Field(exclude=True)
32-
run_script: Path = Field(exclude=True)
31+
docker_image_url: str = Field()
32+
run_script: Path = Field()
3333

3434
global_batch_size: Optional[int] = 16
3535
hidden_size: Optional[int] = 4096
@@ -65,7 +65,7 @@ def no_dashed_args(self):
6565

6666
@property
6767
def cmd_args(self) -> dict[str, Union[str, list[str]]]:
68-
args = self.model_dump(exclude_none=True)
68+
args = self.model_dump(exclude_none=True, exclude={"docker_image_url", "run_script"})
6969
args = {f"--{k.replace('_', '-')}": v for k, v in args.items()}
7070
return args
7171

tests/test_test_definitions.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,6 @@ def test_tokenizer_model(self, megatron_run: MegatronRunTestDefinition):
291291
assert megatron_run.cmd_args_dict["--tokenizer-model"] == Path("/path/to/tokenizer")
292292
assert megatron_run.cmd_args.tokenizer_model == Path("/path/to/tokenizer")
293293

294-
def test_auxiliary_fields_not_in_model_dump(self, megatron_run: MegatronRunTestDefinition):
295-
d = megatron_run.cmd_args.model_dump()
296-
assert "docker_image_url" not in d
297-
assert "run_script" not in d
298-
299294
@pytest.mark.parametrize("field", ["load", "save"])
300295
def test_load_is_set_but_not_mounted(self, field: str):
301296
with pytest.raises(ValueError) as exc_info:

tests/test_test_scenario.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def test_spec_can_set_unknown_args_no_base(self, test_scenario_parser: TestScena
412412
)
413413
_, tdef = test_scenario_parser._prepare_tdef(model.tests[0])
414414
assert tdef.cmd_args_dict["unknown"] == 42
415+
assert isinstance(tdef.cmd_args, NCCLCmdArgs)
415416

416417

417418
class TestReporters:

0 commit comments

Comments
 (0)