@@ -248,7 +248,7 @@ def visualize_class_counts_by_dataset(self,
248248 num_datasets = len (ccs .index .get_level_values (0 ).unique ())
249249 fig , axs = plt .subplots (num_datasets ,
250250 1 ,
251- figsize = (self .FIGURE_WIDTH , 2 * num_datasets ))
251+ figsize = (self .FIGURE_WIDTH , 2.5 * num_datasets ))
252252
253253 if num_datasets == 1 :
254254 axs = [axs ]
@@ -312,7 +312,7 @@ def visualize_class_counts_by_dataset(self,
312312 loc = 'lower center' ,
313313 frameon = False ,
314314 ncol = len (legend_handles )/ 4 ,
315- bbox_to_anchor = (0.5 , - 0.15 ), # Adjust position: (x, y)
315+ bbox_to_anchor = (0.5 , - 0.125 ), # Adjust position: (x, y)
316316 handletextpad = 1 , # Increase padding between legend handle and text
317317 columnspacing = 8 # Increase spacing between columns
318318 )
@@ -447,8 +447,11 @@ def softmax_summary(self):
447447 df = self .melt ()
448448 softmax_cols = [col for col in df .columns if col in self .class_names ()]
449449
450- summary_df = pd .DataFrame (columns = ['mean true positive softmax' ,
451- 'mean false positive softmax' ])
450+ true_col_name = 'mean true positive softmax'
451+ false_col_name = 'mean false positive softmax'
452+
453+ summary_df = pd .DataFrame (
454+ columns = [true_col_name , false_col_name ])
452455
453456 # For each col in softmax_cols, calculate the mean softmax score for the true class
454457 # and the mean softmax score for the false classes
@@ -463,7 +466,7 @@ def softmax_summary(self):
463466 summary_df .loc [col ] = [mean_true_pos , mean_false_pos ]
464467
465468 # Sort the DataFrame by mean true positive softmax
466- summary_df = summary_df .sort_values (by = 'mean true positive softmax' ,
469+ summary_df = summary_df .sort_values (by = true_col_name ,
467470 ascending = False )
468471
469472 # Name index "Predicted class"
@@ -483,16 +486,34 @@ def softmax_summary(self):
483486 plt .axis ('off' ) # Hide axes
484487
485488 # Create table
486- table = plt .table (cellText = summary_df .values , colLabels = summary_df .columns , rowLabels = summary_df .index , loc = 'center' , cellLoc = 'center' )
489+ table = plt .table (cellText = summary_df .values ,
490+ colLabels = summary_df .columns ,
491+ rowLabels = summary_df .index ,
492+ loc = 'left' ,
493+ cellLoc = 'center' ,
494+ )
495+
496+ plt .title ('Predicted classes' , weight = 'bold' )
497+
498+ accent_color = "#d4cbb3"
499+ # Make font color dark gray
500+ font_color = '#222222'
501+
502+ # Change border colors
503+ for key , cell in table .get_celld ().items ():
504+ cell .set_edgecolor (accent_color ) # Set border color
505+ cell .set_linewidth (1 ) # Set border width
506+ cell .get_text ().set_color (font_color )
507+ # cell.PAD = 0.2
487508
488509 # Make every other row gray
489510 for i in range (0 , len (summary_df ), 2 ):
490- table .get_celld ()[(i + 1 , 0 )].set_facecolor ('#eeeeee' )
491- table .get_celld ()[(i + 1 , 1 )].set_facecolor ('#eeeeee' )
511+ table .get_celld ()[(i + 1 , 0 )].set_facecolor (accent_color )
512+ table .get_celld ()[(i + 1 , 1 )].set_facecolor (accent_color )
492513
493514 table .auto_set_font_size (False )
494515 table .set_fontsize (self .FIGURE_FONTSIZE )
495- table .scale (1. 2 , 1. 2 ) # Scale table size
516+ table .scale (2 , 2 ) # Scale table size
496517
497518 plt .savefig (f'{ self .output_dir } /softmax_summary.png' , bbox_inches = 'tight' )
498519
0 commit comments