-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathbenchmark_concurrent_fetch.py
More file actions
248 lines (200 loc) · 8.27 KB
/
benchmark_concurrent_fetch.py
File metadata and controls
248 lines (200 loc) · 8.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""Benchmark concurrent chunk fetching vs serial reads from DANDI Archive.
Compares the old path (h5py + LindiRemfile, serial) against the new path
(zarr + LindiH5ZarrStore, concurrent) for reading external array links.
Produces a bar chart showing timings and speedup for each test case.
Usage:
python devel/benchmark_concurrent_fetch.py
python devel/benchmark_concurrent_fetch.py --dandiset 000473
python devel/benchmark_concurrent_fetch.py --dandiset 000409
python devel/benchmark_concurrent_fetch.py --output benchmark_results.png
"""
import argparse
import time
import tempfile
import numpy as np
import h5py
import zarr
import matplotlib.pyplot as plt
import lindi
from lindi.LindiRemfile.LindiRemfile import LindiRemfile
from lindi.LindiH5ZarrStore.LindiH5ZarrStore import LindiH5ZarrStore
from lindi.LindiH5ZarrStore.LindiH5ZarrStoreOpts import LindiH5ZarrStoreOpts
DANDISETS = {
"000473": {
"url": "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/",
"dataset": "processing/ecephys/LFP/LFP/data",
"label": "000473 LFP",
"slices": [
(np.s_[:1000], "[:1000]"),
(np.s_[:5000], "[:5000]"),
(np.s_[:10000], "[:10000]"),
],
},
"000409": {
"url": "https://api.dandiarchive.org/api/assets/c04f6b30-82bf-40e1-9210-34f0bcd8be24/download/",
"dataset": "acquisition/ElectricalSeriesAp/data",
"label": "000409 Neuropixels",
"slices": [
(np.s_[:500], "[:500]"),
(np.s_[:1000], "[:1000]"),
(np.s_[:2000], "[:2000]"),
],
},
}
def read_serial(url, dataset_name, selection):
"""Old path: serial reads through h5py + LindiRemfile."""
remf = LindiRemfile(url, verbose=False, local_cache=None)
with h5py.File(remf, "r") as f:
return f[dataset_name][selection]
def read_concurrent(url, dataset_name, selection):
"""New path: concurrent reads through zarr + LindiH5ZarrStore."""
opts = LindiH5ZarrStoreOpts(num_dataset_chunks_threshold=None)
with LindiH5ZarrStore.from_file(url, opts=opts) as store:
arr = zarr.open_array(store=store, path=dataset_name, mode="r")
return arr[selection]
def read_lindi_file(lindi_json_path, dataset_name, selection):
"""Full pipeline: LindiH5pyFile (concurrent path for remote external links)."""
f = lindi.LindiH5pyFile.from_lindi_file(lindi_json_path, mode="r")
return f[dataset_name][selection]
def benchmark_one(func, *args, **kwargs):
"""Time a single call. Returns (result, elapsed_seconds)."""
t0 = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - t0
return result, elapsed
def run_benchmark(dandiset_key):
"""Run benchmarks for one dandiset. Returns list of result dicts."""
info = DANDISETS[dandiset_key]
url = info["url"]
dataset = info["dataset"]
label = info["label"]
print(f"\n{'=' * 60}")
print(f"Benchmarking: {label}")
print(f" URL: {url}")
print(f" Dataset: {dataset}")
print(f"{'=' * 60}")
results = []
for selection, sel_label in info["slices"]:
print(f"\n Slice: {dataset}{sel_label}")
print(f" {'-' * 40}")
# Serial (old path)
data_serial, t_serial = benchmark_one(read_serial, url, dataset, selection)
print(f" Serial (h5py+LindiRemfile): {t_serial:.2f}s")
# Concurrent (new path)
data_concurrent, t_concurrent = benchmark_one(read_concurrent, url, dataset, selection)
print(f" Concurrent (zarr+LindiH5ZarrStore): {t_concurrent:.2f}s")
# Equivalence check
np.testing.assert_array_equal(data_serial, data_concurrent)
print(f" Data equivalent: shape={data_serial.shape}, dtype={data_serial.dtype}")
speedup = t_serial / t_concurrent if t_concurrent > 0 else float("inf")
print(f" Speedup: {speedup:.1f}x")
results.append({
"dandiset": dandiset_key,
"label": f"{label}\n{sel_label}",
"short_label": sel_label,
"serial": t_serial,
"concurrent": t_concurrent,
"speedup": speedup,
"shape": data_serial.shape,
})
return results
def run_lindi_file_benchmark(dandiset_key):
"""Run the full LindiH5pyFile pipeline benchmark."""
info = DANDISETS[dandiset_key]
url = info["url"]
dataset = info["dataset"]
selection, sel_label = info["slices"][1] # use middle slice
print(f"\n{'=' * 60}")
print(f"LindiH5pyFile Pipeline: {info['label']}")
print(f"{'=' * 60}")
with tempfile.TemporaryDirectory() as tmpdir:
print(" Creating .lindi.json...")
t0 = time.perf_counter()
fname = f"{tmpdir}/test.nwb.lindi.json"
with lindi.LindiH5pyFile.from_hdf5_file(url) as f:
f.write_lindi_file(fname)
t_create = time.perf_counter() - t0
print(f" Created in {t_create:.2f}s")
# Serial baseline
data_serial, t_serial = benchmark_one(read_serial, url, dataset, selection)
print(f" Serial (h5py+LindiRemfile): {t_serial:.2f}s")
# LindiH5pyFile (uses concurrent path)
data_lindi, t_lindi = benchmark_one(read_lindi_file, fname, dataset, selection)
print(f" LindiH5pyFile (concurrent): {t_lindi:.2f}s")
np.testing.assert_array_equal(data_serial, data_lindi)
speedup = t_serial / t_lindi if t_lindi > 0 else float("inf")
print(f" Data equivalent: shape={data_serial.shape}")
print(f" Speedup: {speedup:.1f}x")
return {
"label": f"LindiH5pyFile\n{info['label']} {sel_label}",
"serial": t_serial,
"concurrent": t_lindi,
"speedup": speedup,
}
def plot_results(all_results, output_path=None):
"""Create a bar chart comparing serial vs concurrent timings."""
labels = [r["label"] for r in all_results]
serial_times = [r["serial"] for r in all_results]
concurrent_times = [r["concurrent"] for r in all_results]
speedups = [r["speedup"] for r in all_results]
x = np.arange(len(labels))
width = 0.35
fig, ax = plt.subplots(figsize=(max(10, len(labels) * 2), 6))
bars_serial = ax.bar(x - width / 2, serial_times, width, label="Serial (h5py + LindiRemfile)", color="#d35f5f")
bars_concurrent = ax.bar(x + width / 2, concurrent_times, width, label="Concurrent (zarr + LindiH5ZarrStore)", color="#5f9ed3")
# Add speedup annotations
for i, (s_time, c_time, speedup) in enumerate(zip(serial_times, concurrent_times, speedups)):
y = max(s_time, c_time)
ax.annotate(
f"{speedup:.1f}x",
xy=(i, y),
xytext=(0, 8),
textcoords="offset points",
ha="center",
fontweight="bold",
fontsize=11,
)
ax.set_ylabel("Time (seconds)")
ax.set_title("LINDI Chunk Fetching: Serial vs Concurrent")
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=9)
ax.legend()
ax.set_ylim(0, max(serial_times) * 1.3)
fig.tight_layout()
if output_path:
fig.savefig(output_path, dpi=150)
print(f"\nPlot saved to: {output_path}")
else:
plt.show()
def main():
parser = argparse.ArgumentParser(description="Benchmark concurrent chunk fetching from DANDI")
parser.add_argument(
"--dandiset",
choices=list(DANDISETS.keys()),
default=None,
help="Run benchmarks for a specific dandiset only (default: run all)",
)
parser.add_argument(
"--output", "-o",
default=None,
help="Save plot to file instead of displaying (e.g. benchmark.png)",
)
parser.add_argument(
"--skip-lindi-file",
action="store_true",
help="Skip the full LindiH5pyFile pipeline benchmark (slow due to .lindi.json creation)",
)
args = parser.parse_args()
dandisets = [args.dandiset] if args.dandiset else list(DANDISETS.keys())
all_results = []
for key in dandisets:
all_results.extend(run_benchmark(key))
if not args.skip_lindi_file:
for key in dandisets:
all_results.append(run_lindi_file_benchmark(key))
print(f"\n{'=' * 60}")
print("All data equivalence checks passed!")
print(f"{'=' * 60}")
plot_results(all_results, output_path=args.output)
if __name__ == "__main__":
main()