Skip to content

Commit 8882a00

Browse files
authored
Merge pull request #7 from kaitj/ci/benchmark
Ci/benchmark
2 parents 2543599 + 33c5f3a commit 8882a00

8 files changed

Lines changed: 588 additions & 78 deletions

File tree

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#!/usr/bin/env python
2+
"""Compare benchmark results across PR, main, and tag and output a markdown table."""
3+
4+
import json
5+
import statistics
6+
from pathlib import Path
7+
from typing import Literal, NamedTuple
8+
9+
10+
class BenchmarkResult(NamedTuple):
11+
fullname: str
12+
kind: Literal["index", "query"]
13+
locality: Literal["local", "remote"] | None = None
14+
workers: int | None = None
15+
median: float = 0.0
16+
mean: float = 0.0
17+
stddev: float = 0.0
18+
19+
20+
def parse_file(path: Path) -> dict[str, BenchmarkResult]:
21+
data = json.loads(path.read_text())
22+
results = {}
23+
for benchmark in data["benchmarks"]:
24+
fullname: str = benchmark["fullname"]
25+
data_trimmed = benchmark["stats"]["data"][1:]
26+
median = statistics.median(data_trimmed)
27+
mean = statistics.mean(data_trimmed)
28+
stddev = statistics.stdev(data_trimmed)
29+
30+
if "query" in fullname:
31+
result = BenchmarkResult(
32+
fullname=fullname, kind="query", median=median, mean=mean, stddev=stddev
33+
)
34+
else:
35+
locality: Literal["local", "remote"] = (
36+
"remote" if "openneuro" in fullname or "s3" in fullname else "local"
37+
)
38+
workers = benchmark["extra_info"].get("workers", "Unknown")
39+
result = BenchmarkResult(
40+
fullname=fullname,
41+
kind="index",
42+
locality=locality,
43+
workers=workers,
44+
median=median,
45+
mean=mean,
46+
stddev=stddev,
47+
)
48+
results[fullname] = result
49+
return results
50+
51+
52+
def _scale(val: float) -> float:
53+
return val * 1000
54+
55+
56+
def _fmt(res: BenchmarkResult) -> str:
57+
median = _scale(res.median)
58+
mean = _scale(res.mean)
59+
stddev = _scale(res.stddev)
60+
return f"{median:.3f} ({mean:.3f} ± {stddev:.3f}) ms"
61+
62+
63+
def _delta(pr: BenchmarkResult, ref: BenchmarkResult) -> str:
64+
if ref == 0:
65+
return "N/A"
66+
diff = _scale(pr.median - ref.median)
67+
pct = (pr.median / ref.median - 1) * 100
68+
icon = "🔴" if pct > 5 else "🟢" if pct < -5 else "⚪"
69+
return f"{icon} {diff:+.3f} ms ({pct:+.1f}%)"
70+
71+
72+
def _label(result: BenchmarkResult) -> str:
73+
if result.kind == "query":
74+
return (
75+
result.fullname.split("::")[-1]
76+
.replace("test_", "")
77+
.replace("_", " ")
78+
.capitalize()
79+
)
80+
return f"{result.locality.capitalize()} index ({result.workers} workers)"
81+
82+
83+
def build_table(
84+
pr: dict[str, BenchmarkResult],
85+
main: dict[str, BenchmarkResult],
86+
tag: dict[str, BenchmarkResult],
87+
tag_name: str,
88+
) -> str:
89+
all_keys = set(pr) | set(main) | set(tag)
90+
labels = [_label((pr.get(k) or main.get(k) or tag.get(k))) for k in all_keys]
91+
92+
col_sep = " | "
93+
header = "| |" + col_sep.join(f" **{label}** " for label in labels) + " |"
94+
divider = "|-|" + "|".join("---" for _ in all_keys) + "|"
95+
96+
def row(name: str, results: dict[str, BenchmarkResult]) -> str:
97+
cells = [_fmt(results[k]) if k in results else "—" for k in all_keys]
98+
return "| **" + name + "** |" + col_sep.join(f" {c} " for c in cells) + " |"
99+
100+
def delta_row(label: str, ref: dict[str, BenchmarkResult]) -> str:
101+
cells = [
102+
_delta(pr[k], ref[k]) if k in pr and k in ref else "—" for k in all_keys
103+
]
104+
return "| *" + label + "* |" + col_sep.join(f" {c} " for c in cells) + " |"
105+
106+
lines = [
107+
"## Benchmark Results",
108+
"",
109+
header,
110+
divider,
111+
row("PR", pr),
112+
row("main", main),
113+
row(tag_name, tag),
114+
divider.replace("-", ""),
115+
delta_row("PR vs main", main),
116+
delta_row(f"PR vs {tag_name}", tag),
117+
"",
118+
"> `median (mean ± std)`",
119+
"> ",
120+
"🔴 >5% slower &nbsp; ⚪ within 5% &nbsp; 🟢 >5% faster",
121+
]
122+
return "\n".join(lines)
123+
124+
125+
def main():
126+
import argparse
127+
128+
parser = argparse.ArgumentParser()
129+
parser.add_argument(
130+
"--pattern",
131+
default="benchmark-*.json",
132+
help="Glob pattern for benchmark JSON files",
133+
)
134+
parser.add_argument(
135+
"-o",
136+
"--output",
137+
help="Output markdown filepath containing benchmark comparisons",
138+
)
139+
args = parser.parse_args()
140+
141+
files = sorted(Path(".").glob(args.pattern))
142+
assert len(files) == 3, f"Expected 3 files, found {len(files)}: {files}"
143+
144+
# Infer pr/main/tag from directory name
145+
parsed: dict[str, BenchmarkResult] = {}
146+
tag = None
147+
for f in files:
148+
stem = f.parent.name # e.g. "benchmark-pr"
149+
key = stem.split("-")[-1] # "pr", "main", tag
150+
if key not in ("pr", "main"):
151+
tag = key
152+
parsed[key] = parse_file(f)
153+
if tag is None:
154+
raise ValueError("Unknown tag")
155+
table = build_table(parsed["pr"], parsed["main"], parsed[tag], tag_name=tag)
156+
args.output.write_text(table)
157+
print(table)
158+
159+
160+
if __name__ == "__main__":
161+
main()

