@@ -335,17 +335,20 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
335335 # Define a GridSpec with unequal column widths
336336 # Column 1 is nx as wide as column 2
337337 width_ratio = 8
338- gs = gridspec .GridSpec (1 , 2 , width_ratios = [width_ratio , 1 ])
338+ gs = gridspec .GridSpec (1 , 3 , width_ratios = [width_ratio , 1 , 1 ])
339339 ax1 = fig .add_subplot (gs [0 , 0 ])
340340 ax2 = fig .add_subplot (gs [0 , 1 ])
341+ ax3 = fig .add_subplot (gs [0 , 2 ])
341342
342343 SET_SIZE_COL_NAME = "mean set size"
344+ MEAN_FP_COL_NAME = "mean false positive softmax"
343345
344346 group_by_col = self .MELTED_KNOWN_CLASS_COL
345347 df = self .melt ()
346348 grouped_df = df .groupby (group_by_col )
347349 pred_col_names = self .class_names ()
348350 set_size_col_names = [SET_SIZE_COL_NAME ]
351+ mean_fp_col_names = [MEAN_FP_COL_NAME ]
349352
350353 psets = self ._get_prediction_sets_by_softmax_threshold (
351354 min_softmax_threshold )
@@ -364,34 +367,58 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
364367
365368 mean_smx = []
366369 mean_set_size = []
370+ mean_fp_smx = []
367371
368372 for name , group in grouped_df :
369373 mean_smx_row = [name ]
370374 mean_set_size_row = [name ]
375+ mean_fp_smx_row = [name ]
371376
372377 for col in pred_col_names :
378+ if col != name :
379+ mean_fp_smx_row .append (group [col ].mean ())
373380 mean_smx_row .append (group [col ].mean ())
374381 mean_set_size_row .append (psets_df [name ])
375382
376383 mean_smx .append (mean_smx_row )
377384 mean_set_size .append (mean_set_size_row )
385+ mean_fp_smx .append (mean_fp_smx_row )
386+
387+ # Iterate over rows in mean_fp_smx
388+ new_mean_fp_smx = []
389+ for row in mean_fp_smx :
390+ # Pop the first element to get the class name
391+ class_name = row .pop (0 )
392+ # Remove zeroes from row
393+ row = [x for x in row if x != 0 ]
394+ # Get mean of the row
395+ mean = np .mean (row )
396+ new_mean_fp_smx .append ([class_name , mean ])
397+ mean_fp_smx = new_mean_fp_smx
378398
379399 col_names_1 = ['true_class_name' ] + pred_col_names
380400 col_names_2 = ['true_class_name' ] + set_size_col_names
401+ col_names_3 = ['true_class_name' ] + mean_fp_col_names
381402
382403 mean_smx_df = pd .DataFrame (mean_smx , columns = col_names_1 )
383404 mean_smx_df .set_index ('true_class_name' , inplace = True )
384405
385406 mean_set_size_df = pd .DataFrame (mean_set_size , columns = col_names_2 )
386407 mean_set_size_df .set_index ('true_class_name' , inplace = True )
387408
409+ mean_fp_smx_df = pd .DataFrame (mean_fp_smx , columns = col_names_3 )
410+ mean_fp_smx_df .set_index ('true_class_name' , inplace = True )
411+
388412 # Sort the rows and columns
389413 mean_smx_df .sort_index (axis = 0 , inplace = True ) # Sort rows
390414 mean_smx_df .sort_index (axis = 1 , inplace = True ) # Sort columns
391415
392416 mean_set_size_df .sort_index (axis = 0 , inplace = True ) # Sort rows
393417 mean_set_size_df .sort_index (axis = 1 , inplace = True ) # Sort columns
394418
419+ mean_fp_smx_df .sort_index (axis = 0 , inplace = True ) # Sort rows
420+ mean_fp_smx_df .sort_index (axis = 1 , inplace = True ) # Sort columns
421+
395422 # Remove any columns where all the rows are 0
396423 mean_smx_df = mean_smx_df .loc [:, (mean_smx_df != 0 ).any (axis = 0 )]
397424
@@ -411,31 +438,51 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
411438 weight = 'bold' , labelpad = labelpad )
412439
413440 # Create second heatmap for mean set size
414- hm2 = sns .heatmap (mean_set_size_df ,
441+ hm2 = sns .heatmap (mean_fp_smx_df ,
415442 ax = ax2 ,
443+ cmap = sns .light_palette ("forestgreen" , as_cmap = True ),
444+ annot = True ,
445+ fmt = '.2f' ,
446+ cbar = False )
447+
448+ hm2 .set_ylabel ("" )
449+
450+ # Remove y ticks
451+ hm2 .set_yticks ([])
452+
453+ # Remove x label
454+ hm2 .set_xlabel (r"$\overline{\mathrm{FP}}$" ,
455+ weight = 'bold' )
456+
457+ # Remove x ticks
458+ hm2 .set_xticks ([])
459+
460+ # Create third heatmap for mean set size
461+ hm3 = sns .heatmap (mean_set_size_df ,
462+ ax = ax3 ,
416463 cmap = sns .light_palette ("purple" , as_cmap = True ),
417464 annot = True ,
418465 fmt = '.2f' ,
419466 cbar = False )
420467
421468 # Rotate x labels
422- plt .setp (hm2 .get_xticklabels (), rotation = 90 )
469+ plt .setp (hm3 .get_xticklabels (), rotation = 90 )
423470
424471 # Set y label
425- hm2 .set_ylabel (f"MEAN PREDICTION SET SIZE, SOFTMAX > { min_softmax_threshold } " ,
472+ hm3 .set_ylabel (f"MEAN PREDICTION SET SIZE, SOFTMAX > { min_softmax_threshold } " ,
426473 weight = 'bold' , labelpad = labelpad )
427474
428475 # Position y label to the right of heatmap
429- hm2 .yaxis .set_label_position ("right" )
476+ hm3 .yaxis .set_label_position ("right" )
430477
431478 # Remove y ticks
432- hm2 .set_yticks ([])
479+ hm3 .set_yticks ([])
433480
434481 # Remove x label
435- hm2 .set_xlabel ('' )
482+ hm3 .set_xlabel ('' )
436483
437484 # Remove x ticks
438- hm2 .set_xticks ([])
485+ hm3 .set_xticks ([])
439486
440487 plt .tight_layout (w_pad = 0.1 ) # Control padding
441488
0 commit comments