11import textwrap
22from pathlib import Path
33from typing import Any , Dict , List , Optional
4- from packaging import version
54
65import matplotlib
76import matplotlib .pyplot as plt
87import pandas as pd
98import polars as pl
109import seaborn as sns
10+ from packaging import version
1111
1212from nwb_benchmarks .database ._processing import BenchmarkDatabase
1313
14- DEFAULT_BENCHMARK_ORDER = ["hdf5 h5py remfile no cache" ,
15- "hdf5 h5py fsspec https no cache" ,
16- "hdf5 h5py fsspec s3 no cache" ,
17- "hdf5 h5py remfile with cache" ,
18- "hdf5 h5py fsspec https with cache" ,
19- "hdf5 h5py fsspec s3 with cache" ,
20- "hdf5 h5py ros3" ,
21- "lindi h5py" ,
22- "zarr https" ,
23- "zarr https force no consolidated" ,
24- "zarr s3" ,
25- "zarr s3 force no consolidated" ]
14+ DEFAULT_BENCHMARK_ORDER = [
15+ "hdf5 h5py remfile no cache" ,
16+ "hdf5 h5py fsspec https no cache" ,
17+ "hdf5 h5py fsspec s3 no cache" ,
18+ "hdf5 h5py remfile with cache" ,
19+ "hdf5 h5py fsspec https with cache" ,
20+ "hdf5 h5py fsspec s3 with cache" ,
21+ "hdf5 h5py ros3" ,
22+ "lindi h5py" ,
23+ "zarr https" ,
24+ "zarr https force no consolidated" ,
25+ "zarr s3" ,
26+ "zarr s3 force no consolidated" ,
27+ ]
28+
2629
2730class BenchmarkVisualizer :
2831 """Handles plotting and visualization of benchmark results."""
2932
3033 file_open_order = DEFAULT_BENCHMARK_ORDER
31- pynwb_read_order = [method .replace ('h5py' , 'pynwb' ).replace ('zarr' , 'zarr pynwb' )
32- for method in DEFAULT_BENCHMARK_ORDER ]
33- download_order = ["hdf5 dandi api" , "zarr dandi api" , "lindi dandi api" ]
34+ pynwb_read_order = [
35+ method .replace ("h5py" , "pynwb" ).replace ("zarr" , "zarr pynwb" ) for method in DEFAULT_BENCHMARK_ORDER
36+ ]
37+ download_order = ["hdf5 dandi api" , "zarr dandi api" , "lindi dandi api" ]
3438 # TODO - where does lindi local json value go / what should it be called
3539
3640 def __init__ (self , output_directory : Optional [Path ] = None ):
@@ -55,7 +59,7 @@ def _format_stat_text(mean: float, std: float, count: int) -> str:
5559 if mean > 1000 or mean < 0.01 :
5660 return f" { mean :.2e} ± { std :.2e} , n={ int (count )} "
5761 return f" { mean :.2f} ± { std :.2f} , n={ int (count )} "
58-
62+
5963 def _add_mean_sem_annotations (self , value : str , group : str , order : List [str ], ** kwargs ):
6064 """Add mean ± SEM annotations to plot."""
6165 stats_df = kwargs .get ("data" ).groupby (group )[value ].agg (["mean" , "std" , "max" , "count" ])
@@ -77,7 +81,7 @@ def _add_mean_sem_annotations(self, value: str, group: str, order: List[str], **
7781 def _get_filename_prefix (self , network_tracking : bool ) -> str :
7882 """Get filename prefix based on network tracking."""
7983 return "network_tracking_" if network_tracking else ""
80-
84+
8185 def _create_plot_kwargs (
8286 self , df , group : str , order : List [str ], filename : Path , kind : str = "box" , ** extra_kwargs
8387 ) -> Dict [str , Any ]:
@@ -96,21 +100,17 @@ def _create_plot_kwargs(
96100 plot_kwargs .update (extra_kwargs )
97101
98102 return plot_kwargs
99-
103+
100104 @staticmethod
101105 def _set_package_version_categorical (group ):
102106 # TODO - idk if this is actually sorting or not
103107 sorted_versions = sorted (
104- group [' package_version' ].unique (),
108+ group [" package_version" ].unique (),
105109 key = lambda v : version .parse (v ),
106110 )
107- group ['package_version' ] = pd .Categorical (
108- group ['package_version' ],
109- categories = sorted_versions ,
110- ordered = True
111- )
111+ group ["package_version" ] = pd .Categorical (group ["package_version" ], categories = sorted_versions , ordered = True )
112112 return group
113-
113+
114114 def _create_heatmap_df (self , df : pl .DataFrame , group : str , metric_order : List [str ]) -> pd .DataFrame :
115115 """Prepare data for heatmap visualization."""
116116 return (
@@ -121,7 +121,11 @@ def _create_heatmap_df(self, df: pl.DataFrame, group: str, metric_order: List[st
121121 )
122122
123123 def plot_benchmark_heatmap (
124- self , df : pl .LazyFrame , metric_order : List [str ], group : str = "benchmark_name_clean" , ax : Optional [plt .Axes ] = None
124+ self ,
125+ df : pl .LazyFrame ,
126+ metric_order : List [str ],
127+ group : str = "benchmark_name_clean" ,
128+ ax : Optional [plt .Axes ] = None ,
125129 ) -> plt .Axes :
126130 """Create heatmap visualization of benchmark results."""
127131 heatmap_df = self ._create_heatmap_df (df , group , metric_order )
@@ -139,7 +143,7 @@ def plot_benchmark_heatmap(
139143 ax .text (j + 0.5 , i + 0.5 , " *" , fontsize = 20 , ha = "center" , va = "center" , color = "black" , weight = "bold" )
140144
141145 return ax
142-
146+
143147 def plot_benchmark_dist (
144148 self ,
145149 df : pd .DataFrame ,
@@ -316,28 +320,26 @@ def plot_download_vs_stream_benchmarks(
316320 ):
317321 """Plot download vs stream benchmark comparison."""
318322 print ("Plotting download vs stream benchmark comparison..." )
319-
323+
320324 # combine read + slice times
321325 slice_df_combined = (
322326 db .filter_tests ("time_remote_slicing" )
323327 # join slice and file read data
324328 .join (
325329 # get average remote file read time
326330 db .filter_tests ("time_remote_file_reading" )
327- .group_by ([' modality' , ' benchmark_name_clean' ])
328- .agg (pl .col (' value' ).mean ().alias (' avg_file_open_time' )),
331+ .group_by ([" modality" , " benchmark_name_clean" ])
332+ .agg (pl .col (" value" ).mean ().alias (" avg_file_open_time" )),
329333 # match on benchmark_name_clean + parameter_case_name
330- on = [' modality' , ' benchmark_name_clean' ],
331- how = ' left'
334+ on = [" modality" , " benchmark_name_clean" ],
335+ how = " left" ,
332336 )
333337 # add average file open time to each slice time
334338 # NOTE - should the file open + slice times be added per run or is average ok?
335- .with_columns (
336- (pl .col ('value' ) + pl .col ('avg_file_open_time' )).alias ('total_time' )
337- )
339+ .with_columns ((pl .col ("value" ) + pl .col ("avg_file_open_time" )).alias ("total_time" ))
338340 )
339341
340- # TODO - combine download + local read times
342+ # TODO - combine download + local read times
341343 # download_df = db.filter_tests("time_download")
342344
343345 # plot time vs number of slices (TODO - plot with extrapolation)
@@ -350,13 +352,11 @@ def plot_download_vs_stream_benchmarks(
350352 "row" : "variable" if network_tracking else "is_preloaded" ,
351353 "sharex" : "row" if network_tracking else True ,
352354 }
353- self .plot_benchmark_slices_vs_time (y_value = "value" ,
354- filename = f"{ base_filename } _vs_time.pdf" ,
355- ** plot_kwargs )
356- self .plot_benchmark_slices_vs_time (y_value = "total_time" ,
357- filename = f"{ base_filename } _vs_total_time.pdf" ,
358- ** plot_kwargs )
359-
355+ self .plot_benchmark_slices_vs_time (y_value = "value" , filename = f"{ base_filename } _vs_time.pdf" , ** plot_kwargs )
356+ self .plot_benchmark_slices_vs_time (
357+ y_value = "total_time" , filename = f"{ base_filename } _vs_total_time.pdf" , ** plot_kwargs
358+ )
359+
360360 def plot_method_rankings (self , db : BenchmarkDatabase ):
361361 """Create heatmap showing method rankings across benchmarks."""
362362 print ("Plotting method rankings heatmap..." )
@@ -365,17 +365,11 @@ def plot_method_rankings(self, db: BenchmarkDatabase):
365365 read_df = db .filter_tests ("time_remote_file_reading" ).collect ()
366366
367367 fig , axes = plt .subplots (3 , 1 , figsize = (8 , 16 ))
368- axes [0 ] = self .plot_benchmark_heatmap (
369- df = read_df , metric_order = self .file_open_order , ax = axes [0 ]
370- )
368+ axes [0 ] = self .plot_benchmark_heatmap (df = read_df , metric_order = self .file_open_order , ax = axes [0 ])
371369
372- axes [1 ] = self .plot_benchmark_heatmap (
373- df = read_df , metric_order = self .pynwb_read_order , ax = axes [1 ]
374- )
370+ axes [1 ] = self .plot_benchmark_heatmap (df = read_df , metric_order = self .pynwb_read_order , ax = axes [1 ])
375371
376- axes [2 ] = self .plot_benchmark_heatmap (
377- df = slice_df , metric_order = self .pynwb_read_order , ax = axes [2 ]
378- )
372+ axes [2 ] = self .plot_benchmark_heatmap (df = slice_df , metric_order = self .pynwb_read_order , ax = axes [2 ])
379373
380374 axes [0 ].set_title ("Remote File Reading" )
381375 axes [1 ].set_title ("Remote File Reading - PyNWB" )
@@ -385,26 +379,27 @@ def plot_method_rankings(self, db: BenchmarkDatabase):
385379 plt .savefig (self .output_directory / "method_rankings_heatmap.pdf" , dpi = 300 )
386380 plt .close ()
387381
388- def plot_performance_across_versions (self ,
389- db : BenchmarkDatabase ,
390- order : List [str ] = None ,
391- hue : str = "benchmark_name_clean" ,
392- benchmark_type : str = "time_remote_file_reading" ):
382+ def plot_performance_across_versions (
383+ self ,
384+ db : BenchmarkDatabase ,
385+ order : List [str ] = None ,
386+ hue : str = "benchmark_name_clean" ,
387+ benchmark_type : str = "time_remote_file_reading" ,
388+ ):
393389 """Plot performance changes over time for a given benchmark type."""
394390 print (f"Plotting performance over time" )
395-
391+
396392 # get polars dataframe and filter
397393 df = db .join_results_with_environments ()
398394 df = (
399- df
400- .filter (pl .col ("benchmark_name_type" ) == benchmark_type )
395+ df .filter (pl .col ("benchmark_name_type" ) == benchmark_type )
401396 .collect ()
402397 .to_pandas ()
403- .groupby (' package_name' )
398+ .groupby (" package_name" )
404399 .apply (self ._set_package_version_categorical , include_groups = False )
405400 .reset_index (level = 0 )
406401 )
407-
402+
408403 g = sns .catplot (
409404 data = df ,
410405 x = "package_version" ,
@@ -427,7 +422,7 @@ def plot_performance_across_versions(self,
427422
428423 def plot_all (self , db : BenchmarkDatabase ):
429424 """Generate all benchmark visualization plots."""
430-
425+
431426 # 1. WHICH LIBRARY SHOULD I USE TO STREAM DATA
432427 # Remote file reading / slicing benchmarks
433428 self .plot_read_benchmarks (db , suffix = "_pynwb" )
0 commit comments