Skip to content

Commit 10512ef

Browse files
committed
Add mean set size to heatmap
1 parent 1c9a269 commit 10512ef

File tree

1 file changed

+84
-13
lines changed

1 file changed

+84
-13
lines changed

src/conformist/prediction_dataset.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import seaborn as sns
55
from matplotlib.patches import Patch
66
from matplotlib.font_manager import FontProperties
7+
import matplotlib.gridspec as gridspec
78
from upsetplot import plot
89

910
from .output_dir import OutputDir
@@ -328,51 +329,117 @@ def visualize_class_counts_by_dataset(self,
328329
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',
329330
bbox_inches='tight')
330331

331-
def visualize_prediction_heatmap(self):
332-
plt.figure(figsize=(self.FIGURE_WIDTH, 8))
332+
def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
333+
fig = plt.figure(figsize=(self.FIGURE_WIDTH, 10))
334+
335+
# Create two subplots, we will create two separate heatmaps
336+
# Define a GridSpec with unequal column widths
337+
# Column 1 is nx as wide as column 2
338+
width_ratio = 8
339+
gs = gridspec.GridSpec(1, 2, width_ratios=[width_ratio, 1])
340+
ax1 = fig.add_subplot(gs[0, 0])
341+
ax2 = fig.add_subplot(gs[0, 1])
342+
343+
SET_SIZE_COL_NAME = "mean set size"
333344

334345
group_by_col = self.MELTED_KNOWN_CLASS_COL
335346
df = self.melt()
336-
337347
grouped_df = df.groupby(group_by_col)
338348
pred_col_names = self.class_names()
349+
set_size_col_names = [SET_SIZE_COL_NAME]
350+
351+
psets = self._get_prediction_sets_by_softmax_threshold(
352+
min_softmax_threshold)
353+
# Index psets by df.id
354+
psets.set_index(df[self.ID_COL], inplace=True)
355+
# Set new column set_size that contains the number of True values in other cols
356+
psets['set_size'] = psets.sum(axis=1)
357+
# Add melted known class to psets
358+
psets = psets.merge(df[[self.ID_COL, group_by_col]],
359+
left_index=True,
360+
right_on=self.ID_COL)
361+
362+
# Create new df with columns class and mean_set_size
363+
psets_df = psets.groupby(
364+
self.MELTED_KNOWN_CLASS_COL)['set_size'].mean()
339365

340366
mean_smx = []
367+
mean_set_size = []
341368

342369
for name, group in grouped_df:
343-
name = self.translate_class_name(name)
344370
mean_smx_row = [name]
371+
mean_set_size_row = [name]
345372

346373
for col in pred_col_names:
347374
mean_smx_row.append(group[col].mean())
375+
mean_set_size_row.append(psets_df[name])
348376

349377
mean_smx.append(mean_smx_row)
378+
mean_set_size.append(mean_set_size_row)
350379

351-
col_names = ['true_class_name'] + self.class_names(translate=True)
380+
col_names_1 = ['true_class_name'] + pred_col_names
381+
col_names_2 = ['true_class_name'] + set_size_col_names
352382

353-
mean_smx_df = pd.DataFrame(mean_smx, columns=col_names)
383+
mean_smx_df = pd.DataFrame(mean_smx, columns=col_names_1)
354384
mean_smx_df.set_index('true_class_name', inplace=True)
355385

386+
mean_set_size_df = pd.DataFrame(mean_set_size, columns=col_names_2)
387+
mean_set_size_df.set_index('true_class_name', inplace=True)
388+
356389
# Sort the rows and columns
357390
mean_smx_df.sort_index(axis=0, inplace=True) # Sort rows
358391
mean_smx_df.sort_index(axis=1, inplace=True) # Sort columns
359392

393+
mean_set_size_df.sort_index(axis=0, inplace=True) # Sort rows
394+
mean_set_size_df.sort_index(axis=1, inplace=True) # Sort columns
395+
360396
# Remove any columns where all the rows are 0
361397
mean_smx_df = mean_smx_df.loc[:, (mean_smx_df != 0).any(axis=0)]
362398

363399
hm = sns.heatmap(mean_smx_df,
400+
ax=ax1,
364401
cmap="coolwarm",
365402
annot=True,
366-
fmt='.2f')
403+
fmt='.2f',
404+
cbar=False)
367405

368406
labelpad = 20
369407
plt.setp(hm.get_yticklabels(), rotation=0)
370408

371-
hm.set_xlabel('MEAN PROBABILITY SCORE',
409+
hm.set_xlabel('MEAN SOFTMAX SCORE',
372410
weight='bold', labelpad=labelpad)
373411
hm.set_ylabel('TRUE CLASS',
374412
weight='bold', labelpad=labelpad)
375413

414+
# Create second heatmap for mean set size
415+
hm2 = sns.heatmap(mean_set_size_df,
416+
ax=ax2,
417+
cmap=sns.light_palette("purple", as_cmap=True),
418+
annot=True,
419+
fmt='.2f',
420+
cbar=False)
421+
422+
# Rotate x labels
423+
plt.setp(hm2.get_xticklabels(), rotation=90)
424+
425+
# Set y label
426+
hm2.set_ylabel(f"MEAN PREDICTION SET SIZE, SOFTMAX > {min_softmax_threshold}",
427+
weight='bold', labelpad=labelpad)
428+
429+
# Position y label to the right of heatmap
430+
hm2.yaxis.set_label_position("right")
431+
432+
# Remove y ticks
433+
hm2.set_yticks([])
434+
435+
# Remove x label
436+
hm2.set_xlabel('')
437+
438+
# Remove x ticks
439+
hm2.set_xticks([])
440+
441+
plt.tight_layout(w_pad=0.1) # Control padding
442+
376443
# Save the plot to a file
377444
plt.savefig(f'{self.output_dir}/prediction_heatmap.png', bbox_inches='tight')
378445

@@ -507,10 +574,7 @@ def visualize_prediction_stripplot(self,
507574
plt.savefig(f'{self.output_dir}/prediction_stripplot.png',
508575
bbox_inches='tight')
509576

510-
def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
511-
plt.figure()
512-
plt.figure(figsize=(self.FIGURE_WIDTH, 8))
513-
577+
def _get_prediction_sets_by_softmax_threshold(self, min_softmax_threshold=0.5):
514578
df = self.melt()
515579
cols = [col for col in df.columns if col in self.class_names()]
516580

@@ -523,7 +587,14 @@ def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
523587
new_row[col] = (row[col] >= min_softmax_threshold)
524588
rows.append(new_row)
525589

526-
upset_data = pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True)
590+
return pd.concat([new_df, pd.DataFrame(rows)], ignore_index=True)
591+
592+
593+
def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
594+
plt.figure()
595+
plt.figure(figsize=(self.FIGURE_WIDTH, 8))
596+
upset_data = self._get_prediction_sets_by_softmax_threshold(
597+
min_softmax_threshold)
527598
# Set a multi-index
528599
upset_data.set_index(upset_data.columns.tolist(), inplace=True)
529600

0 commit comments

Comments
 (0)