@@ -172,8 +172,9 @@ def _class_colors(self):
172172 colormap = plt .cm .get_cmap ('tab20' )
173173
174174 # Create a dictionary to map each class to a color
175+ classes = self .class_names ()
175176 class_to_color = {
176- cls : colormap (i ) for i , cls in enumerate (self . class_names () )}
177+ cls : colormap (i ) for i , cls in enumerate (classes )}
177178
178179 return class_to_color
179180
@@ -313,6 +314,54 @@ def visualize_prediction_heatmap(self):
313314 # Save the plot to a file
314315 plt .savefig (f'{ self .output_dir } /prediction_heatmap.png' , bbox_inches = 'tight' )
315316
317+ def visualize_prediction_stripplot (self ):
318+ plt .figure ()
319+
320+ # Set the font size for the entire figure
321+ plt .rcParams .update ({'font.size' : 12 })
322+
323+ df = self .melt ()
324+ cols = [col for col in df .columns if col in self .class_names ()]
325+
326+ # Create a new df new_df
327+ # Loop through rows in df. For each row, create a new row in new_df for each class in cols
328+ # For each row in new_df, add the softmax score for the corresponding class
329+
330+ new_df = pd .DataFrame (columns = ['True class' , 'Predicted class' , 'Softmax score' ])
331+
332+ rows = []
333+ for index , row in df .iterrows ():
334+ for col in cols :
335+ new_row = {
336+ 'True class' : row [self .MELTED_KNOWN_CLASS_COL ],
337+ 'Predicted class' : col ,
338+ 'Softmax score' : row [col ]
339+ }
340+ rows .append (new_row )
341+
342+ new_df = pd .concat ([new_df , pd .DataFrame (rows )], ignore_index = True )
343+
344+ # Increase the height of each row by adjusting the figure size
345+ num_classes = new_df ['True class' ].nunique ()
346+ plt .figure (figsize = (10 , num_classes * 1 )) # Adjust the height multiplier as needed
347+
348+ sns .stripplot (data = new_df ,
349+ x = 'Softmax score' ,
350+ y = 'True class' ,
351+ hue = 'Predicted class' ,
352+ jitter = 0.2 ,
353+ alpha = 0.5 ,
354+ dodge = True ,
355+ palette = self ._class_colors (),
356+ size = 5 )
357+
358+ # Save the plot to a file
359+ plt .tight_layout ()
360+ plt .subplots_adjust (hspace = 0.5 )
361+ plt .savefig (f'{ self .output_dir } /prediction_stripplot.png' , bbox_inches = 'tight' )
362+
363+
364+
316365 def prediction_sets_df (self , prediction_sets , export_to_dir = None ):
317366 # Make a copy of the DataFrame
318367 df = self .df .copy ()
0 commit comments