3232]
3333
3434
35- def expand_model_specs (model_patterns : list [str ]) -> list [str ]:
35+ def expand_model_specs (model_patterns : list [str ]) -> list [Path ]:
3636 """Expand glob patterns and validate model spec files."""
3737 models = []
3838 for pattern in model_patterns :
3939 expanded = glob .glob (pattern )
4040 if expanded :
4141 # Filter to only include files that exist
42- models .extend ([f for f in expanded if Path (f ).is_file ()])
42+ models .extend ([Path ( f ) for f in expanded if Path (f ).is_file ()])
4343 else :
4444 # If no glob match, check if it's a direct file path
45- if Path (pattern ).is_file ():
46- models .append (pattern )
45+ path = Path (pattern )
46+ if path .is_file ():
47+ models .append (path )
4748
4849 if not models :
4950 raise click .ClickException (
@@ -53,9 +54,9 @@ def expand_model_specs(model_patterns: list[str]) -> list[str]:
5354 return models
5455
5556
56- def validate_baseline_model (models : list [str ]) -> None :
57+ def validate_baseline_model (models : list [Path ]) -> None :
5758 """Validate that one of the model specs is the baseline."""
58- baseline_found = any ( "model_spec_baseline.yaml" in model for model in models )
59+ baseline_found = "model_spec_baseline.yaml" in [ model . name for model in models ]
5960 if not baseline_found :
6061 raise click .ClickException (
6162 "Error: One of the model specs must be 'model_spec_baseline.yaml'."
@@ -244,7 +245,7 @@ def run_single_benchmark(
244245
245246
246247def run_benchmark_loop (
247- model_specs : list [str ],
248+ model_specs : list [Path ],
248249 model_runs : int ,
249250 baseline_model_runs : int ,
250251 output_dir : str = "." ,
@@ -288,12 +289,12 @@ def run_benchmark_loop(
288289 for spec in model_specs :
289290 logger .info (f"Running { spec } ..." )
290291
291- model_spec_name = Path ( spec ) .stem
292+ model_spec_name = spec .stem
292293 spec_specific_results_dir = Path (results_dir ) / model_spec_name
293294 spec_specific_results_dir .mkdir (parents = True , exist_ok = True )
294295
295296 # Determine number of runs
296- if "model_spec_baseline.yaml" in spec :
297+ if spec . name == "model_spec_baseline.yaml" :
297298 num_runs = baseline_model_runs
298299 else :
299300 num_runs = model_runs
@@ -302,7 +303,7 @@ def run_benchmark_loop(
302303 for run in range (1 , num_runs + 1 ):
303304 try :
304305 results = run_single_benchmark (
305- spec , run , num_runs , spec_specific_results_dir , model_spec_name
306+ str ( spec ) , run , num_runs , str ( spec_specific_results_dir ) , model_spec_name
306307 )
307308 result_df = pd .DataFrame ([results ])
308309 result_df .to_csv (results_file , mode = "a" , header = False , index = False )
0 commit comments