Skip to content

Commit 58d1f8c

Browse files
Also compare batch measurements in nvbench_compare.py
Fixes: #247
1 parent 935bb0b commit 58d1f8c

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

scripts/nvbench_compare.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def compare_benches(ref_benches, cmp_benches, threshold, plot):
137137
headers.append("Status")
138138
colalign.append("center")
139139

140+
added_batch_headers = False
141+
140142
for device_id in device_ids:
141143
rows = []
142144
plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}}
@@ -179,6 +181,12 @@ def lookup_summary(summaries, tag):
179181
ref_noise_summary = lookup_summary(
180182
ref_summaries, "nv/cold/time/gpu/stdev/relative"
181183
)
184+
cmp_batch_summary = lookup_summary(
185+
cmp_summaries, "nv/batch/time/gpu/mean"
186+
)
187+
ref_batch_summary = lookup_summary(
188+
ref_summaries, "nv/batch/time/gpu/mean"
189+
)
182190

183191
# TODO: Use other timings, too. Maybe multiple rows, with a
184192
# "Timing" column + values "CPU/GPU/Batch"?
@@ -192,6 +200,20 @@ def lookup_summary(summaries, tag):
192200
):
193201
continue
194202

203+
has_batch_data = cmp_batch_summary and ref_batch_summary
204+
if not added_batch_headers:
205+
headers.append("B Ref Time")
206+
colalign.append("right")
207+
headers.append("B Cmp Time")
208+
colalign.append("right")
209+
headers.append("B Diff")
210+
colalign.append("right")
211+
headers.append("B %Diff")
212+
colalign.append("right")
213+
headers.append("B Status")
214+
colalign.append("center")
215+
added_batch_headers = True
216+
195217
def extract_value(summary):
196218
summary_data = summary["data"]
197219
value_data = next(
@@ -204,6 +226,9 @@ def extract_value(summary):
204226
ref_time = extract_value(ref_time_summary)
205227
cmp_noise = extract_value(cmp_noise_summary)
206228
ref_noise = extract_value(ref_noise_summary)
229+
if has_batch_data:
230+
cmp_batch_time = extract_value(cmp_batch_summary)
231+
ref_batch_time = extract_value(ref_batch_summary)
207232

208233
# Convert string encoding to expected numerics:
209234
cmp_time = float(cmp_time)
@@ -212,6 +237,12 @@ def extract_value(summary):
212237
diff = cmp_time - ref_time
213238
frac_diff = diff / ref_time
214239

240+
if has_batch_data:
241+
cmp_batch_time = float(cmp_batch_time)
242+
ref_batch_time = float(ref_batch_time)
243+
diff_batch = cmp_batch_time - ref_batch_time
244+
frac_diff_batch = diff_batch / ref_batch_time
245+
215246
if ref_noise and cmp_noise:
216247
ref_noise = float(ref_noise)
217248
cmp_noise = float(cmp_noise)
@@ -269,6 +300,19 @@ def extract_value(summary):
269300
failure_count += 1
270301
status = Fore.RED + "SLOW" + Fore.RESET
271302

303+
if has_batch_data:
304+
if (
305+
abs(frac_diff_batch) <= 0.01
306+
): # TODO(bgruber): what value to use here?
307+
pass_count += 1
308+
batch_status = Fore.BLUE + "SAME" + Fore.RESET
309+
elif diff_batch < 0:
310+
failure_count += 1
311+
batch_status = Fore.GREEN + "FAST" + Fore.RESET
312+
else:
313+
failure_count += 1
314+
batch_status = Fore.RED + "SLOW" + Fore.RESET
315+
272316
if abs(frac_diff) >= threshold:
273317
row.append(format_duration(ref_time))
274318
row.append(format_percentage(ref_noise))
@@ -278,6 +322,13 @@ def extract_value(summary):
278322
row.append(format_percentage(frac_diff))
279323
row.append(status)
280324

325+
if has_batch_data:
326+
row.append(format_duration(ref_batch_time))
327+
row.append(format_duration(cmp_batch_time))
328+
row.append(format_duration(diff_batch))
329+
row.append(format_percentage(frac_diff_batch))
330+
row.append(batch_status)
331+
281332
rows.append(row)
282333

283334
if len(rows) == 0:

0 commit comments

Comments
 (0)