Skip to content

Commit 0ce45af

Browse files
authored
Plot comparison results (#90)
1 parent 92286e1 commit 0ce45af

File tree

1 file changed

+60
-2
lines changed

1 file changed

+60
-2
lines changed

scripts/nvbench_compare.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ def format_percentage(percentage):
9999
return "%0.2f%%" % (percentage * 100.0)
100100

101101

102-
def compare_benches(ref_benches, cmp_benches, threshold):
102+
def compare_benches(ref_benches, cmp_benches, threshold, plot):
103+
if plot:
104+
import matplotlib.pyplot as plt
105+
import seaborn as sns
106+
107+
sns.set()
108+
103109
for cmp_bench in cmp_benches:
104110
ref_bench = find_matching_bench(cmp_bench, ref_benches)
105111
if not ref_bench:
@@ -135,6 +141,8 @@ def compare_benches(ref_benches, cmp_benches, threshold):
135141
for device_id in device_ids:
136142

137143
rows = []
144+
plot_data = {'cmp': {}, 'ref': {}, 'cmp_noise': {}, 'ref_noise': {}}
145+
138146
for cmp_state in cmp_states:
139147
cmp_state_name = cmp_state["name"]
140148
ref_state = next(filter(lambda st: st["name"] == cmp_state_name,
@@ -207,6 +215,27 @@ def extract_value(summary):
207215
else:
208216
min_noise = None # Noise is inf
209217

218+
if plot:
219+
axis_name = []
220+
axis_value = "--"
221+
for aid in range(len(axis_values)):
222+
if axis_values[aid]["name"] != plot:
223+
axis_name.append("{} = {}".format(axis_values[aid]["name"], axis_values[aid]["value"]))
224+
else:
225+
axis_value = float(axis_values[aid]["value"])
226+
axis_name = ', '.join(axis_name)
227+
228+
if axis_name not in plot_data['cmp']:
229+
plot_data['cmp'][axis_name] = {}
230+
plot_data['ref'][axis_name] = {}
231+
plot_data['cmp_noise'][axis_name] = {}
232+
plot_data['ref_noise'][axis_name] = {}
233+
234+
plot_data['cmp'][axis_name][axis_value] = cmp_time
235+
plot_data['ref'][axis_name][axis_value] = ref_time
236+
plot_data['cmp_noise'][axis_name][axis_value] = cmp_noise
237+
plot_data['ref_noise'][axis_name][axis_value] = ref_noise
238+
210239
global config_count
211240
global unknown_count
212241
global pass_count
@@ -252,12 +281,41 @@ def extract_value(summary):
252281

253282
print("")
254283

284+
if plot:
285+
plt.xscale("log")
286+
plt.yscale("log")
287+
plt.xlabel(plot)
288+
plt.ylabel("time [s]")
289+
plt.title(device["name"])
290+
291+
def plot_line(key, shape, label):
292+
x = [float(x) for x in plot_data[key][axis].keys()]
293+
y = list(plot_data[key][axis].values())
294+
295+
noise = list(plot_data[key + '_noise'][axis].values())
296+
297+
top = [y[i] + y[i] * noise[i] for i in range(len(x))]
298+
bottom = [y[i] - y[i] * noise[i] for i in range(len(x))]
299+
300+
p = plt.plot(x, y, shape, marker='o', label=label)
301+
plt.fill_between(x, bottom, top, color=p[0].get_color(), alpha=0.1)
302+
303+
304+
for axis in plot_data['cmp'].keys():
305+
plot_line('cmp', '-', axis)
306+
plot_line('ref', '--', axis + ' ref')
307+
308+
plt.legend()
309+
plt.show()
310+
255311

256312
def main():
257313
help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]"
258314
parser = argparse.ArgumentParser(prog='nvbench_compare', usage=help_text)
259315
parser.add_argument('--threshold-diff', type=float, dest='threshold', default=0.0,
260316
help='only show benchmarks where percentage diff is >= THRESHOLD')
317+
parser.add_argument('--plot-along', type=str, dest='plot', default=None,
318+
help='plot results')
261319

262320
args, files_or_dirs = parser.parse_known_args()
263321
print(files_or_dirs)
@@ -294,7 +352,7 @@ def main():
294352
print("Device sections do not match.")
295353
sys.exit(1)
296354

297-
compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold)
355+
compare_benches(ref_root["benchmarks"], cmp_root["benchmarks"], args.threshold, args.plot)
298356

299357
print("# Summary\n")
300358
print("- Total Matches: %d" % config_count)

0 commit comments

Comments
 (0)