.github/scripts/run_benchmarks.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/usr/bin/env python
2+
"""Perform benchmarks across PR commit, main, and previous tag."""
3+
4+
import argparse
5+
6+
import pytest
7+
8+
9+
def main():
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("-o", "--output", required=True, help="Output JSON file path")
12+
args = parser.parse_args()
13+
14+
pytest.main(
15+
[
16+
"-m",
17+
"benchmark and not cloud",
18+
"--benchmark-save-data",
19+
f"--benchmark-json={args.output}",
20+
"--benchmark-time-unit=ms",
21+
"--benchmark-warmup=on",
22+
]
23+
)
24+
25+
26+
if __name__ == "__main__":
27+
main()

.github/workflows/benchmark.yaml

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
name: Benchmark
2+
3+
on:
4+
pull_request:
5+
branches: [ "main" ]
6+
7+
jobs:
8+
get-tag:
9+
runs-on: ubuntu-latest
10+
outputs:
11+
tag: ${{ steps.last_tag.outputs.tag }}
12+
steps:
13+
- uses: actions/checkout@v6
14+
with:
15+
fetch-tags: true
16+
fetch-depth: 0
17+
- id: last_tag
18+
run: echo ="tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT
19+
20+
benchmark:
21+
needs: get-tag
22+
runs-on: ubuntu-latest
23+
strategy:
24+
matrix:
25+
target:
26+
- name: pr
27+
ref: ${{ github.sha }}
28+
- name: main
29+
ref: main
30+
- name: ${{ needs.get_tag.outputs.tag }}
31+
ref: ${{ needs.get_tag.outputs.tag }}
32+
steps:
33+
- uses: actions/checkout@v6
34+
with:
35+
ref: ${{ matrix.target.ref }}
36+
submodules: true
37+
- uses: astral-sh/setup-uv@v8.1.0
38+
- run: uv sync --extra "cloud"
39+
- name: Run benchmarks
40+
run: |
41+
uv run .github/scripts/run_benchmarks.py \
42+
--output benchmark-${{matrix.target.name }}.json
43+
- uses: actions/upload-artifact@v7
44+
with:
45+
name: benchmark-${{ matrix.target.name }}
46+
path: benchmark-${{ matrix.target.name }}.json
47+
48+
report:
49+
needs: [ get-tag, benchmark ]
50+
runs-on: ubuntu-latest
51+
steps:
52+
- uses: actions/checkout@v6
53+
- uses: astral-sh/setup-uv@v8.1.0
54+
- uses: actions/download-artifact@v8
55+
with:
56+
pattern: benchmark-*
57+
- name: Generate report
58+
run: |
59+
uv run .github/scripts/compare_benchmarks.py \
60+
--output benchmarks.md \
61+
--pattern benchmark-*.json

.github/workflows/ci.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ jobs:
4242
with:
4343
python-version: ${{ matrix.python-version }}
4444

45-
- name: Run tests without cloudpathlib
45+
- name: Run non-cloud tests
4646
run: |
4747
uv run pytest \
48-
-m "not cloud" \
48+
-m "not cloud and not benchmark" \
4949
--junitxml=pytest-cloudless.xml \
5050
--cov-report=xml:coverage.xml \
51-
--cov bids2table \
51+
--cov=bids2table \
5252
tests
5353
54-
- name: Run tests with cloudpathlib
54+
- name: Run cloud tests
5555
run: |
5656
uv run --extra cloud pytest \
57-
-m "cloud" \
57+
-m "cloud and not benchmark" \
5858
--junitxml=pytest-cloud.xml \
5959
--cov-report=xml:coverage.xml \
6060
--cov=bids2table \

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ s3 = [
3535

3636
[dependency-groups]
3737
dev = [
38-
"pandas==3.0.2",
3938
"pdoc>=16.0.0",
4039
"pre-commit>=4.6.0",
4140
"pytest>=9.0.3",
41+
"pytest-benchmark>=5.2.3",
4242
"pytest-cov>=7.1.0",
4343
"ruff>=0.15.12",
44+
"polars>=1.40.1",
4445
]
4546

4647
[project.urls]
@@ -68,4 +69,7 @@ lint.extend-select = ["I"]
6869
[tool.pytest.ini_options]
6970
log_cli = true
7071
log_cli_level = "INFO"
71-
markers = ["cloud: Tests requiring cloud group dependencies"]
72+
markers = [
73+
"benchmark: Tests used for benchmarking",
74+
"cloud: Tests requiring cloud group dependencies",
75+
]

0 commit comments

Comments
 (0)