|
4 | 4 |
|
5 | 5 | from vivarium_profiling.constants import metadata, paths |
6 | 6 | from vivarium_profiling.tools import build_artifacts, configure_logging_to_terminal |
7 | | -from vivarium_profiling.tools.run_benchmark import ( |
8 | | - expand_model_specs, |
9 | | - run_benchmark_loop, |
10 | | -) |
| 7 | +from vivarium_profiling.tools.run_benchmark import run_benchmark_loop |
| 8 | +import glob |
| 9 | +from pathlib import Path |
11 | 10 |
|
12 | 11 |
|
13 | 12 | @click.command() |
@@ -117,8 +116,30 @@ def run_benchmark( |
117 | 116 | run_benchmark -m "model_spec_baseline.yaml" -m "model_spec_*.yaml" -r 10 -b 20 |
118 | 117 | """ |
119 | 118 | # Expand model patterns |
120 | | - model_specifications = expand_model_specs(list(model_specifications)) |
| 119 | + model_specifications = _expand_model_specs(list(model_specifications)) |
121 | 120 |
|
122 | 121 | # Run benchmarks with error handling |
123 | 122 | main = handle_exceptions(run_benchmark_loop, logger, with_debugger=with_debugger) |
124 | 123 | main(model_specifications, model_runs, baseline_model_runs, output_dir, verbose) |
| 124 | + |
| 125 | + |
| 126 | +def _expand_model_specs(model_patterns: list[str]) -> list[Path]: |
| 127 | + """Expand glob patterns and validate model spec files.""" |
| 128 | + models = [] |
| 129 | + for pattern in model_patterns: |
| 130 | + expanded = glob.glob(pattern) |
| 131 | + if expanded: |
| 132 | + # Filter to only include files that exist |
| 133 | + models.extend([Path(f) for f in expanded if Path(f).is_file()]) |
| 134 | + else: |
| 135 | + # If no glob match, check if it's a direct file path |
| 136 | + path = Path(pattern) |
| 137 | + if path.is_file(): |
| 138 | + models.append(path) |
| 139 | + |
| 140 | + if not models: |
| 141 | + raise click.ClickException( |
| 142 | + f"No model specification files found for patterns: {model_patterns}" |
| 143 | + ) |
| 144 | + |
| 145 | + return models |
0 commit comments