Skip to content

Commit dae6dcf

Browse files
committed
Add benchmarking script
1 parent b35fc76 commit dae6dcf

2 files changed

Lines changed: 353 additions & 0 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ htmlcov
2626

2727
# Local data and scratch
2828
.scratch
29+
benchmarks/
2930

3031
# Local virtual environment
3132
.venv

scripts/benchmark.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
# /// script
2+
# requires-python = ">=3.13"
3+
# dependencies = []
4+
# ///
5+
"""Perform benchmarking of bids2table against last tag, main and feature branches.
6+
7+
Run with:
8+
uv run --with <repo> scripts/benchmark.py -b <feature_branch> [-o <output_dir>]
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import argparse
14+
import json
15+
import logging
16+
import statistics
17+
import subprocess
18+
import sys
19+
from contextlib import contextmanager
20+
from datetime import datetime, timezone
21+
from pathlib import Path
22+
from typing import Literal, NamedTuple
23+
24+
import pytest
25+
26+
logging.basicConfig(level=logging.INFO)
27+
_logger = logging.getLogger("bids2table.benchmark")
28+
29+
30+
@contextmanager
31+
def _suppress_log_exceptions():
32+
logging.raiseExceptions = False
33+
try:
34+
yield
35+
finally:
36+
logging.raiseExceptions = True
37+
38+
39+
def _reset_logger():
40+
for h in _logger.handlers[:]:
41+
_logger.removeHandler(h)
42+
h.close()
43+
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
44+
45+
46+
class Git:
47+
"""Class to simplify git calls via subprocess."""
48+
49+
def __init__(self):
50+
"""Initialize the repository object, pulling in latest changes."""
51+
self.repo_path = self._root()
52+
self._head_ref = self._run("rev-parse", "--abbrev-ref", "HEAD")
53+
54+
def __enter__(self):
55+
if bool(self._run("status", "--porcelain")):
56+
_logger.error("Please stash or commit changes before benchmarking.")
57+
sys.exit(1)
58+
self.pull()
59+
self.submodule_update()
60+
return self
61+
62+
def __exit__(self, *_):
63+
"""On context closure, checkout the HEAD ref."""
64+
self.checkout(self._head_ref)
65+
66+
@staticmethod
67+
def _root() -> Path:
68+
result = subprocess.run(
69+
["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True
70+
)
71+
return Path(result.stdout.strip())
72+
73+
def _run(self, *args: str) -> str:
74+
result = subprocess.run(
75+
["git", "-C", str(self.repo_path), *args], capture_output=True, text=True
76+
)
77+
if result.returncode != 0:
78+
_logger.error(result.stderr.strip())
79+
sys.exit(result.returncode)
80+
return result.stdout.strip()
81+
82+
def checkout(self, ref: str) -> None:
83+
"""Checkout reference.
84+
85+
Args:
86+
ref: Reference to checkout (e.g. branch, SHA, tag)
87+
"""
88+
self._run("checkout", ref)
89+
90+
def pull(self) -> None:
91+
"""Pull from the remote repository."""
92+
self._run("pull")
93+
94+
def submodule_update(self) -> None:
95+
"""Update submodules of the repo, initializing if necessary."""
96+
self._run("submodule", "update", "--init", "--recursive")
97+
98+
def last_tag(self) -> str:
99+
"""Get last tag.
100+
101+
Returns:
102+
A string value of the last tag
103+
"""
104+
return self._run("describe", "--tags", "--abbrev=0")
105+
106+
107+
class BenchmarkResult(NamedTuple):
108+
fullname: str
109+
kind: Literal["index", "query"]
110+
locality: Literal["local", "remote"] | None = None
111+
workers: int = 1
112+
median: float = 0.0
113+
mean: float = 0.0
114+
stddev: float = 0.0
115+
116+
117+
def parse_file(path: Path) -> dict[str, BenchmarkResult]:
118+
data = json.loads(path.read_text())
119+
results = {}
120+
for benchmark in data["benchmarks"]:
121+
fullname: str = benchmark["fullname"]
122+
data_trimmed = benchmark["stats"]["data"][1:]
123+
median = statistics.median(data_trimmed)
124+
mean = statistics.mean(data_trimmed)
125+
stddev = statistics.stdev(data_trimmed)
126+
127+
if "query" in fullname:
128+
result = BenchmarkResult(
129+
fullname=fullname, kind="query", median=median, mean=mean, stddev=stddev
130+
)
131+
else:
132+
locality: Literal["local", "remote"] = (
133+
"remote" if "openneuro" in fullname or "s3" in fullname else "local"
134+
)
135+
workers = benchmark["extra_info"].get("workers", "Unknown")
136+
result = BenchmarkResult(
137+
fullname=fullname,
138+
kind="index",
139+
locality=locality,
140+
workers=workers,
141+
median=median,
142+
mean=mean,
143+
stddev=stddev,
144+
)
145+
results[fullname] = result
146+
return results
147+
148+
149+
class Value(NamedTuple):
150+
value: float
151+
factor: float
152+
unit: str
153+
154+
155+
def _scale(val: float) -> Value:
156+
if val >= 1.0:
157+
return Value(value=val, factor=1, unit="s")
158+
elif val >= 1e-3:
159+
return Value(value=val * 1e3, factor=1e3, unit="ms")
160+
else:
161+
return Value(value=val * 1e6, factor=1e6, unit="µs")
162+
163+
164+
def _fmt(res: BenchmarkResult) -> str:
165+
median = _scale(res.median)
166+
mean = res.mean * median.factor
167+
stddev = res.stddev * median.factor
168+
return f"{median.value:.3f} ({mean:.3f} ± {stddev:.3f}) {median.unit}"
169+
170+
171+
def _ratio(pr: BenchmarkResult, ref: BenchmarkResult) -> str:
172+
ratio = pr.median / ref.median
173+
icon = "🔴" if ratio > 1 else "🟢" if ratio < 1 else "⚪"
174+
return f"{icon} {ratio:.2f}"
175+
176+
177+
def _label(result: BenchmarkResult) -> str:
178+
if result.kind == "query":
179+
return (
180+
result.fullname.split("::")[-1]
181+
.replace("test_", "")
182+
.replace("_", " ")
183+
.capitalize()
184+
)
185+
return f"{result.locality.capitalize()} index ({result.workers} workers)"
186+
187+
188+
def build_table(
189+
branch_name: str,
190+
branch: dict[str, BenchmarkResult],
191+
main: dict[str, BenchmarkResult],
192+
tag: dict[str, BenchmarkResult] | None = None,
193+
) -> str:
194+
tag = tag or {}
195+
all_keys = sorted(
196+
set(branch) | set(main) | set(tag),
197+
key=lambda x: (0 if "index" in x else 1 if "query" in x else 2, x),
198+
)
199+
labels = [_label(branch.get(k) or main.get(k) or tag.get(k)) for k in all_keys]
200+
201+
col_sep = " | "
202+
header = "| |" + col_sep.join(f" **{label}** " for label in labels) + " |"
203+
divider = "|-|" + "|".join("---" for _ in all_keys) + "|"
204+
205+
def row(name: str, results: dict[str, BenchmarkResult]) -> str:
206+
cells = [_fmt(results[k]) if k in results else "—" for k in all_keys]
207+
return "| **" + name + "** |" + col_sep.join(f" {c} " for c in cells) + " |"
208+
209+
def ratio_row(label: str, ref: dict[str, BenchmarkResult]) -> str:
210+
cells = [
211+
_ratio(branch[k], ref[k]) if k in branch and k in ref else "—"
212+
for k in all_keys
213+
]
214+
return "| *" + label + "* |" + col_sep.join(f" {c} " for c in cells) + " |"
215+
216+
lines = [
217+
"## Benchmark Results",
218+
"",
219+
header,
220+
divider,
221+
row(branch_name, branch),
222+
row("main", main),
223+
divider.replace("-", ""),
224+
ratio_row(f"{branch_name} vs main ratio", main),
225+
"",
226+
"> `median (mean ± std)`",
227+
"> ",
228+
"> 🔴 Slower &nbsp; ⚪ No change &nbsp; 🟢 Faster",
229+
]
230+
return "\n".join(lines)
231+
232+
233+
def _parser() -> argparse.Namespace:
234+
parser = argparse.ArgumentParser()
235+
parser.add_argument("-b", "--branch", required=True, help="PR branch to benchmark")
236+
parser.add_argument(
237+
"-o",
238+
"--output-dir",
239+
default="benchmarks",
240+
type=Path,
241+
help="Output directory to save benchmarks to",
242+
)
243+
return parser.parse_args()
244+
245+
246+
def _sanitize(s: str) -> str:
247+
return s.replace("/", "-")
248+
249+
250+
def run_benchmark(git: Git, branch: str, out_dir: Path) -> None:
251+
"""Perform benchmarking.
252+
253+
Args:
254+
git: Representation of current git repository for benchmarking
255+
branch: Feature branch to benchmark
256+
out_dir: Output directory to save benchmarks to
257+
"""
258+
259+
tag = git.last_tag()
260+
targets = {branch: branch, "main": "main", tag: None}
261+
262+
with _suppress_log_exceptions():
263+
for name, ref in targets.items():
264+
# Skip if the reference is not provided
265+
if ref is None:
266+
continue
267+
git.checkout(ref)
268+
_reset_logger()
269+
_logger.info("Running benchmarks for '%s'", name)
270+
271+
safe_name = _sanitize(name)
272+
fname = out_dir / f"benchmark-{safe_name}.json"
273+
if fname.exists():
274+
_logger.warning(
275+
"Existing benchmarks found for %s. File will be overwritten.", fname
276+
)
277+
278+
# Run benchmark
279+
pytest.main(
280+
[
281+
"-m",
282+
"benchmark",
283+
"--benchmark-save-data",
284+
f"--benchmark-json={fname}",
285+
"--benchmark-time-unit=ms",
286+
"--benchmark-warmup=on",
287+
f"{git.repo_path}/tests",
288+
]
289+
)
290+
291+
292+
def generate_report(git: Git, branch: str, out_dir: Path) -> None:
293+
"""Generate markdown report from benchmarks.
294+
295+
Args:
296+
git: Representation of current git repository for benchmarking
297+
branch: Feature branch benchmarked
298+
out_dir: Directory benchmarks are saved to / output report to
299+
300+
Raises:
301+
AssertionError: if less than 2 benchmark files found.
302+
"""
303+
with _suppress_log_exceptions():
304+
git.checkout(branch)
305+
_reset_logger()
306+
_logger.info("Generating benchmark report")
307+
308+
files = sorted(out_dir.glob("benchmark-*.json"))
309+
if len(files) < 2:
310+
raise AssertionError(
311+
"Expected 2 or more benchmark files to perform comparisons."
312+
)
313+
314+
tag = git.last_tag()
315+
parsed: dict[str, dict[str, BenchmarkResult]] = {}
316+
for f in files:
317+
if not f.exists():
318+
_logger.warning("File %s does not exist - skipping", f)
319+
continue
320+
key = f.stem.split("-")[1]
321+
if key == tag:
322+
pass # keep as tag name
323+
elif key != "main":
324+
key = branch
325+
parsed[key] = parse_file(f)
326+
327+
if tag not in parsed:
328+
_logger.warning("Tag '%s' not found in benchmark files.", tag)
329+
330+
report_contents = build_table(
331+
branch,
332+
parsed[branch],
333+
parsed["main"],
334+
None, # parsed.get(tag)1
335+
)
336+
dt = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M")
337+
report_file = out_dir / f"benchmark-{_sanitize(branch)}-{dt}.md"
338+
report_file.write_text(report_contents)
339+
_logger.info("Report written to %s", report_file)
340+
341+
342+
def main() -> None:
343+
args = _parser()
344+
args.output_dir.mkdir(parents=True, exist_ok=True)
345+
346+
with Git() as git:
347+
run_benchmark(git=git, branch=args.branch, out_dir=args.output_dir)
348+
generate_report(git=git, branch=args.branch, out_dir=args.output_dir)
349+
350+
351+
if __name__ == "__main__":
352+
main()

0 commit comments

Comments
 (0)