Skip to content

Commit dde2bcb

Browse files
authored
Merge branch 'main' into add-old-version-util
2 parents 4d34284 + e625cf1 commit dde2bcb

File tree

8 files changed

+812
-123
lines changed

8 files changed

+812
-123
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,6 @@ fabric.properties
160160

161161
# Spyder
162162
.spyproject/*
163+
164+
# Figures
165+
*.pdf

src/nwb_benchmarks/benchmarks/time_download.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .params_remote_download import hdf5_params, lindi_remote_rfs_params, zarr_params
1313

1414

15-
class BaseDownloadDandiAPIBenchmark(BaseBenchmark):
15+
class BaseDandiAPIDownloadBenchmark(BaseBenchmark):
1616
"""
1717
Base class for timing the download of remote NWB files using the DANDI API.
1818
"""
@@ -27,7 +27,7 @@ def teardown(self, params: dict[str, str]):
2727
self.tmpdir.cleanup()
2828

2929

30-
class HDF5DownloadDandiAPIBenchmark(BaseDownloadDandiAPIBenchmark):
30+
class HDF5DandiAPIDownloadBenchmark(BaseDandiAPIDownloadBenchmark):
3131
"""
3232
Time the download of remote HDF5 NWB files using the DANDI API.
3333
"""
@@ -41,7 +41,7 @@ def time_download_hdf5_dandi_api(self, params: dict[str, str]):
4141
download(urls=params["https_url"], output_dir=self.tmpdir.name)
4242

4343

44-
class ZarrDownloadDandiAPIBenchmark(BaseDownloadDandiAPIBenchmark):
44+
class ZarrDandiAPIDownloadBenchmark(BaseDandiAPIDownloadBenchmark):
4545
"""
4646
Time the download of remote Zarr NWB files using the DANDI API.
4747
"""
@@ -55,7 +55,7 @@ def time_download_zarr_dandi_api(self, params: dict[str, str]):
5555
download(urls=params["https_url"], output_dir=self.tmpdir.name)
5656

5757

58-
class LindiDownloadDandiAPIBenchmark(BaseDownloadDandiAPIBenchmark):
58+
class LindiDandiAPIDownloadBenchmark(BaseDandiAPIDownloadBenchmark):
5959
"""
6060
Time the download of a remote LINDI JSON file.
6161
"""
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
"""Exposed imports to the `database` submodule."""
22

33
from ._models import Environment, Machine, Result, Results
4-
from ._processing import (
4+
from ._parquet import (
55
concat_dataclasses_to_parquet,
66
repackage_as_parquet,
77
)
8+
from ._processing import BenchmarkDatabase
9+
from ._visualization import BenchmarkVisualizer
810

911
__all__ = [
1012
"Machine",
1113
"Result",
1214
"Results",
1315
"Environment",
16+
"BenchmarkDatabase",
17+
"BenchmarkVisualizer",
1418
"concat_dataclasses_to_parquet",
1519
"repackage_as_parquet",
1620
]

src/nwb_benchmarks/database/_models.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pathlib
55
import re
66
import uuid
7+
from datetime import datetime
78

89
import packaging.version
910
import typing_extensions
@@ -13,7 +14,7 @@
1314
class Result:
1415
uuid: str
1516
version: str
16-
timestamp: str
17+
timestamp: datetime
1718
commit_hash: str
1819
environment_id: str
1920
machine_id: str
@@ -27,6 +28,43 @@ class Result:
2728
class Results:
2829
results: list[Result]
2930

31+
@staticmethod
32+
def normalize_time_and_network_results(benchmark_results) -> dict:
33+
"""Convert benchmark results to a consistent dict format with list values."""
34+
35+
def process_network_results(benchmark_results: dict) -> dict:
36+
"""Add additional network metrics."""
37+
results = benchmark_results.copy()
38+
39+
if results["total_traffic_in_number_of_web_packets"] != 0:
40+
results["mean_time_per_web_packet"] = (
41+
results["total_transfer_time_in_seconds"] / results["total_traffic_in_number_of_web_packets"]
42+
)
43+
else:
44+
results["mean_time_per_web_packet"] = float("nan")
45+
46+
return results
47+
48+
if isinstance(benchmark_results, dict):
49+
value_dict = process_network_results(benchmark_results)
50+
else:
51+
value_dict = dict(time=benchmark_results)
52+
53+
# Ensure all values are lists
54+
return {k: v if isinstance(v, list) else [float(v)] for k, v in value_dict.items()}
55+
56+
@staticmethod
57+
def parse_parameter_case(s):
58+
# replace any slice(...) with "slice(...)" for safe parsing
59+
modified_s = re.sub(r"slice\([^)]+\)", r'"\g<0>"', s)
60+
output = ast.literal_eval(modified_s)
61+
62+
# if the parsed string is not a dict (older benchmarks results), convert it to one
63+
if not isinstance(output, dict):
64+
output = {"https_url": output[0].strip("'")}
65+
66+
return output
67+
3068
@classmethod
3169
def safe_load_from_json(cls, file_path: pathlib.Path) -> typing_extensions.Self | None:
3270
with file_path.open(mode="r") as file_stream:
@@ -43,43 +81,22 @@ def safe_load_from_json(cls, file_path: pathlib.Path) -> typing_extensions.Self
4381
environment_id = data["environment_id"]
4482
machine_id = data["machine_id"]
4583

46-
def normalize_time_and_network_results(benchmark_results) -> dict:
47-
"""Convert benchmark results to a consistent dict format with list values."""
48-
if isinstance(benchmark_results, dict):
49-
value_dict = benchmark_results
50-
else:
51-
value_dict = dict(time=benchmark_results)
52-
53-
# Ensure all values are lists
54-
return {k: v if isinstance(v, list) else [float(v)] for k, v in value_dict.items()}
55-
56-
def parse_parameter_case(s):
57-
# replace any slice(...) with "slice(...)" for safe parsing
58-
modified_s = re.sub(r"slice\([^)]+\)", r'"\g<0>"', s)
59-
output = ast.literal_eval(modified_s)
60-
61-
# if the parsed string is not a dict (older benchmarks results), convert it to one
62-
if not isinstance(output, dict):
63-
output = {"https_url": output[0].strip("'")}
64-
65-
return output
66-
6784
results = [
6885
Result(
6986
uuid=str(uuid.uuid4()), # TODO: add this to each results file so it is persistent
7087
version=database_version,
71-
timestamp=timestamp,
88+
timestamp=datetime.strptime(timestamp, "%Y-%m-%d-%H-%M-%S"),
7289
commit_hash=commit_hash,
7390
environment_id=environment_id,
7491
machine_id=machine_id,
7592
benchmark_name=benchmark_name,
76-
parameter_case=parse_parameter_case(parameter_case),
93+
parameter_case=cls.parse_parameter_case(parameter_case),
7794
value=value,
7895
variable=variable_name,
7996
)
8097
for benchmark_name, parameter_cases in data["results"].items()
8198
for parameter_case, benchmark_results in parameter_cases.items()
82-
for variable_name, values in normalize_time_and_network_results(benchmark_results).items()
99+
for variable_name, values in cls.normalize_time_and_network_results(benchmark_results).items()
83100
for value in values
84101
]
85102

@@ -94,6 +111,7 @@ def to_dataframe(self) -> "polars.DataFrame":
94111
"commit_hash": [result.commit_hash for result in self.results],
95112
"environment_id": [result.environment_id for result in self.results],
96113
"machine_id": [result.machine_id for result in self.results],
114+
"timestamp": [result.timestamp for result in self.results],
97115
"benchmark_name": [result.benchmark_name for result in self.results],
98116
"parameter_case_name": [result.parameter_case.get("name") for result in self.results],
99117
"parameter_case_https_url": [result.parameter_case.get("https_url") for result in self.results],
@@ -187,7 +205,7 @@ def safe_load_from_json(cls, file_path: pathlib.Path) -> typing_extensions.Self
187205
packages = {
188206
package["name"]: f'{package["version"]} ({package["build"]})'
189207
for package in data[preamble]
190-
if len(package) == 3
208+
if len(package) >= 3
191209
}
192210

193211
if not any(packages):
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import dataclasses
2+
import pathlib
3+
4+
import packaging
5+
import polars
6+
7+
from ._models import Environment, Machine, Results
8+
9+
10+
def concat_dataclasses_to_parquet(
11+
directory: pathlib.Path,
12+
output_directory: pathlib.Path,
13+
dataclass_name: str,
14+
dataclass: dataclasses.dataclass,
15+
concat_how: str = "diagonal_relaxed",
16+
minimum_version: str = "1.0.0",
17+
) -> None:
18+
"""Generic function to process any data type (machines, environments, results)
19+
20+
Args:
21+
directory (pathlib.Path): Path to the root directory containing data subdirectories.
22+
output_directory (pathlib.Path): Path to the output directory where the parquet file will be saved.
23+
dataclass_name (str): Name of the data class, used for input and output filenames.
24+
dataclass: The dataclass type to process (Machine, Environment, Results).
25+
concat_how (str, optional): How to concatenate dataframes. Defaults to "diagonal_relaxed".
26+
minimum_version (str, optional): Minimum version of the database to include. Defaults to "1.0.0".
27+
Returns:
28+
29+
"""
30+
31+
data_frames = []
32+
data_directory = directory / dataclass_name
33+
34+
for file_path in data_directory.iterdir():
35+
obj = dataclass.safe_load_from_json(file_path=file_path)
36+
37+
if obj is None:
38+
continue
39+
40+
data_frame = obj.to_dataframe()
41+
42+
# filter by minimum version (before concatenation to avoid issues with different results structures)
43+
# TODO - should environment have a version?
44+
if "version" in data_frame.columns:
45+
data_frame = data_frame.filter(
46+
polars.col("version").map_elements(
47+
lambda x: packaging.version.parse(x) >= packaging.version.parse(minimum_version),
48+
return_dtype=polars.Boolean,
49+
)
50+
)
51+
52+
data_frames.append(data_frame)
53+
54+
if data_frames:
55+
database = polars.concat(items=data_frames, how=concat_how)
56+
output_file_path = output_directory / f"{dataclass_name}.parquet"
57+
database.write_parquet(file=output_file_path)
58+
59+
60+
def repackage_as_parquet(
61+
directory: pathlib.Path,
62+
output_directory: pathlib.Path,
63+
minimum_results_version: str = "1.0.0",
64+
minimum_machines_version: str = "1.0.0",
65+
) -> None:
66+
"""Repackage JSON results files as parquet databases for easier querying."""
67+
68+
# Machines
69+
concat_dataclasses_to_parquet(
70+
directory=directory,
71+
output_directory=output_directory,
72+
dataclass_name="machines",
73+
dataclass=Machine,
74+
concat_how="diagonal_relaxed",
75+
minimum_version=minimum_machines_version,
76+
)
77+
78+
# Environments
79+
concat_dataclasses_to_parquet(
80+
directory=directory,
81+
output_directory=output_directory,
82+
dataclass_name="environments",
83+
dataclass=Environment,
84+
concat_how="diagonal",
85+
)
86+
87+
# Results
88+
concat_dataclasses_to_parquet(
89+
directory=directory,
90+
output_directory=output_directory,
91+
dataclass_name="results",
92+
dataclass=Results,
93+
concat_how="diagonal_relaxed",
94+
minimum_version=minimum_results_version,
95+
)

0 commit comments

Comments
 (0)