Skip to content

Commit 570181a

Browse files
authored
Merge pull request #141 from srivatsankrishnan/main
Hierarchical Test template for support Grok/GPT via PAXML
2 parents 383d512 + d18176b commit 570181a

File tree

3 files changed

+311
-55
lines changed

3 files changed

+311
-55
lines changed

src/cloudai/_core/test_parser.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717
from pathlib import Path
18-
from typing import Any, Dict, List, Set
18+
from typing import Any, Dict, List, Optional, Set
1919

2020
from .base_multi_file_parser import BaseMultiFileParser
2121
from .test import Test
@@ -47,6 +47,16 @@ def __init__(
4747
super().__init__(directory_path)
4848
self.test_template_mapping: Dict[str, TestTemplate] = test_template_mapping
4949

50+
def _extract_name_keyword(self, name: Optional[str]) -> Optional[str]:
51+
if name is None:
52+
return None
53+
lower_name = name.lower()
54+
if "grok" in lower_name:
55+
return "Grok"
56+
elif "gpt" in lower_name:
57+
return "GPT"
58+
return None
59+
5060
def _parse_data(self, data: Dict[str, Any]) -> Test:
5161
"""
5262
Parse data for a Test object.
@@ -57,6 +67,7 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
5767
Returns:
5868
Test: Parsed Test object.
5969
"""
70+
test_name = self._extract_name_keyword(data.get("name"))
6071
test_template_name = data.get("test_template_name", "")
6172
test_template = self.test_template_mapping.get(test_template_name)
6273

@@ -76,10 +87,13 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
7687
extra_cmd_args = data.get("extra_cmd_args", "")
7788

7889
flattened_template_cmd_args = self._flatten_template_dict_keys(test_template.cmd_args)
79-
self._validate_args(cmd_args, flattened_template_cmd_args)
90+
91+
# Ensure test_name is not None by providing a default value if necessary
92+
test_name_str = test_name if test_name is not None else ""
93+
self._validate_args(cmd_args, flattened_template_cmd_args, test_name_str)
8094

8195
flattened_template_env_vars = self._flatten_template_dict_keys(test_template.env_vars)
82-
self._validate_args(env_vars, flattened_template_env_vars)
96+
self._validate_args(env_vars, flattened_template_env_vars, test_name_str)
8397

8498
return Test(
8599
name=data.get("name", ""),
@@ -133,17 +147,32 @@ def _flatten_template_dict_keys(self, nested_args: Dict[str, Any], parent_key: s
133147

134148
return keys
135149

136-
def _validate_args(self, args: Dict[str, Any], valid_keys: Set[str]) -> None:
150+
def _validate_args(self, args: Dict[str, Any], valid_keys: Set[str], test_name: str) -> None:
137151
"""
138152
Validate the provided arguments against a set of valid keys.
139153
140154
Args:
141155
args (Dict[str, Any]): Arguments provided in the TOML configuration.
142156
valid_keys (Set[str]): Set of valid keys from the flattened template arguments.
157+
test_name (str): The name of the test for which arguments are being validated.
143158
144159
Raises:
145160
ValueError: If an argument is not defined in the TestTemplate's arguments.
146161
"""
147162
for arg_key in args:
148-
if arg_key not in valid_keys:
149-
raise ValueError(f"Argument '{arg_key}' is not defined in the TestTemplate's arguments.")
163+
# Check if the arg_key directly exists in valid_keys
164+
if arg_key in valid_keys:
165+
continue
166+
167+
# Check if arg_key with test_name prefix exists in valid_keys
168+
test_specific_key = f"{test_name}.{arg_key}"
169+
if test_specific_key in valid_keys:
170+
continue
171+
172+
# Check if arg_key with 'common' prefix exists in valid_keys
173+
common_key = f"common.{arg_key}"
174+
if common_key in valid_keys:
175+
continue
176+
177+
# If none of the conditions above are met, the arg_key is invalid
178+
raise ValueError(f"Argument '{arg_key}' is not defined in the TestTemplate's arguments.")

0 commit comments

Comments
 (0)