1515# limitations under the License.
1616
1717from pathlib import Path
18- from typing import Any , Dict , List , Set
18+ from typing import Any , Dict , List , Optional , Set
1919
2020from .base_multi_file_parser import BaseMultiFileParser
2121from .test import Test
@@ -47,6 +47,16 @@ def __init__(
4747 super ().__init__ (directory_path )
4848 self .test_template_mapping : Dict [str , TestTemplate ] = test_template_mapping
4949
50+ def _extract_name_keyword (self , name : Optional [str ]) -> Optional [str ]:
51+ if name is None :
52+ return None
53+ lower_name = name .lower ()
54+ if "grok" in lower_name :
55+ return "Grok"
56+ elif "gpt" in lower_name :
57+ return "GPT"
58+ return None
59+
5060 def _parse_data (self , data : Dict [str , Any ]) -> Test :
5161 """
5262 Parse data for a Test object.
@@ -57,6 +67,7 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
5767 Returns:
5868 Test: Parsed Test object.
5969 """
70+ test_name = self ._extract_name_keyword (data .get ("name" ))
6071 test_template_name = data .get ("test_template_name" , "" )
6172 test_template = self .test_template_mapping .get (test_template_name )
6273
@@ -76,10 +87,13 @@ def _parse_data(self, data: Dict[str, Any]) -> Test:
7687 extra_cmd_args = data .get ("extra_cmd_args" , "" )
7788
7889 flattened_template_cmd_args = self ._flatten_template_dict_keys (test_template .cmd_args )
79- self ._validate_args (cmd_args , flattened_template_cmd_args )
90+
91+ # Ensure test_name is not None by providing a default value if necessary
92+ test_name_str = test_name if test_name is not None else ""
93+ self ._validate_args (cmd_args , flattened_template_cmd_args , test_name_str )
8094
8195 flattened_template_env_vars = self ._flatten_template_dict_keys (test_template .env_vars )
82- self ._validate_args (env_vars , flattened_template_env_vars )
96+ self ._validate_args (env_vars , flattened_template_env_vars , test_name_str )
8397
8498 return Test (
8599 name = data .get ("name" , "" ),
@@ -133,17 +147,32 @@ def _flatten_template_dict_keys(self, nested_args: Dict[str, Any], parent_key: s
133147
134148 return keys
135149
136- def _validate_args (self , args : Dict [str , Any ], valid_keys : Set [str ]) -> None :
150+ def _validate_args (self , args : Dict [str , Any ], valid_keys : Set [str ], test_name : str ) -> None :
137151 """
138152 Validate the provided arguments against a set of valid keys.
139153
140154 Args:
141155 args (Dict[str, Any]): Arguments provided in the TOML configuration.
142156 valid_keys (Set[str]): Set of valid keys from the flattened template arguments.
157+ test_name (str): The name of the test for which arguments are being validated.
143158
144159 Raises:
145160 ValueError: If an argument is not defined in the TestTemplate's arguments.
146161 """
147162 for arg_key in args :
148- if arg_key not in valid_keys :
149- raise ValueError (f"Argument '{ arg_key } ' is not defined in the TestTemplate's arguments." )
163+ # Check if the arg_key directly exists in valid_keys
164+ if arg_key in valid_keys :
165+ continue
166+
167+ # Check if arg_key with test_name prefix exists in valid_keys
168+ test_specific_key = f"{ test_name } .{ arg_key } "
169+ if test_specific_key in valid_keys :
170+ continue
171+
172+ # Check if arg_key with 'common' prefix exists in valid_keys
173+ common_key = f"common.{ arg_key } "
174+ if common_key in valid_keys :
175+ continue
176+
177+ # If none of the conditions above are met, the arg_key is invalid
178+ raise ValueError (f"Argument '{ arg_key } ' is not defined in the TestTemplate's arguments." )
0 commit comments