@@ -99,7 +99,13 @@ def format_percentage(percentage):
9999 return "%0.2f%%" % (percentage * 100.0 )
100100
101101
102- def compare_benches (ref_benches , cmp_benches , threshold ):
102+ def compare_benches (ref_benches , cmp_benches , threshold , plot ):
103+ if plot :
104+ import matplotlib .pyplot as plt
105+ import seaborn as sns
106+
107+ sns .set ()
108+
103109 for cmp_bench in cmp_benches :
104110 ref_bench = find_matching_bench (cmp_bench , ref_benches )
105111 if not ref_bench :
@@ -135,6 +141,8 @@ def compare_benches(ref_benches, cmp_benches, threshold):
135141 for device_id in device_ids :
136142
137143 rows = []
144+ plot_data = {'cmp' : {}, 'ref' : {}, 'cmp_noise' : {}, 'ref_noise' : {}}
145+
138146 for cmp_state in cmp_states :
139147 cmp_state_name = cmp_state ["name" ]
140148 ref_state = next (filter (lambda st : st ["name" ] == cmp_state_name ,
@@ -207,6 +215,27 @@ def extract_value(summary):
207215 else :
208216 min_noise = None # Noise is inf
209217
218+ if plot :
219+ axis_name = []
220+ axis_value = "--"
221+ for aid in range (len (axis_values )):
222+ if axis_values [aid ]["name" ] != plot :
223+ axis_name .append ("{} = {}" .format (axis_values [aid ]["name" ], axis_values [aid ]["value" ]))
224+ else :
225+ axis_value = float (axis_values [aid ]["value" ])
226+ axis_name = ', ' .join (axis_name )
227+
228+ if axis_name not in plot_data ['cmp' ]:
229+ plot_data ['cmp' ][axis_name ] = {}
230+ plot_data ['ref' ][axis_name ] = {}
231+ plot_data ['cmp_noise' ][axis_name ] = {}
232+ plot_data ['ref_noise' ][axis_name ] = {}
233+
234+ plot_data ['cmp' ][axis_name ][axis_value ] = cmp_time
235+ plot_data ['ref' ][axis_name ][axis_value ] = ref_time
236+ plot_data ['cmp_noise' ][axis_name ][axis_value ] = cmp_noise
237+ plot_data ['ref_noise' ][axis_name ][axis_value ] = ref_noise
238+
210239 global config_count
211240 global unknown_count
212241 global pass_count
@@ -252,12 +281,41 @@ def extract_value(summary):
252281
253282 print ("" )
254283
284+ if plot :
285+ plt .xscale ("log" )
286+ plt .yscale ("log" )
287+ plt .xlabel (plot )
288+ plt .ylabel ("time [s]" )
289+ plt .title (device ["name" ])
290+
291+ def plot_line (key , shape , label ):
292+ x = [float (x ) for x in plot_data [key ][axis ].keys ()]
293+ y = list (plot_data [key ][axis ].values ())
294+
295+ noise = list (plot_data [key + '_noise' ][axis ].values ())
296+
297+ top = [y [i ] + y [i ] * noise [i ] for i in range (len (x ))]
298+ bottom = [y [i ] - y [i ] * noise [i ] for i in range (len (x ))]
299+
300+ p = plt .plot (x , y , shape , marker = 'o' , label = label )
301+ plt .fill_between (x , bottom , top , color = p [0 ].get_color (), alpha = 0.1 )
302+
303+
304+ for axis in plot_data ['cmp' ].keys ():
305+ plot_line ('cmp' , '-' , axis )
306+ plot_line ('ref' , '--' , axis + ' ref' )
307+
308+ plt .legend ()
309+ plt .show ()
310+
255311
256312def main ():
257313 help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
258314 parser = argparse .ArgumentParser (prog = 'nvbench_compare' , usage = help_text )
259315 parser .add_argument ('--threshold-diff' , type = float , dest = 'threshold' , default = 0.0 ,
260316 help = 'only show benchmarks where percentage diff is >= THRESHOLD' )
317+ parser .add_argument ('--plot-along' , type = str , dest = 'plot' , default = None ,
318+ help = 'plot results' )
261319
262320 args , files_or_dirs = parser .parse_known_args ()
263321 print (files_or_dirs )
@@ -294,7 +352,7 @@ def main():
294352 print ("Device sections do not match." )
295353 sys .exit (1 )
296354
297- compare_benches (ref_root ["benchmarks" ], cmp_root ["benchmarks" ], args .threshold )
355+ compare_benches (ref_root ["benchmarks" ], cmp_root ["benchmarks" ], args .threshold , args . plot )
298356
299357 print ("# Summary\n " )
300358 print ("- Total Matches: %d" % config_count )
0 commit comments