Skip to content

Commit d4a8629

Browse files
committed
move expand_model_spec
1 parent fc52936 commit d4a8629

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

src/vivarium_profiling/tools/cli.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from vivarium_profiling.constants import metadata, paths
66
from vivarium_profiling.tools import build_artifacts, configure_logging_to_terminal
7-
from vivarium_profiling.tools.run_benchmark import (
8-
expand_model_specs,
9-
run_benchmark_loop,
10-
)
7+
from vivarium_profiling.tools.run_benchmark import run_benchmark_loop
8+
import glob
9+
from pathlib import Path
1110

1211

1312
@click.command()
@@ -117,8 +116,30 @@ def run_benchmark(
117116
run_benchmark -m "model_spec_baseline.yaml" -m "model_spec_*.yaml" -r 10 -b 20
118117
"""
119118
# Expand model patterns
120-
model_specifications = expand_model_specs(list(model_specifications))
119+
model_specifications = _expand_model_specs(list(model_specifications))
121120

122121
# Run benchmarks with error handling
123122
main = handle_exceptions(run_benchmark_loop, logger, with_debugger=with_debugger)
124123
main(model_specifications, model_runs, baseline_model_runs, output_dir, verbose)
124+
125+
126+
def _expand_model_specs(model_patterns: list[str]) -> list[Path]:
127+
"""Expand glob patterns and validate model spec files."""
128+
models = []
129+
for pattern in model_patterns:
130+
expanded = glob.glob(pattern)
131+
if expanded:
132+
# Filter to only include files that exist
133+
models.extend([Path(f) for f in expanded if Path(f).is_file()])
134+
else:
135+
# If no glob match, check if it's a direct file path
136+
path = Path(pattern)
137+
if path.is_file():
138+
models.append(path)
139+
140+
if not models:
141+
raise click.ClickException(
142+
f"No model specification files found for patterns: {model_patterns}"
143+
)
144+
145+
return models

src/vivarium_profiling/tools/run_benchmark.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,6 @@
3232
]
3333

3434

35-
def expand_model_specs(model_patterns: list[str]) -> list[Path]:
36-
"""Expand glob patterns and validate model spec files."""
37-
models = []
38-
for pattern in model_patterns:
39-
expanded = glob.glob(pattern)
40-
if expanded:
41-
# Filter to only include files that exist
42-
models.extend([Path(f) for f in expanded if Path(f).is_file()])
43-
else:
44-
# If no glob match, check if it's a direct file path
45-
path = Path(pattern)
46-
if path.is_file():
47-
models.append(path)
48-
49-
if not models:
50-
raise click.ClickException(
51-
f"No model specification files found for patterns: {model_patterns}"
52-
)
53-
54-
return models
55-
56-
5735
def validate_baseline_model(models: list[Path]) -> None:
5836
"""Validate that one of the model specs is the baseline."""
5937
baseline_found = "model_spec_baseline.yaml" in [model.name for model in models]

0 commit comments

Comments
 (0)