Skip to content

Commit bc3b0bb

Browse files
committed
Adjust how model specs are passed in
1 parent 2ab644a commit bc3b0bb

File tree

4 files changed

+45
-52
lines changed

4 files changed

+45
-52
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ depending on the profiler backend provided. By default, runtime profiling is per
178178
you can also use ``scalene`` for more detailed call stack analysis.
179179

180180
The ``run_benchmark`` command runs multiple iterations of one or more model specification, in order to compare
181-
the results. It requires at least one baseline model (specified as ``model_spec_baseline.yaml``) for comparison,
181+
the results. It requires at least one baseline model for comparison,
182182
and any other number of 'experiment' models to benchmark against the baseline, which can be passed via glob patterns.
183183
You can separately configure the sample size of runs for the baseline and experiment models. The command aggregates
184184
the profiling results and generates summary statistics and visualizations for a default set of important function calls

src/vivarium_profiling/tools/cli.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,16 @@ def make_artifacts(
200200

201201

202202
@click.command()
203+
@click.argument(
204+
"model_specification",
205+
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
206+
)
203207
@click.option(
204-
"-m",
205-
"--model_specifications",
208+
"-a",
209+
"--additional-model-specifications",
206210
multiple=True,
207-
required=True,
208-
help="Model specification files (supports glob patterns). Can be specified multiple times.",
211+
type=click.Path(exists=True, dir_okay=False, resolve_path=True),
212+
help="Additional model specification files (supports glob patterns). Can be specified multiple times.",
209213
)
210214
@click.option(
211215
"-r",
@@ -217,8 +221,9 @@ def make_artifacts(
217221
@click.option(
218222
"-b",
219223
"--baseline-model-runs",
224+
default=3,
225+
show_default=True,
220226
type=int,
221-
required=True,
222227
help="Number of runs for baseline model.",
223228
)
224229
@click.option(
@@ -242,7 +247,8 @@ def make_artifacts(
242247
help="Drop into python debugger if an error occurs.",
243248
)
244249
def run_benchmark(
245-
model_specifications: tuple[str, ...],
250+
model_specification: str,
251+
additional_model_specifications: tuple[str, ...],
246252
model_runs: int,
247253
baseline_model_runs: int,
248254
output_dir: str,
@@ -255,11 +261,21 @@ def run_benchmark(
255261
This command profiles multiple model specifications and collects runtime
256262
and memory usage statistics. Results are saved to a timestamped CSV file.
257263
264+
The baseline model specification is provided as a positional argument.
265+
Additional model specifications can be provided with -a.
266+
258267
Example usage:
259-
run_benchmark -m "model_spec_baseline.yaml" -m "model_spec_*.yaml" -r 10 -b 20
268+
run_benchmark model_spec_baseline.yaml -b 20
269+
270+
run_benchmark model_spec_baseline.yaml -a model_spec_2x.yaml -a model_spec_4x.yaml -r 10 -b 20
260271
"""
261-
# Expand model patterns
262-
model_specifications = _expand_model_specs(list(model_specifications))
272+
configure_logging_to_terminal(verbose)
273+
274+
baseline_path = Path(model_specification)
275+
276+
# Expand additional model specs (supporting glob patterns)
277+
additional_paths = _expand_model_specs(list(additional_model_specifications))
278+
model_specifications = [str(baseline_path)] + [str(p) for p in additional_paths]
263279

264280
# Run benchmarks with error handling
265281
main = handle_exceptions(run_benchmark_loop, logger, with_debugger=with_debugger)
@@ -274,7 +290,22 @@ def run_benchmark(
274290

275291

276292
def _expand_model_specs(model_patterns: list[str]) -> list[Path]:
277-
"""Expand glob patterns and validate model spec files."""
293+
"""Expand glob patterns and validate model spec files.
294+
295+
Parameters
296+
----------
297+
model_patterns
298+
List of file paths or glob patterns.
299+
300+
Returns
301+
-------
302+
List of resolved Path objects for existing files. Returns empty list
303+
if no patterns provided.
304+
305+
"""
306+
if not model_patterns:
307+
return []
308+
278309
models = []
279310
for pattern in model_patterns:
280311
expanded = glob.glob(pattern)
@@ -287,11 +318,6 @@ def _expand_model_specs(model_patterns: list[str]) -> list[Path]:
287318
if path.is_file():
288319
models.append(path)
289320

290-
if not models:
291-
raise click.ClickException(
292-
f"No model specification files found for patterns: {model_patterns}"
293-
)
294-
295321
return models
296322

297323

src/vivarium_profiling/tools/run_benchmark.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,6 @@
2222
RESULTS_SUMMARY_NAME = "benchmark_results.csv"
2323

2424

25-
def validate_baseline_model(models: list[Path]) -> None:
26-
"""Validate that one of the model specs is the baseline."""
27-
baseline_found = "model_spec_baseline.yaml" in [model.name for model in models]
28-
if not baseline_found:
29-
raise click.ClickException(
30-
"Error: One of the model specs must be 'model_spec_baseline.yaml'."
31-
)
32-
33-
3425
def create_results_directory(output_dir: str = ".") -> str:
3526
"""Create a timestamped results directory."""
3627
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
@@ -180,9 +171,6 @@ def run_benchmark_loop(
180171

181172
configure_logging_to_terminal(verbose)
182173

183-
# Validate inputs
184-
validate_baseline_model(model_specifications)
185-
186174
# Create results directory and initialize results file
187175
results_dir = create_results_directory(output_dir)
188176
results_file = initialize_results_file(results_dir, config)
@@ -193,15 +181,15 @@ def run_benchmark_loop(
193181
logger.info(f" Results Directory: {results_dir}")
194182

195183
# Run benchmarks for each specification
196-
for spec in model_specifications:
184+
for i, spec in enumerate(model_specifications):
197185
logger.info(f"Running {spec}...")
198186

199187
model_spec_name = spec.stem
200188
spec_specific_results_dir = Path(results_dir) / model_spec_name
201189
spec_specific_results_dir.mkdir(parents=True, exist_ok=True)
202190

203-
# Determine number of runs
204-
if spec.name == "model_spec_baseline.yaml":
191+
# Determine number of runs - first spec is baseline
192+
if i == 0:
205193
num_runs = baseline_model_runs
206194
else:
207195
num_runs = model_runs

tests/test_run_benchmark.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,3 @@ def test_run_benchmark_loop_integration(test_model_specs: list[Path], tmp_path:
104104
assert (
105105
len(spec_contents) > 0
106106
), f"Model spec directory {spec_path} should contain results"
107-
108-
109-
def test_run_benchmark_loop_validation_error(test_model_specs: list[Path], tmp_path: Path):
110-
"""Test that benchmark fails appropriately when baseline model is missing."""
111-
output_dir = str(tmp_path / "validation_test")
112-
Path(output_dir).mkdir(parents=True, exist_ok=True)
113-
114-
# Try to run without baseline model - should raise exception
115-
model_specs = test_model_specs[1:]
116-
117-
with pytest.raises(
118-
click.ClickException,
119-
match="Error: One of the model specs must be 'model_spec_baseline.yaml'.",
120-
): # Should raise ClickException about missing baseline
121-
run_benchmark_loop(
122-
model_specifications=model_specs,
123-
model_runs=2,
124-
baseline_model_runs=2,
125-
output_dir=output_dir,
126-
verbose=0,
127-
)

0 commit comments

Comments
 (0)