@@ -137,6 +137,8 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
137137 headers .append ("Status" )
138138 colalign .append ("center" )
139139
140+ added_batch_headers = False
141+
140142 for device_id in device_ids :
141143 rows = []
142144 plot_data = {"cmp" : {}, "ref" : {}, "cmp_noise" : {}, "ref_noise" : {}}
@@ -179,6 +181,12 @@ def lookup_summary(summaries, tag):
179181 ref_noise_summary = lookup_summary (
180182 ref_summaries , "nv/cold/time/gpu/stdev/relative"
181183 )
184+ cmp_batch_summary = lookup_summary (
185+ cmp_summaries , "nv/batch/time/gpu/mean"
186+ )
187+ ref_batch_summary = lookup_summary (
188+ ref_summaries , "nv/batch/time/gpu/mean"
189+ )
182190
183191 # TODO: Use other timings, too. Maybe multiple rows, with a
184192 # "Timing" column + values "CPU/GPU/Batch"?
@@ -192,6 +200,20 @@ def lookup_summary(summaries, tag):
192200 ):
193201 continue
194202
203+ has_batch_data = cmp_batch_summary and ref_batch_summary
204+ if not added_batch_headers :
205+ headers .append ("B Ref Time" )
206+ colalign .append ("right" )
207+ headers .append ("B Cmp Time" )
208+ colalign .append ("right" )
209+ headers .append ("B Diff" )
210+ colalign .append ("right" )
211+ headers .append ("B %Diff" )
212+ colalign .append ("right" )
213+ headers .append ("B Status" )
214+ colalign .append ("center" )
215+ added_batch_headers = True
216+
195217 def extract_value (summary ):
196218 summary_data = summary ["data" ]
197219 value_data = next (
@@ -204,6 +226,9 @@ def extract_value(summary):
204226 ref_time = extract_value (ref_time_summary )
205227 cmp_noise = extract_value (cmp_noise_summary )
206228 ref_noise = extract_value (ref_noise_summary )
229+ if has_batch_data :
230+ cmp_batch_time = extract_value (cmp_batch_summary )
231+ ref_batch_time = extract_value (ref_batch_summary )
207232
208233 # Convert string encoding to expected numerics:
209234 cmp_time = float (cmp_time )
@@ -212,6 +237,12 @@ def extract_value(summary):
212237 diff = cmp_time - ref_time
213238 frac_diff = diff / ref_time
214239
240+ if has_batch_data :
241+ cmp_batch_time = float (cmp_batch_time )
242+ ref_batch_time = float (ref_batch_time )
243+ diff_batch = cmp_batch_time - ref_batch_time
244+ frac_diff_batch = diff_batch / ref_batch_time
245+
215246 if ref_noise and cmp_noise :
216247 ref_noise = float (ref_noise )
217248 cmp_noise = float (cmp_noise )
@@ -269,6 +300,19 @@ def extract_value(summary):
269300 failure_count += 1
270301 status = Fore .RED + "SLOW" + Fore .RESET
271302
303+ if has_batch_data :
304+ if (
305+ abs (frac_diff_batch ) <= 0.01
306+ ): # TODO(bgruber): what value to use here?
307+ pass_count += 1
308+ batch_status = Fore .BLUE + "SAME" + Fore .RESET
309+ elif diff_batch < 0 :
310+ failure_count += 1
311+ batch_status = Fore .GREEN + "FAST" + Fore .RESET
312+ else :
313+ failure_count += 1
314+ batch_status = Fore .RED + "SLOW" + Fore .RESET
315+
272316 if abs (frac_diff ) >= threshold :
273317 row .append (format_duration (ref_time ))
274318 row .append (format_percentage (ref_noise ))
@@ -278,6 +322,13 @@ def extract_value(summary):
278322 row .append (format_percentage (frac_diff ))
279323 row .append (status )
280324
325+ if has_batch_data :
326+ row .append (format_duration (ref_batch_time ))
327+ row .append (format_duration (cmp_batch_time ))
328+ row .append (format_duration (diff_batch ))
329+ row .append (format_percentage (frac_diff_batch ))
330+ row .append (batch_status )
331+
281332 rows .append (row )
282333
283334 if len (rows ) == 0 :
0 commit comments