44import seaborn as sns
55from matplotlib .patches import Patch
66from matplotlib .font_manager import FontProperties
7+ import matplotlib .gridspec as gridspec
78from upsetplot import plot
89
910from .output_dir import OutputDir
@@ -328,51 +329,117 @@ def visualize_class_counts_by_dataset(self,
328329 plt .savefig (f'{ self .output_dir } /class_counts_by_dataset.png' ,
329330 bbox_inches = 'tight' )
330331
331- def visualize_prediction_heatmap (self ):
332- plt .figure (figsize = (self .FIGURE_WIDTH , 8 ))
332+ def visualize_prediction_heatmap (self , min_softmax_threshold = 0.5 ):
333+ fig = plt .figure (figsize = (self .FIGURE_WIDTH , 10 ))
334+
335+ # Create two subplots, we will create two separate heatmaps
336+ # Define a GridSpec with unequal column widths
337+ # Column 1 is nx as wide as column 2
338+ width_ratio = 8
339+ gs = gridspec .GridSpec (1 , 2 , width_ratios = [width_ratio , 1 ])
340+ ax1 = fig .add_subplot (gs [0 , 0 ])
341+ ax2 = fig .add_subplot (gs [0 , 1 ])
342+
343+ SET_SIZE_COL_NAME = "mean set size"
333344
334345 group_by_col = self .MELTED_KNOWN_CLASS_COL
335346 df = self .melt ()
336-
337347 grouped_df = df .groupby (group_by_col )
338348 pred_col_names = self .class_names ()
349+ set_size_col_names = [SET_SIZE_COL_NAME ]
350+
351+ psets = self ._get_prediction_sets_by_softmax_threshold (
352+ min_softmax_threshold )
353+ # Index psets by df.id
354+ psets .set_index (df [self .ID_COL ], inplace = True )
355+ # Set new column set_size that contains the number of True values in other cols
356+ psets ['set_size' ] = psets .sum (axis = 1 )
357+ # Add melted known class to psets
358+ psets = psets .merge (df [[self .ID_COL , group_by_col ]],
359+ left_index = True ,
360+ right_on = self .ID_COL )
361+
362+ # Create new df with columns class and mean_set_size
363+ psets_df = psets .groupby (
364+ self .MELTED_KNOWN_CLASS_COL )['set_size' ].mean ()
339365
340366 mean_smx = []
367+ mean_set_size = []
341368
342369 for name , group in grouped_df :
343- name = self .translate_class_name (name )
344370 mean_smx_row = [name ]
371+ mean_set_size_row = [name ]
345372
346373 for col in pred_col_names :
347374 mean_smx_row .append (group [col ].mean ())
375+ mean_set_size_row .append (psets_df [name ])
348376
349377 mean_smx .append (mean_smx_row )
378+ mean_set_size .append (mean_set_size_row )
350379
351- col_names = ['true_class_name' ] + self .class_names (translate = True )
380+ col_names_1 = ['true_class_name' ] + pred_col_names
381+ col_names_2 = ['true_class_name' ] + set_size_col_names
352382
353- mean_smx_df = pd .DataFrame (mean_smx , columns = col_names )
383+ mean_smx_df = pd .DataFrame (mean_smx , columns = col_names_1 )
354384 mean_smx_df .set_index ('true_class_name' , inplace = True )
355385
386+ mean_set_size_df = pd .DataFrame (mean_set_size , columns = col_names_2 )
387+ mean_set_size_df .set_index ('true_class_name' , inplace = True )
388+
356389 # Sort the rows and columns
357390 mean_smx_df .sort_index (axis = 0 , inplace = True ) # Sort rows
358391 mean_smx_df .sort_index (axis = 1 , inplace = True ) # Sort columns
359392
393+ mean_set_size_df .sort_index (axis = 0 , inplace = True ) # Sort rows
394+ mean_set_size_df .sort_index (axis = 1 , inplace = True ) # Sort columns
395+
360396 # Remove any columns where all the rows are 0
361397 mean_smx_df = mean_smx_df .loc [:, (mean_smx_df != 0 ).any (axis = 0 )]
362398
363399 hm = sns .heatmap (mean_smx_df ,
400+ ax = ax1 ,
364401 cmap = "coolwarm" ,
365402 annot = True ,
366- fmt = '.2f' )
403+ fmt = '.2f' ,
404+ cbar = False )
367405
368406 labelpad = 20
369407 plt .setp (hm .get_yticklabels (), rotation = 0 )
370408
371- hm .set_xlabel ('MEAN PROBABILITY SCORE' ,
409+ hm .set_xlabel ('MEAN SOFTMAX SCORE' ,
372410 weight = 'bold' , labelpad = labelpad )
373411 hm .set_ylabel ('TRUE CLASS' ,
374412 weight = 'bold' , labelpad = labelpad )
375413
414+ # Create second heatmap for mean set size
415+ hm2 = sns .heatmap (mean_set_size_df ,
416+ ax = ax2 ,
417+ cmap = sns .light_palette ("purple" , as_cmap = True ),
418+ annot = True ,
419+ fmt = '.2f' ,
420+ cbar = False )
421+
422+ # Rotate x labels
423+ plt .setp (hm2 .get_xticklabels (), rotation = 90 )
424+
425+ # Set y label
426+ hm2 .set_ylabel (f"MEAN PREDICTION SET SIZE, SOFTMAX > { min_softmax_threshold } " ,
427+ weight = 'bold' , labelpad = labelpad )
428+
429+ # Position y label to the right of heatmap
430+ hm2 .yaxis .set_label_position ("right" )
431+
432+ # Remove y ticks
433+ hm2 .set_yticks ([])
434+
435+ # Remove x label
436+ hm2 .set_xlabel ('' )
437+
438+ # Remove x ticks
439+ hm2 .set_xticks ([])
440+
441+ plt .tight_layout (w_pad = 0.1 ) # Control padding
442+
376443 # Save the plot to a file
377444 plt .savefig (f'{ self .output_dir } /prediction_heatmap.png' , bbox_inches = 'tight' )
378445
@@ -507,10 +574,7 @@ def visualize_prediction_stripplot(self,
507574 plt .savefig (f'{ self .output_dir } /prediction_stripplot.png' ,
508575 bbox_inches = 'tight' )
509576
510- def visualize_model_sets (self , min_softmax_threshold = 0.5 , color = "black" ):
511- plt .figure ()
512- plt .figure (figsize = (self .FIGURE_WIDTH , 8 ))
513-
577+ def _get_prediction_sets_by_softmax_threshold (self , min_softmax_threshold = 0.5 ):
514578 df = self .melt ()
515579 cols = [col for col in df .columns if col in self .class_names ()]
516580
@@ -523,7 +587,14 @@ def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
523587 new_row [col ] = (row [col ] >= min_softmax_threshold )
524588 rows .append (new_row )
525589
526- upset_data = pd .concat ([new_df , pd .DataFrame (rows )], ignore_index = True )
590+ return pd .concat ([new_df , pd .DataFrame (rows )], ignore_index = True )
591+
592+
593+ def visualize_model_sets (self , min_softmax_threshold = 0.5 , color = "black" ):
594+ plt .figure ()
595+ plt .figure (figsize = (self .FIGURE_WIDTH , 8 ))
596+ upset_data = self ._get_prediction_sets_by_softmax_threshold (
597+ min_softmax_threshold )
527598 # Set a multi-index
528599 upset_data .set_index (upset_data .columns .tolist (), inplace = True )
529600
0 commit comments