@@ -344,12 +344,14 @@ def visualize_class_counts_by_dataset(self,
344344 bbox_inches = 'tight' , format = 'pdf' )
345345
346346 def visualize_prediction_heatmap (self , min_softmax_threshold = 0.5 ):
347- fig = plt .figure (figsize = (self .FIGURE_WIDTH , 10 ))
347+ fig = plt .figure ()
348+ plt .rcParams .update ({'font.size' : 6 })
349+ # fig = plt.figure(figsize=(self.FIGURE_WIDTH, 10))
348350
349351 # Create two subplots, we will create two separate heatmaps
350352 # Define a GridSpec with unequal column widths
351353 # Column 1 is nx as wide as column 2
352- width_ratio = 8
354+ width_ratio = 13.5
353355 gs = gridspec .GridSpec (1 , 3 , width_ratios = [width_ratio , 1 , 1 ])
354356 ax1 = fig .add_subplot (gs [0 , 0 ])
355357 ax2 = fig .add_subplot (gs [0 , 1 ])
@@ -444,13 +446,15 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
444446 fmt = '.2f' ,
445447 cbar = False )
446448
447- labelpad = 20
448449 plt .setp (hm .get_yticklabels (), rotation = 0 )
449450
450451 hm .set_xlabel ('MEAN SOFTMAX SCORE' ,
451- weight = 'bold' , labelpad = labelpad )
452+ weight = 'bold' , labelpad = 10 )
452453 hm .set_ylabel ('TRUE CLASS' ,
453- weight = 'bold' , labelpad = labelpad )
454+ weight = 'bold' , labelpad = - 10 )
455+
456+ hm .yaxis .set_tick_params (pad = 0 )
457+ hm .xaxis .set_tick_params (pad = 0 )
454458
455459 # Create second heatmap for mean set size
456460 hm2 = sns .heatmap (mean_fp_smx_df ,
@@ -501,10 +505,12 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
501505 # Remove x ticks
502506 hm3 .set_xticks ([])
503507
504- plt .tight_layout (w_pad = 0.1 ) # Control padding
508+ plt .tight_layout (w_pad = 0 ) # Control padding
509+
510+ fig .set_size_inches (4 , 4 )
505511
506512 # Save the plot to a file
507- plt .savefig (f'{ self .output_dir } /prediction_heatmap.pdf' , bbox_inches = 'tight' , format = 'pdf' )
513+ plt .savefig (f'{ self .output_dir } /prediction_heatmap.pdf' , bbox_inches = 'tight' , format = 'pdf' , pad_inches = 0 )
508514
509515 def softmax_summary (self ):
510516 df = self .melt ()
@@ -581,8 +587,14 @@ def softmax_summary(self):
581587 plt .savefig (f'{ self .output_dir } /softmax_summary.pdf' , bbox_inches = 'tight' , format = 'pdf' )
582588
583589 def visualize_prediction_stripplot (self ,
584- custom_color_palette = None ):
585- plt .figure ()
590+ custom_color_palette = None ,
591+ legend_top_padding = 0.125 ):
592+ fig = plt .figure ()
593+
594+ # Set fontsize to 6
595+ plt .rcParams .update ({'font.size' : 6 })
596+
597+ # figsize=(self.FIGURE_WIDTH, 2.5 * num_datasets)
586598
587599 df = self .melt ()
588600 cols = [col for col in df .columns if col in self .class_names ()]
@@ -608,7 +620,7 @@ def visualize_prediction_stripplot(self,
608620
609621 # Increase the height of each row by adjusting the figure size
610622 num_classes = new_df ['True class' ].nunique ()
611- plt .figure (figsize = (self .FIGURE_WIDTH , num_classes * 1 )) # Adjust the height multiplier as needed
623+ # plt.figure(figsize=(self.FIGURE_WIDTH, num_classes * 1)) # Adjust the height multiplier as needed
612624
613625 ax = plt .gca ()
614626 # Add light gray background to every other row
@@ -628,7 +640,7 @@ def visualize_prediction_stripplot(self,
628640 dodge = True ,
629641 palette = self ._class_colors (
630642 custom_color_palette = custom_color_palette ),
631- size = 4 ,
643+ size = 2 ,
632644 ax = ax ,
633645 order = class_names )
634646
@@ -648,16 +660,30 @@ def visualize_prediction_stripplot(self,
648660 loc = 'lower center' ,
649661 frameon = False ,
650662 ncol = 4 ,
651- bbox_to_anchor = (0.5 , - 0.2 ), # Adjust position: (x, y)
663+ bbox_to_anchor = (0.35 , - 0.2 ),
652664 handletextpad = 1 , # Increase padding between legend handle and text
653- columnspacing = 8 # Increase spacing between columns
665+ columnspacing = 2 # Increase spacing between columns
654666 )
667+
668+ # legend = fig.legend(legend_handles,
669+ # class_names,
670+ # title=title,
671+ # loc='lower center',
672+ # frameon=False,
673+ # ncol=len(legend_handles)/4,
674+ # bbox_to_anchor=(0.5, 0-legend_top_padding), # Adjust position: (x, y)
675+ # handletextpad=1, # Increase padding between legend handle and text
676+ # columnspacing=2 # Increase spacing between columns
677+ # )
655678 font_properties = FontProperties (weight = 'bold' )
656679 legend .get_title ().set_font_properties (font_properties )
657680
681+ # plt.tight_layout()
682+
683+ fig .set_size_inches (3.75 , 6 )
684+
658685 # Save the plot to a file
659- plt .tight_layout ()
660- plt .subplots_adjust (hspace = 0.5 )
686+ # plt.subplots_adjust(hspace=0.5)
661687 plt .savefig (f'{ self .output_dir } /prediction_stripplot.pdf' ,
662688 bbox_inches = 'tight' , format = 'pdf' )
663689
@@ -678,20 +704,28 @@ def _get_prediction_sets_by_softmax_threshold(self, min_softmax_threshold=0.5):
678704
679705
680706 def visualize_model_sets (self , min_softmax_threshold = 0.5 , color = "black" ):
681- plt .figure ()
682- plt .figure (figsize = (self .FIGURE_WIDTH , 8 ))
707+ fig = plt .figure (figsize = (3.75 , 5 ))
708+ # fig.set_size_inches(3.75, 1)
709+
710+ # Set fontsize to 6
711+ plt .rcParams .update ({'font.size' : 6 })
712+
683713 upset_data = self ._get_prediction_sets_by_softmax_threshold (
684714 min_softmax_threshold )
685715 # Set a multi-index
686716 upset_data .set_index (upset_data .columns .tolist (), inplace = True )
687717
688- plot (upset_data ,
718+ plts = plot (upset_data ,
719+ fig = fig ,
720+ element_size = 9 ,
689721 sort_by = "cardinality" ,
690722 facecolor = color ,
691- show_counts = "%d" ,
692- show_percentages = "{:.0%}" ,
723+ show_counts = "%d " ,
724+ show_percentages = False ,
693725 orientation = 'horizontal' ,
694726 min_subset_size = 3 )
727+
728+
695729 plt .savefig (f'{ self .output_dir } /upset.pdf' , bbox_inches = 'tight' , format = 'pdf' )
696730
697731 def prediction_sets_df (self , prediction_sets , export_to_dir = None ):
0 commit comments