|
| 1 | +# /// script |
| 2 | +# requires-python = ">=3.13" |
| 3 | +# dependencies = [] |
| 4 | +# /// |
| 5 | +"""Perform benchmarking of bids2table against last tag, main and feature branches. |
| 6 | +
|
| 7 | +Run with: |
| 8 | + uv run --with <repo> scripts/benchmark.py -b <feature_branch> [-o <output_dir>] |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import argparse |
| 14 | +import json |
| 15 | +import logging |
| 16 | +import statistics |
| 17 | +import subprocess |
| 18 | +import sys |
| 19 | +from contextlib import contextmanager |
| 20 | +from datetime import datetime, timezone |
| 21 | +from pathlib import Path |
| 22 | +from typing import Literal, NamedTuple |
| 23 | + |
| 24 | +import pytest |
| 25 | + |
| 26 | +logging.basicConfig(level=logging.INFO) |
| 27 | +_logger = logging.getLogger("bids2table.benchmark") |
| 28 | + |
| 29 | + |
| 30 | +@contextmanager |
| 31 | +def _suppress_log_exceptions(): |
| 32 | + logging.raiseExceptions = False |
| 33 | + try: |
| 34 | + yield |
| 35 | + finally: |
| 36 | + logging.raiseExceptions = True |
| 37 | + |
| 38 | + |
| 39 | +def _reset_logger(): |
| 40 | + for h in _logger.handlers[:]: |
| 41 | + _logger.removeHandler(h) |
| 42 | + h.close() |
| 43 | + logging.basicConfig(stream=sys.stderr, level=logging.INFO) |
| 44 | + |
| 45 | + |
| 46 | +class Git: |
| 47 | + """Class to simplify git calls via subprocess.""" |
| 48 | + |
| 49 | + def __init__(self): |
| 50 | + """Initialize the repository object, pulling in latest changes.""" |
| 51 | + self.repo_path = self._root() |
| 52 | + self._head_ref = self._run("rev-parse", "--abbrev-ref", "HEAD") |
| 53 | + |
| 54 | + def __enter__(self): |
| 55 | + if bool(self._run("status", "--porcelain")): |
| 56 | + _logger.error("Please stash or commit changes before benchmarking.") |
| 57 | + sys.exit(1) |
| 58 | + self.pull() |
| 59 | + self.submodule_update() |
| 60 | + return self |
| 61 | + |
| 62 | + def __exit__(self, *_): |
| 63 | + """On context closure, checkout the HEAD ref.""" |
| 64 | + self.checkout(self._head_ref) |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def _root() -> Path: |
| 68 | + result = subprocess.run( |
| 69 | + ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True |
| 70 | + ) |
| 71 | + return Path(result.stdout.strip()) |
| 72 | + |
| 73 | + def _run(self, *args: str) -> str: |
| 74 | + result = subprocess.run( |
| 75 | + ["git", "-C", str(self.repo_path), *args], capture_output=True, text=True |
| 76 | + ) |
| 77 | + if result.returncode != 0: |
| 78 | + _logger.error(result.stderr.strip()) |
| 79 | + sys.exit(result.returncode) |
| 80 | + return result.stdout.strip() |
| 81 | + |
| 82 | + def checkout(self, ref: str) -> None: |
| 83 | + """Checkout reference. |
| 84 | +
|
| 85 | + Args: |
| 86 | + ref: Reference to checkout (e.g. branch, SHA, tag) |
| 87 | + """ |
| 88 | + self._run("checkout", ref) |
| 89 | + |
| 90 | + def pull(self) -> None: |
| 91 | + """Pull from the remote repository.""" |
| 92 | + self._run("pull") |
| 93 | + |
| 94 | + def submodule_update(self) -> None: |
| 95 | + """Update submodules of the repo, initializing if necessary.""" |
| 96 | + self._run("submodule", "update", "--init", "--recursive") |
| 97 | + |
| 98 | + def last_tag(self) -> str: |
| 99 | + """Get last tag. |
| 100 | +
|
| 101 | + Returns: |
| 102 | + A string value of the last tag |
| 103 | + """ |
| 104 | + return self._run("describe", "--tags", "--abbrev=0") |
| 105 | + |
| 106 | + |
| 107 | +class BenchmarkResult(NamedTuple): |
| 108 | + fullname: str |
| 109 | + kind: Literal["index", "query"] |
| 110 | + locality: Literal["local", "remote"] | None = None |
| 111 | + workers: int = 1 |
| 112 | + median: float = 0.0 |
| 113 | + mean: float = 0.0 |
| 114 | + stddev: float = 0.0 |
| 115 | + |
| 116 | + |
| 117 | +def parse_file(path: Path) -> dict[str, BenchmarkResult]: |
| 118 | + data = json.loads(path.read_text()) |
| 119 | + results = {} |
| 120 | + for benchmark in data["benchmarks"]: |
| 121 | + fullname: str = benchmark["fullname"] |
| 122 | + data_trimmed = benchmark["stats"]["data"][1:] |
| 123 | + median = statistics.median(data_trimmed) |
| 124 | + mean = statistics.mean(data_trimmed) |
| 125 | + stddev = statistics.stdev(data_trimmed) |
| 126 | + |
| 127 | + if "query" in fullname: |
| 128 | + result = BenchmarkResult( |
| 129 | + fullname=fullname, kind="query", median=median, mean=mean, stddev=stddev |
| 130 | + ) |
| 131 | + else: |
| 132 | + locality: Literal["local", "remote"] = ( |
| 133 | + "remote" if "openneuro" in fullname or "s3" in fullname else "local" |
| 134 | + ) |
| 135 | + workers = benchmark["extra_info"].get("workers", "Unknown") |
| 136 | + result = BenchmarkResult( |
| 137 | + fullname=fullname, |
| 138 | + kind="index", |
| 139 | + locality=locality, |
| 140 | + workers=workers, |
| 141 | + median=median, |
| 142 | + mean=mean, |
| 143 | + stddev=stddev, |
| 144 | + ) |
| 145 | + results[fullname] = result |
| 146 | + return results |
| 147 | + |
| 148 | + |
| 149 | +class Value(NamedTuple): |
| 150 | + value: float |
| 151 | + factor: float |
| 152 | + unit: str |
| 153 | + |
| 154 | + |
| 155 | +def _scale(val: float) -> Value: |
| 156 | + if val >= 1.0: |
| 157 | + return Value(value=val, factor=1, unit="s") |
| 158 | + elif val >= 1e-3: |
| 159 | + return Value(value=val * 1e3, factor=1e3, unit="ms") |
| 160 | + else: |
| 161 | + return Value(value=val * 1e6, factor=1e6, unit="µs") |
| 162 | + |
| 163 | + |
| 164 | +def _fmt(res: BenchmarkResult) -> str: |
| 165 | + median = _scale(res.median) |
| 166 | + mean = res.mean * median.factor |
| 167 | + stddev = res.stddev * median.factor |
| 168 | + return f"{median.value:.3f} ({mean:.3f} ± {stddev:.3f}) {median.unit}" |
| 169 | + |
| 170 | + |
| 171 | +def _ratio(pr: BenchmarkResult, ref: BenchmarkResult) -> str: |
| 172 | + ratio = pr.median / ref.median |
| 173 | + icon = "🔴" if ratio > 1 else "🟢" if ratio < 1 else "⚪" |
| 174 | + return f"{icon} {ratio:.2f}" |
| 175 | + |
| 176 | + |
| 177 | +def _label(result: BenchmarkResult) -> str: |
| 178 | + if result.kind == "query": |
| 179 | + return ( |
| 180 | + result.fullname.split("::")[-1] |
| 181 | + .replace("test_", "") |
| 182 | + .replace("_", " ") |
| 183 | + .capitalize() |
| 184 | + ) |
| 185 | + return f"{result.locality.capitalize()} index ({result.workers} workers)" |
| 186 | + |
| 187 | + |
| 188 | +def build_table( |
| 189 | + branch_name: str, |
| 190 | + branch: dict[str, BenchmarkResult], |
| 191 | + main: dict[str, BenchmarkResult], |
| 192 | + tag: dict[str, BenchmarkResult] | None = None, |
| 193 | +) -> str: |
| 194 | + tag = tag or {} |
| 195 | + all_keys = sorted( |
| 196 | + set(branch) | set(main) | set(tag), |
| 197 | + key=lambda x: (0 if "index" in x else 1 if "query" in x else 2, x), |
| 198 | + ) |
| 199 | + labels = [_label(branch.get(k) or main.get(k) or tag.get(k)) for k in all_keys] |
| 200 | + |
| 201 | + col_sep = " | " |
| 202 | + header = "| |" + col_sep.join(f" **{label}** " for label in labels) + " |" |
| 203 | + divider = "|-|" + "|".join("---" for _ in all_keys) + "|" |
| 204 | + |
| 205 | + def row(name: str, results: dict[str, BenchmarkResult]) -> str: |
| 206 | + cells = [_fmt(results[k]) if k in results else "—" for k in all_keys] |
| 207 | + return "| **" + name + "** |" + col_sep.join(f" {c} " for c in cells) + " |" |
| 208 | + |
| 209 | + def ratio_row(label: str, ref: dict[str, BenchmarkResult]) -> str: |
| 210 | + cells = [ |
| 211 | + _ratio(branch[k], ref[k]) if k in branch and k in ref else "—" |
| 212 | + for k in all_keys |
| 213 | + ] |
| 214 | + return "| *" + label + "* |" + col_sep.join(f" {c} " for c in cells) + " |" |
| 215 | + |
| 216 | + lines = [ |
| 217 | + "## Benchmark Results", |
| 218 | + "", |
| 219 | + header, |
| 220 | + divider, |
| 221 | + row(branch_name, branch), |
| 222 | + row("main", main), |
| 223 | + divider.replace("-", ""), |
| 224 | + ratio_row(f"{branch_name} vs main ratio", main), |
| 225 | + "", |
| 226 | + "> `median (mean ± std)`", |
| 227 | + "> ", |
| 228 | + "> 🔴 Slower ⚪ No change 🟢 Faster", |
| 229 | + ] |
| 230 | + return "\n".join(lines) |
| 231 | + |
| 232 | + |
| 233 | +def _parser() -> argparse.Namespace: |
| 234 | + parser = argparse.ArgumentParser() |
| 235 | + parser.add_argument("-b", "--branch", required=True, help="PR branch to benchmark") |
| 236 | + parser.add_argument( |
| 237 | + "-o", |
| 238 | + "--output-dir", |
| 239 | + default="benchmarks", |
| 240 | + type=Path, |
| 241 | + help="Output directory to save benchmarks to", |
| 242 | + ) |
| 243 | + return parser.parse_args() |
| 244 | + |
| 245 | + |
| 246 | +def _sanitize(s: str) -> str: |
| 247 | + return s.replace("/", "-") |
| 248 | + |
| 249 | + |
| 250 | +def run_benchmark(git: Git, branch: str, out_dir: Path) -> None: |
| 251 | + """Perform benchmarking. |
| 252 | +
|
| 253 | + Args: |
| 254 | + git: Representation of current git repository for benchmarking |
| 255 | + branch: Feature branch to benchmark |
| 256 | + out_dir: Output directory to save benchmarks to |
| 257 | + """ |
| 258 | + |
| 259 | + tag = git.last_tag() |
| 260 | + targets = {branch: branch, "main": "main", tag: None} |
| 261 | + |
| 262 | + with _suppress_log_exceptions(): |
| 263 | + for name, ref in targets.items(): |
| 264 | + # Skip if the reference is not provided |
| 265 | + if ref is None: |
| 266 | + continue |
| 267 | + git.checkout(ref) |
| 268 | + _reset_logger() |
| 269 | + _logger.info("Running benchmarks for '%s'", name) |
| 270 | + |
| 271 | + safe_name = _sanitize(name) |
| 272 | + fname = out_dir / f"benchmark-{safe_name}.json" |
| 273 | + if fname.exists(): |
| 274 | + _logger.warning( |
| 275 | + "Existing benchmarks found for %s. File will be overwritten.", fname |
| 276 | + ) |
| 277 | + |
| 278 | + # Run benchmark |
| 279 | + pytest.main( |
| 280 | + [ |
| 281 | + "-m", |
| 282 | + "benchmark", |
| 283 | + "--benchmark-save-data", |
| 284 | + f"--benchmark-json={fname}", |
| 285 | + "--benchmark-time-unit=ms", |
| 286 | + "--benchmark-warmup=on", |
| 287 | + f"{git.repo_path}/tests", |
| 288 | + ] |
| 289 | + ) |
| 290 | + |
| 291 | + |
| 292 | +def generate_report(git: Git, branch: str, out_dir: Path) -> None: |
| 293 | + """Generate markdown report from benchmarks. |
| 294 | +
|
| 295 | + Args: |
| 296 | + git: Representation of current git repository for benchmarking |
| 297 | + branch: Feature branch benchmarked |
| 298 | + out_dir: Directory benchmarks are saved to / output report to |
| 299 | +
|
| 300 | + Raises: |
| 301 | + AssertionError: if less than 2 benchmark files found. |
| 302 | + """ |
| 303 | + with _suppress_log_exceptions(): |
| 304 | + git.checkout(branch) |
| 305 | + _reset_logger() |
| 306 | + _logger.info("Generating benchmark report") |
| 307 | + |
| 308 | + files = sorted(out_dir.glob("benchmark-*.json")) |
| 309 | + if len(files) < 2: |
| 310 | + raise AssertionError( |
| 311 | + "Expected 2 or more benchmark files to perform comparisons." |
| 312 | + ) |
| 313 | + |
| 314 | + tag = git.last_tag() |
| 315 | + parsed: dict[str, dict[str, BenchmarkResult]] = {} |
| 316 | + for f in files: |
| 317 | + if not f.exists(): |
| 318 | + _logger.warning("File %s does not exist - skipping", f) |
| 319 | + continue |
| 320 | + key = f.stem.split("-")[1] |
| 321 | + if key == tag: |
| 322 | + pass # keep as tag name |
| 323 | + elif key != "main": |
| 324 | + key = branch |
| 325 | + parsed[key] = parse_file(f) |
| 326 | + |
| 327 | + if tag not in parsed: |
| 328 | + _logger.warning("Tag '%s' not found in benchmark files.", tag) |
| 329 | + |
| 330 | + report_contents = build_table( |
| 331 | + branch, |
| 332 | + parsed[branch], |
| 333 | + parsed["main"], |
| 334 | + None, # parsed.get(tag)1 |
| 335 | + ) |
| 336 | + dt = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M") |
| 337 | + report_file = out_dir / f"benchmark-{_sanitize(branch)}-{dt}.md" |
| 338 | + report_file.write_text(report_contents) |
| 339 | + _logger.info("Report written to %s", report_file) |
| 340 | + |
| 341 | + |
| 342 | +def main() -> None: |
| 343 | + args = _parser() |
| 344 | + args.output_dir.mkdir(parents=True, exist_ok=True) |
| 345 | + |
| 346 | + with Git() as git: |
| 347 | + run_benchmark(git=git, branch=args.branch, out_dir=args.output_dir) |
| 348 | + generate_report(git=git, branch=args.branch, out_dir=args.output_dir) |
| 349 | + |
| 350 | + |
| 351 | +if __name__ == "__main__": |
| 352 | + main() |
0 commit comments