Skip to content

Commit 037e394

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent dbeb1f2 commit 037e394

File tree

3 files changed

+84
-93
lines changed

3 files changed

+84
-93
lines changed

src/nwb_benchmarks/database/_processing.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from nwb_benchmarks.database._parquet import repackage_as_parquet
99

10-
PACKAGES_OF_INTEREST = ['h5py', 'fsspec', 'lindi', 'remfile', 'zarr', 'hdmf-zarr', 'hdmf', 'pynwb']
10+
PACKAGES_OF_INTEREST = ["h5py", "fsspec", "lindi", "remfile", "zarr", "hdmf-zarr", "hdmf", "pynwb"]
11+
1112

1213
class BenchmarkDatabase:
1314
"""Handles database preprocessing and loading for NWB benchmarks."""
@@ -82,10 +83,7 @@ def _preprocess_results(self, df: pl.LazyFrame) -> pl.DataFrame:
8283
return (
8384
df
8485
# Filter for specific machine early to reduce data volume
85-
.filter(
86-
pl.col("machine_id") == self.machine_id if self.machine_id is not None
87-
else pl.lit(True)
88-
)
86+
.filter(pl.col("machine_id") == self.machine_id if self.machine_id is not None else pl.lit(True))
8987
# Extract benchmark name components
9088
.with_columns(
9189
[
@@ -122,8 +120,7 @@ def _preprocess_results(self, df: pl.LazyFrame) -> pl.DataFrame:
122120
.str.extract(r"slice\(0, (\d+),", group_index=1)
123121
.cast(pl.Int64)
124122
.alias("scaling_value"),
125-
)
126-
.with_columns((pl.col("scaling_value").rank(method="dense")).over("modality").alias("slice_number"))
123+
).with_columns((pl.col("scaling_value").rank(method="dense")).over("modality").alias("slice_number"))
127124
# Create unified cleaned benchmark name
128125
.with_columns(
129126
pl.when(pl.col("benchmark_name_label") == "ContinuousSliceBenchmark")
@@ -137,23 +134,21 @@ def _preprocess_results(self, df: pl.LazyFrame) -> pl.DataFrame:
137134

138135
def _preprocess_environments(self, df: pl.LazyFrame) -> pl.LazyFrame:
139136
"""Apply all preprocessing transformations to the environments dataframe."""
140-
141-
return (df
142-
# get only relevant package columns
143-
.select(["environment_id", *self.packages_of_interest])
144-
# remove build information
145-
.with_columns([
146-
pl.col(pkg).str.extract(r"^([\d.]+)", group_index=1)
147-
for pkg in self.packages_of_interest
148-
])
149-
# unpivot packages into long format for plotting
150-
.unpivot(
151-
index="environment_id",
152-
on=self.packages_of_interest,
153-
variable_name="package_name",
154-
value_name="package_version",)
155-
.filter(pl.col("package_version").is_not_null())
156-
)
137+
138+
return (
139+
df
140+
# get only relevant package columns
141+
.select(["environment_id", *self.packages_of_interest])
142+
# remove build information
143+
.with_columns([pl.col(pkg).str.extract(r"^([\d.]+)", group_index=1) for pkg in self.packages_of_interest])
144+
# unpivot packages into long format for plotting
145+
.unpivot(
146+
index="environment_id",
147+
on=self.packages_of_interest,
148+
variable_name="package_name",
149+
value_name="package_version",
150+
).filter(pl.col("package_version").is_not_null())
151+
)
157152

158153
def get_results(self) -> pl.LazyFrame:
159154
"""
@@ -179,15 +174,15 @@ def get_environments(self) -> pl.LazyFrame:
179174
lazy_df = pl.scan_parquet(self.db_directory / "environments.parquet")
180175
self._environments_df = self._preprocess_environments(lazy_df)
181176

182-
return self._environments_df
177+
return self._environments_df
183178

184179
def join_results_with_environments(self) -> pl.LazyFrame:
185180
"""Join streaming package versions with results using the environments table."""
186181
return self.get_results().join(
187-
self.get_environments(),
188-
on="environment_id",
189-
how="left",
190-
)
182+
self.get_environments(),
183+
on="environment_id",
184+
how="left",
185+
)
191186

192187
def filter_tests(self, benchmark_type: str) -> pl.LazyFrame:
193188
"""Filter benchmark tests."""

src/nwb_benchmarks/database/_visualization.py

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
import textwrap
22
from pathlib import Path
33
from typing import Any, Dict, List, Optional
4-
from packaging import version
54

65
import matplotlib
76
import matplotlib.pyplot as plt
87
import pandas as pd
98
import polars as pl
109
import seaborn as sns
10+
from packaging import version
1111

1212
from nwb_benchmarks.database._processing import BenchmarkDatabase
1313

14-
DEFAULT_BENCHMARK_ORDER = ["hdf5 h5py remfile no cache",
15-
"hdf5 h5py fsspec https no cache",
16-
"hdf5 h5py fsspec s3 no cache",
17-
"hdf5 h5py remfile with cache",
18-
"hdf5 h5py fsspec https with cache",
19-
"hdf5 h5py fsspec s3 with cache",
20-
"hdf5 h5py ros3",
21-
"lindi h5py",
22-
"zarr https",
23-
"zarr https force no consolidated",
24-
"zarr s3",
25-
"zarr s3 force no consolidated"]
14+
DEFAULT_BENCHMARK_ORDER = [
15+
"hdf5 h5py remfile no cache",
16+
"hdf5 h5py fsspec https no cache",
17+
"hdf5 h5py fsspec s3 no cache",
18+
"hdf5 h5py remfile with cache",
19+
"hdf5 h5py fsspec https with cache",
20+
"hdf5 h5py fsspec s3 with cache",
21+
"hdf5 h5py ros3",
22+
"lindi h5py",
23+
"zarr https",
24+
"zarr https force no consolidated",
25+
"zarr s3",
26+
"zarr s3 force no consolidated",
27+
]
28+
2629

2730
class BenchmarkVisualizer:
2831
"""Handles plotting and visualization of benchmark results."""
2932

3033
file_open_order = DEFAULT_BENCHMARK_ORDER
31-
pynwb_read_order = [method.replace('h5py', 'pynwb').replace('zarr', 'zarr pynwb')
32-
for method in DEFAULT_BENCHMARK_ORDER]
33-
download_order = ["hdf5 dandi api", "zarr dandi api", "lindi dandi api"]
34+
pynwb_read_order = [
35+
method.replace("h5py", "pynwb").replace("zarr", "zarr pynwb") for method in DEFAULT_BENCHMARK_ORDER
36+
]
37+
download_order = ["hdf5 dandi api", "zarr dandi api", "lindi dandi api"]
3438
# TODO - where does lindi local json value go / what should it be called
3539

3640
def __init__(self, output_directory: Optional[Path] = None):
@@ -55,7 +59,7 @@ def _format_stat_text(mean: float, std: float, count: int) -> str:
5559
if mean > 1000 or mean < 0.01:
5660
return f" {mean:.2e} ± {std:.2e}, n={int(count)}"
5761
return f" {mean:.2f} ± {std:.2f}, n={int(count)}"
58-
62+
5963
def _add_mean_sem_annotations(self, value: str, group: str, order: List[str], **kwargs):
6064
"""Add mean ± SEM annotations to plot."""
6165
stats_df = kwargs.get("data").groupby(group)[value].agg(["mean", "std", "max", "count"])
@@ -77,7 +81,7 @@ def _add_mean_sem_annotations(self, value: str, group: str, order: List[str], **
7781
def _get_filename_prefix(self, network_tracking: bool) -> str:
7882
"""Get filename prefix based on network tracking."""
7983
return "network_tracking_" if network_tracking else ""
80-
84+
8185
def _create_plot_kwargs(
8286
self, df, group: str, order: List[str], filename: Path, kind: str = "box", **extra_kwargs
8387
) -> Dict[str, Any]:
@@ -96,21 +100,17 @@ def _create_plot_kwargs(
96100
plot_kwargs.update(extra_kwargs)
97101

98102
return plot_kwargs
99-
103+
100104
@staticmethod
101105
def _set_package_version_categorical(group):
102106
# TODO - idk if this is actually sorting or not
103107
sorted_versions = sorted(
104-
group['package_version'].unique(),
108+
group["package_version"].unique(),
105109
key=lambda v: version.parse(v),
106110
)
107-
group['package_version'] = pd.Categorical(
108-
group['package_version'],
109-
categories=sorted_versions,
110-
ordered=True
111-
)
111+
group["package_version"] = pd.Categorical(group["package_version"], categories=sorted_versions, ordered=True)
112112
return group
113-
113+
114114
def _create_heatmap_df(self, df: pl.DataFrame, group: str, metric_order: List[str]) -> pd.DataFrame:
115115
"""Prepare data for heatmap visualization."""
116116
return (
@@ -121,7 +121,11 @@ def _create_heatmap_df(self, df: pl.DataFrame, group: str, metric_order: List[st
121121
)
122122

123123
def plot_benchmark_heatmap(
124-
self, df: pl.LazyFrame, metric_order: List[str], group: str = "benchmark_name_clean", ax: Optional[plt.Axes] = None
124+
self,
125+
df: pl.LazyFrame,
126+
metric_order: List[str],
127+
group: str = "benchmark_name_clean",
128+
ax: Optional[plt.Axes] = None,
125129
) -> plt.Axes:
126130
"""Create heatmap visualization of benchmark results."""
127131
heatmap_df = self._create_heatmap_df(df, group, metric_order)
@@ -139,7 +143,7 @@ def plot_benchmark_heatmap(
139143
ax.text(j + 0.5, i + 0.5, " *", fontsize=20, ha="center", va="center", color="black", weight="bold")
140144

141145
return ax
142-
146+
143147
def plot_benchmark_dist(
144148
self,
145149
df: pd.DataFrame,
@@ -316,28 +320,26 @@ def plot_download_vs_stream_benchmarks(
316320
):
317321
"""Plot download vs stream benchmark comparison."""
318322
print("Plotting download vs stream benchmark comparison...")
319-
323+
320324
# combine read + slice times
321325
slice_df_combined = (
322326
db.filter_tests("time_remote_slicing")
323327
# join slice and file read data
324328
.join(
325329
# get average remote file read time
326330
db.filter_tests("time_remote_file_reading")
327-
.group_by(['modality', 'benchmark_name_clean'])
328-
.agg(pl.col('value').mean().alias('avg_file_open_time')),
331+
.group_by(["modality", "benchmark_name_clean"])
332+
.agg(pl.col("value").mean().alias("avg_file_open_time")),
329333
# match on benchmark_name_clean + parameter_case_name
330-
on=['modality', 'benchmark_name_clean'],
331-
how='left'
334+
on=["modality", "benchmark_name_clean"],
335+
how="left",
332336
)
333337
# add average file open time to each slice time
334338
# NOTE - should the file open + slice times be added per run or is average ok?
335-
.with_columns(
336-
(pl.col('value') + pl.col('avg_file_open_time')).alias('total_time')
337-
)
339+
.with_columns((pl.col("value") + pl.col("avg_file_open_time")).alias("total_time"))
338340
)
339341

340-
# TODO - combine download + local read times
342+
# TODO - combine download + local read times
341343
# download_df = db.filter_tests("time_download")
342344

343345
# plot time vs number of slices (TODO - plot with extrapolation)
@@ -350,13 +352,11 @@ def plot_download_vs_stream_benchmarks(
350352
"row": "variable" if network_tracking else "is_preloaded",
351353
"sharex": "row" if network_tracking else True,
352354
}
353-
self.plot_benchmark_slices_vs_time(y_value="value",
354-
filename=f"{base_filename}_vs_time.pdf",
355-
**plot_kwargs)
356-
self.plot_benchmark_slices_vs_time(y_value="total_time",
357-
filename=f"{base_filename}_vs_total_time.pdf",
358-
**plot_kwargs)
359-
355+
self.plot_benchmark_slices_vs_time(y_value="value", filename=f"{base_filename}_vs_time.pdf", **plot_kwargs)
356+
self.plot_benchmark_slices_vs_time(
357+
y_value="total_time", filename=f"{base_filename}_vs_total_time.pdf", **plot_kwargs
358+
)
359+
360360
def plot_method_rankings(self, db: BenchmarkDatabase):
361361
"""Create heatmap showing method rankings across benchmarks."""
362362
print("Plotting method rankings heatmap...")
@@ -365,17 +365,11 @@ def plot_method_rankings(self, db: BenchmarkDatabase):
365365
read_df = db.filter_tests("time_remote_file_reading").collect()
366366

367367
fig, axes = plt.subplots(3, 1, figsize=(8, 16))
368-
axes[0] = self.plot_benchmark_heatmap(
369-
df=read_df, metric_order=self.file_open_order, ax=axes[0]
370-
)
368+
axes[0] = self.plot_benchmark_heatmap(df=read_df, metric_order=self.file_open_order, ax=axes[0])
371369

372-
axes[1] = self.plot_benchmark_heatmap(
373-
df=read_df, metric_order=self.pynwb_read_order, ax=axes[1]
374-
)
370+
axes[1] = self.plot_benchmark_heatmap(df=read_df, metric_order=self.pynwb_read_order, ax=axes[1])
375371

376-
axes[2] = self.plot_benchmark_heatmap(
377-
df=slice_df, metric_order=self.pynwb_read_order, ax=axes[2]
378-
)
372+
axes[2] = self.plot_benchmark_heatmap(df=slice_df, metric_order=self.pynwb_read_order, ax=axes[2])
379373

380374
axes[0].set_title("Remote File Reading")
381375
axes[1].set_title("Remote File Reading - PyNWB")
@@ -385,26 +379,27 @@ def plot_method_rankings(self, db: BenchmarkDatabase):
385379
plt.savefig(self.output_directory / "method_rankings_heatmap.pdf", dpi=300)
386380
plt.close()
387381

388-
def plot_performance_across_versions(self,
389-
db: BenchmarkDatabase,
390-
order: List[str] = None,
391-
hue: str = "benchmark_name_clean",
392-
benchmark_type: str = "time_remote_file_reading"):
382+
def plot_performance_across_versions(
383+
self,
384+
db: BenchmarkDatabase,
385+
order: List[str] = None,
386+
hue: str = "benchmark_name_clean",
387+
benchmark_type: str = "time_remote_file_reading",
388+
):
393389
"""Plot performance changes over time for a given benchmark type."""
394390
print(f"Plotting performance over time")
395-
391+
396392
# get polars dataframe and filter
397393
df = db.join_results_with_environments()
398394
df = (
399-
df
400-
.filter(pl.col("benchmark_name_type") == benchmark_type)
395+
df.filter(pl.col("benchmark_name_type") == benchmark_type)
401396
.collect()
402397
.to_pandas()
403-
.groupby('package_name')
398+
.groupby("package_name")
404399
.apply(self._set_package_version_categorical, include_groups=False)
405400
.reset_index(level=0)
406401
)
407-
402+
408403
g = sns.catplot(
409404
data=df,
410405
x="package_version",
@@ -427,7 +422,7 @@ def plot_performance_across_versions(self,
427422

428423
def plot_all(self, db: BenchmarkDatabase):
429424
"""Generate all benchmark visualization plots."""
430-
425+
431426
# 1. WHICH LIBRARY SHOULD I USE TO STREAM DATA
432427
# Remote file reading / slicing benchmarks
433428
self.plot_read_benchmarks(db, suffix="_pynwb")

src/nwb_benchmarks/scripts/generate_figures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
LBL_MAC_MACHINE_ID = "87fee773e425b4b1d3978fbf762d57effb0e8df8"
88

9+
910
def main():
1011
# Initialize database handler
1112
db = BenchmarkDatabase(machine_id=LBL_MAC_MACHINE_ID)

0 commit comments

Comments
 (0)