Skip to content

Commit 69cba47

Browse files
Merge pull request #772 from srivatsankrishnan/verify-configs
Model Name/ModeL size for verify configs
2 parents 507c141 + e3c4f6f commit 69cba47

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/cloudai/workloads/megatron_bridge/megatron_bridge.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
from typing import List, Optional, Union, cast
1919

20-
from pydantic import Field, field_validator
20+
from pydantic import Field, ValidationInfo, field_validator
2121

2222
from cloudai.core import DockerImage, GitRepo, Installable, PythonExecutable
2323
from cloudai.models.workload import CmdArgs, TestDefinition
@@ -40,8 +40,8 @@ class MegatronBridgeCmdArgs(CmdArgs):
4040
detach: Optional[bool] = Field(default=None)
4141

4242
# Model/task
43-
model_name: str = Field(default="")
44-
model_size: str = Field(default="")
43+
model_name: str = Field(min_length=1)
44+
model_size: str = Field(min_length=1)
4545
domain: str = Field(default="llm")
4646
task: str = Field(default="pretrain")
4747
compute_dtype: str = Field(default="bf16")
@@ -88,6 +88,14 @@ def validate_hf_token(cls, v: Optional[str]) -> Optional[str]:
8888
raise ValueError("cmd_args.hf_token is required. Please set it to your literal HF token string.")
8989
return token
9090

91+
@field_validator("model_name", "model_size", mode="after")
92+
@classmethod
93+
def validate_model_fields(cls, v: str, info: ValidationInfo) -> str:
94+
s = v.strip()
95+
if not s:
96+
raise ValueError(f"cmd_args.{info.field_name} cannot be empty.")
97+
return s
98+
9199

92100
class MegatronBridgeTestDefinition(TestDefinition):
93101
"""Megatron-Bridge test definition (CloudAI-managed install + Slurm submission via launcher)."""

tests/slurm_command_gen_strategy/test_megatron_bridge_slurm_command_gen_strategy.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,21 @@ def cmd_gen(self, slurm_system: SlurmSystem, test_run: TestRun) -> MegatronBridg
8585

8686
def test_hf_token_empty_is_rejected_by_schema(self) -> None:
8787
with pytest.raises(Exception, match=r"hf_token"):
88-
MegatronBridgeCmdArgs.model_validate({"hf_token": ""})
88+
MegatronBridgeCmdArgs.model_validate({"hf_token": "", "model_name": "qwen3", "model_size": "30b_a3b"})
89+
90+
@pytest.mark.parametrize("field_name", ["model_name", "model_size"])
91+
def test_model_fields_empty_string_rejected(self, field_name: str) -> None:
92+
data = {"hf_token": "dummy_token", "model_name": "qwen3", "model_size": "30b_a3b"}
93+
data[field_name] = ""
94+
with pytest.raises(Exception, match=field_name):
95+
MegatronBridgeCmdArgs.model_validate(data)
96+
97+
@pytest.mark.parametrize("field_name", ["model_name", "model_size"])
98+
def test_model_fields_whitespace_only_rejected(self, field_name: str) -> None:
99+
data = {"hf_token": "dummy_token", "model_name": "qwen3", "model_size": "30b_a3b"}
100+
data[field_name] = " \t "
101+
with pytest.raises(Exception, match=rf"cmd_args\.{field_name} cannot be empty\."):
102+
MegatronBridgeCmdArgs.model_validate(data)
89103

90104
def test_git_repos_can_pin_megatron_bridge_commit(self) -> None:
91105
args = MegatronBridgeCmdArgs(hf_token="dummy_token", model_name="qwen3", model_size="30b_a3b")

0 commit comments

Comments
 (0)