Skip to content

Commit befa5c9

Browse files
committed
Add mean false positive column to heatmap
1 parent 8df91c3 commit befa5c9

File tree

1 file changed

+55
-8
lines changed

1 file changed

+55
-8
lines changed

src/conformist/prediction_dataset.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,17 +335,20 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
335335
# Define a GridSpec with unequal column widths
336336
# Column 1 is nx as wide as column 2
337337
width_ratio = 8
338-
gs = gridspec.GridSpec(1, 2, width_ratios=[width_ratio, 1])
338+
gs = gridspec.GridSpec(1, 3, width_ratios=[width_ratio, 1, 1])
339339
ax1 = fig.add_subplot(gs[0, 0])
340340
ax2 = fig.add_subplot(gs[0, 1])
341+
ax3 = fig.add_subplot(gs[0, 2])
341342

342343
SET_SIZE_COL_NAME = "mean set size"
344+
MEAN_FP_COL_NAME = "mean false positive softmax"
343345

344346
group_by_col = self.MELTED_KNOWN_CLASS_COL
345347
df = self.melt()
346348
grouped_df = df.groupby(group_by_col)
347349
pred_col_names = self.class_names()
348350
set_size_col_names = [SET_SIZE_COL_NAME]
351+
mean_fp_col_names = [MEAN_FP_COL_NAME]
349352

350353
psets = self._get_prediction_sets_by_softmax_threshold(
351354
min_softmax_threshold)
@@ -364,34 +367,58 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
364367

365368
mean_smx = []
366369
mean_set_size = []
370+
mean_fp_smx = []
367371

368372
for name, group in grouped_df:
369373
mean_smx_row = [name]
370374
mean_set_size_row = [name]
375+
mean_fp_smx_row = [name]
371376

372377
for col in pred_col_names:
378+
if col != name:
379+
mean_fp_smx_row.append(group[col].mean())
373380
mean_smx_row.append(group[col].mean())
374381
mean_set_size_row.append(psets_df[name])
375382

376383
mean_smx.append(mean_smx_row)
377384
mean_set_size.append(mean_set_size_row)
385+
mean_fp_smx.append(mean_fp_smx_row)
386+
387+
# Iterate over rows in mean_fp_smx
388+
new_mean_fp_smx = []
389+
for row in mean_fp_smx:
390+
# Pop the first element to get the class name
391+
class_name = row.pop(0)
392+
# Remove zeroes from row
393+
row = [x for x in row if x != 0]
394+
# Get mean of the row
395+
mean = np.mean(row)
396+
new_mean_fp_smx.append([class_name, mean])
397+
mean_fp_smx = new_mean_fp_smx
378398

379399
col_names_1 = ['true_class_name'] + pred_col_names
380400
col_names_2 = ['true_class_name'] + set_size_col_names
401+
col_names_3 = ['true_class_name'] + mean_fp_col_names
381402

382403
mean_smx_df = pd.DataFrame(mean_smx, columns=col_names_1)
383404
mean_smx_df.set_index('true_class_name', inplace=True)
384405

385406
mean_set_size_df = pd.DataFrame(mean_set_size, columns=col_names_2)
386407
mean_set_size_df.set_index('true_class_name', inplace=True)
387408

409+
mean_fp_smx_df = pd.DataFrame(mean_fp_smx, columns=col_names_3)
410+
mean_fp_smx_df.set_index('true_class_name', inplace=True)
411+
388412
# Sort the rows and columns
389413
mean_smx_df.sort_index(axis=0, inplace=True) # Sort rows
390414
mean_smx_df.sort_index(axis=1, inplace=True) # Sort columns
391415

392416
mean_set_size_df.sort_index(axis=0, inplace=True) # Sort rows
393417
mean_set_size_df.sort_index(axis=1, inplace=True) # Sort columns
394418

419+
mean_fp_smx_df.sort_index(axis=0, inplace=True) # Sort rows
420+
mean_fp_smx_df.sort_index(axis=1, inplace=True) # Sort columns
421+
395422
# Remove any columns where all the rows are 0
396423
mean_smx_df = mean_smx_df.loc[:, (mean_smx_df != 0).any(axis=0)]
397424

@@ -411,31 +438,51 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
411438
weight='bold', labelpad=labelpad)
412439

413440
# Create second heatmap for mean set size
414-
hm2 = sns.heatmap(mean_set_size_df,
441+
hm2 = sns.heatmap(mean_fp_smx_df,
415442
ax=ax2,
443+
cmap=sns.light_palette("forestgreen", as_cmap=True),
444+
annot=True,
445+
fmt='.2f',
446+
cbar=False)
447+
448+
hm2.set_ylabel("")
449+
450+
# Remove y ticks
451+
hm2.set_yticks([])
452+
453+
# Remove x label
454+
hm2.set_xlabel(r"$\overline{\mathrm{FP}}$",
455+
weight='bold')
456+
457+
# Remove x ticks
458+
hm2.set_xticks([])
459+
460+
# Create third heatmap for mean set size
461+
hm3 = sns.heatmap(mean_set_size_df,
462+
ax=ax3,
416463
cmap=sns.light_palette("purple", as_cmap=True),
417464
annot=True,
418465
fmt='.2f',
419466
cbar=False)
420467

421468
# Rotate x labels
422-
plt.setp(hm2.get_xticklabels(), rotation=90)
469+
plt.setp(hm3.get_xticklabels(), rotation=90)
423470

424471
# Set y label
425-
hm2.set_ylabel(f"MEAN PREDICTION SET SIZE, SOFTMAX > {min_softmax_threshold}",
472+
hm3.set_ylabel(f"MEAN PREDICTION SET SIZE, SOFTMAX > {min_softmax_threshold}",
426473
weight='bold', labelpad=labelpad)
427474

428475
# Position y label to the right of heatmap
429-
hm2.yaxis.set_label_position("right")
476+
hm3.yaxis.set_label_position("right")
430477

431478
# Remove y ticks
432-
hm2.set_yticks([])
479+
hm3.set_yticks([])
433480

434481
# Remove x label
435-
hm2.set_xlabel('')
482+
hm3.set_xlabel('')
436483

437484
# Remove x ticks
438-
hm2.set_xticks([])
485+
hm3.set_xticks([])
439486

440487
plt.tight_layout(w_pad=0.1) # Control padding
441488

0 commit comments

Comments
 (0)