Skip to content

Commit 17ae909

Browse files
committed
replace spec strings with Paths
1 parent 8ee2b4a commit 17ae909

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

src/vivarium_profiling/tools/run_benchmark.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,19 @@
3232
]
3333

3434

35-
def expand_model_specs(model_patterns: list[str]) -> list[str]:
35+
def expand_model_specs(model_patterns: list[str]) -> list[Path]:
3636
"""Expand glob patterns and validate model spec files."""
3737
models = []
3838
for pattern in model_patterns:
3939
expanded = glob.glob(pattern)
4040
if expanded:
4141
# Filter to only include files that exist
42-
models.extend([f for f in expanded if Path(f).is_file()])
42+
models.extend([Path(f) for f in expanded if Path(f).is_file()])
4343
else:
4444
# If no glob match, check if it's a direct file path
45-
if Path(pattern).is_file():
46-
models.append(pattern)
45+
path = Path(pattern)
46+
if path.is_file():
47+
models.append(path)
4748

4849
if not models:
4950
raise click.ClickException(
@@ -53,9 +54,9 @@ def expand_model_specs(model_patterns: list[str]) -> list[str]:
5354
return models
5455

5556

56-
def validate_baseline_model(models: list[str]) -> None:
57+
def validate_baseline_model(models: list[Path]) -> None:
5758
"""Validate that one of the model specs is the baseline."""
58-
baseline_found = any("model_spec_baseline.yaml" in model for model in models)
59+
baseline_found = "model_spec_baseline.yaml" in [model.name for model in models]
5960
if not baseline_found:
6061
raise click.ClickException(
6162
"Error: One of the model specs must be 'model_spec_baseline.yaml'."
@@ -244,7 +245,7 @@ def run_single_benchmark(
244245

245246

246247
def run_benchmark_loop(
247-
model_specs: list[str],
248+
model_specs: list[Path],
248249
model_runs: int,
249250
baseline_model_runs: int,
250251
output_dir: str = ".",
@@ -288,12 +289,12 @@ def run_benchmark_loop(
288289
for spec in model_specs:
289290
logger.info(f"Running {spec}...")
290291

291-
model_spec_name = Path(spec).stem
292+
model_spec_name = spec.stem
292293
spec_specific_results_dir = Path(results_dir) / model_spec_name
293294
spec_specific_results_dir.mkdir(parents=True, exist_ok=True)
294295

295296
# Determine number of runs
296-
if "model_spec_baseline.yaml" in spec:
297+
if spec.name == "model_spec_baseline.yaml":
297298
num_runs = baseline_model_runs
298299
else:
299300
num_runs = model_runs
@@ -302,7 +303,7 @@ def run_benchmark_loop(
302303
for run in range(1, num_runs + 1):
303304
try:
304305
results = run_single_benchmark(
305-
spec, run, num_runs, spec_specific_results_dir, model_spec_name
306+
str(spec), run, num_runs, str(spec_specific_results_dir), model_spec_name
306307
)
307308
result_df = pd.DataFrame([results])
308309
result_df.to_csv(results_file, mode="a", header=False, index=False)

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def pytest_collection_modifyitems(config, items):
2424

2525

2626
@pytest.fixture
27-
def test_model_specs(tmp_path):
27+
def test_model_specs(tmp_path) -> list[Path]:
2828
"""Create minimal test model specification files in a temporary directory."""
2929

3030
# Baseline model specification
@@ -58,4 +58,4 @@ def test_model_specs(tmp_path):
5858
baseline_file.write_text(baseline_spec)
5959
other_spec_file.write_text(other_spec)
6060

61-
return [str(baseline_file), str(other_spec_file)]
61+
return [baseline_file, other_spec_file]

tests/test_run_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@pytest.mark.slow
17-
def test_run_benchmark_loop_integration(test_model_specs: list[str], tmp_path: Path):
17+
def test_run_benchmark_loop_integration(test_model_specs: list[Path], tmp_path: Path):
1818
"""Integration test for run_benchmark_loop with minimal real model specs.
1919
2020
This test verifies that:
@@ -86,7 +86,7 @@ def test_run_benchmark_loop_integration(test_model_specs: list[str], tmp_path: P
8686
), f"Model spec directory {spec_path} should contain results"
8787

8888

89-
def test_run_benchmark_loop_validation_error(test_model_specs, tmp_path):
89+
def test_run_benchmark_loop_validation_error(test_model_specs: list[Path], tmp_path: Path):
9090
"""Test that benchmark fails appropriately when baseline model is missing."""
9191
output_dir = str(tmp_path / "validation_test")
9292
Path(output_dir).mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)