Skip to content

Commit 00e12ce

Browse files
committed
Strip plot with softmax scores of all predictions
1 parent dbbf856 commit 00e12ce

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

src/conformist/prediction_dataset.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,9 @@ def _class_colors(self):
172172
colormap = plt.cm.get_cmap('tab20')
173173

174174
# Create a dictionary to map each class to a color
175+
classes = self.class_names()
175176
class_to_color = {
176-
cls: colormap(i) for i, cls in enumerate(self.class_names())}
177+
cls: colormap(i) for i, cls in enumerate(classes)}
177178

178179
return class_to_color
179180

@@ -313,6 +314,54 @@ def visualize_prediction_heatmap(self):
313314
# Save the plot to a file
314315
plt.savefig(f'{self.output_dir}/prediction_heatmap.png', bbox_inches='tight')
315316

317+
def visualize_prediction_stripplot(self):
318+
plt.figure()
319+
320+
# Set the font size for the entire figure
321+
plt.rcParams.update({'font.size': 12})
322+
323+
df = self.melt()
324+
cols = [col for col in df.columns if col in self.class_names()]
325+
326+
# Create a new df new_df
327+
# Loop through rows in df. For each row, create a new row in new_df for each class in cols
328+
# For each row in new_df, add the softmax score for the corresponding class
329+
330+
new_df = pd.DataFrame(columns=['True class', 'Predicted class', 'Softmax score'])
331+
332+
rows = []
333+
for index, row in df.iterrows():
334+
for col in cols:
335+
new_row = {
336+
'True class': row[self.MELTED_KNOWN_CLASS_COL],
337+
'Predicted class': col,
338+
'Softmax score': row[col]
339+
}
340+
rows.append(new_row)
341+
342+
new_df = pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True)
343+
344+
# Increase the height of each row by adjusting the figure size
345+
num_classes = new_df['True class'].nunique()
346+
plt.figure(figsize=(10, num_classes * 1)) # Adjust the height multiplier as needed
347+
348+
sns.stripplot(data=new_df,
349+
x='Softmax score',
350+
y='True class',
351+
hue='Predicted class',
352+
jitter=0.2,
353+
alpha=0.5,
354+
dodge=True,
355+
palette=self._class_colors(),
356+
size=5)
357+
358+
# Save the plot to a file
359+
plt.tight_layout()
360+
plt.subplots_adjust(hspace=0.5)
361+
plt.savefig(f'{self.output_dir}/prediction_stripplot.png', bbox_inches='tight')
362+
363+
364+
316365
def prediction_sets_df(self, prediction_sets, export_to_dir=None):
317366
# Make a copy of the DataFrame
318367
df = self.df.copy()

0 commit comments

Comments
 (0)