Skip to content

Commit 4d19859

Browse files
committed
WIP
1 parent 8ee6ee1 commit 4d19859

File tree

1 file changed

+54
-20
lines changed

1 file changed

+54
-20
lines changed

src/conformist/prediction_dataset.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,14 @@ def visualize_class_counts_by_dataset(self,
344344
bbox_inches='tight', format='pdf')
345345

346346
def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
347-
fig = plt.figure(figsize=(self.FIGURE_WIDTH, 10))
347+
fig = plt.figure()
348+
plt.rcParams.update({'font.size': 6})
349+
# fig = plt.figure(figsize=(self.FIGURE_WIDTH, 10))
348350

349351
# Create two subplots, we will create two separate heatmaps
350352
# Define a GridSpec with unequal column widths
351353
# Column 1 is nx as wide as column 2
352-
width_ratio = 8
354+
width_ratio = 13.5
353355
gs = gridspec.GridSpec(1, 3, width_ratios=[width_ratio, 1, 1])
354356
ax1 = fig.add_subplot(gs[0, 0])
355357
ax2 = fig.add_subplot(gs[0, 1])
@@ -444,13 +446,15 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
444446
fmt='.2f',
445447
cbar=False)
446448

447-
labelpad = 20
448449
plt.setp(hm.get_yticklabels(), rotation=0)
449450

450451
hm.set_xlabel('MEAN SOFTMAX SCORE',
451-
weight='bold', labelpad=labelpad)
452+
weight='bold', labelpad=10)
452453
hm.set_ylabel('TRUE CLASS',
453-
weight='bold', labelpad=labelpad)
454+
weight='bold', labelpad=-10)
455+
456+
hm.yaxis.set_tick_params(pad=0)
457+
hm.xaxis.set_tick_params(pad=0)
454458

455459
# Create second heatmap for mean set size
456460
hm2 = sns.heatmap(mean_fp_smx_df,
@@ -501,10 +505,12 @@ def visualize_prediction_heatmap(self, min_softmax_threshold=0.5):
501505
# Remove x ticks
502506
hm3.set_xticks([])
503507

504-
plt.tight_layout(w_pad=0.1) # Control padding
508+
plt.tight_layout(w_pad=0) # Control padding
509+
510+
fig.set_size_inches(4, 4)
505511

506512
# Save the plot to a file
507-
plt.savefig(f'{self.output_dir}/prediction_heatmap.pdf', bbox_inches='tight', format='pdf')
513+
plt.savefig(f'{self.output_dir}/prediction_heatmap.pdf', bbox_inches='tight', format='pdf', pad_inches = 0)
508514

509515
def softmax_summary(self):
510516
df = self.melt()
@@ -581,8 +587,14 @@ def softmax_summary(self):
581587
plt.savefig(f'{self.output_dir}/softmax_summary.pdf', bbox_inches='tight', format='pdf')
582588

583589
def visualize_prediction_stripplot(self,
584-
custom_color_palette=None):
585-
plt.figure()
590+
custom_color_palette=None,
591+
legend_top_padding=0.125):
592+
fig = plt.figure()
593+
594+
# Set fontsize to 6
595+
plt.rcParams.update({'font.size': 6})
596+
597+
# figsize=(self.FIGURE_WIDTH, 2.5 * num_datasets)
586598

587599
df = self.melt()
588600
cols = [col for col in df.columns if col in self.class_names()]
@@ -608,7 +620,7 @@ def visualize_prediction_stripplot(self,
608620

609621
# Increase the height of each row by adjusting the figure size
610622
num_classes = new_df['True class'].nunique()
611-
plt.figure(figsize=(self.FIGURE_WIDTH, num_classes * 1)) # Adjust the height multiplier as needed
623+
# plt.figure(figsize=(self.FIGURE_WIDTH, num_classes * 1)) # Adjust the height multiplier as needed
612624

613625
ax = plt.gca()
614626
# Add light gray background to every other row
@@ -628,7 +640,7 @@ def visualize_prediction_stripplot(self,
628640
dodge=True,
629641
palette=self._class_colors(
630642
custom_color_palette=custom_color_palette),
631-
size=4,
643+
size=2,
632644
ax=ax,
633645
order=class_names)
634646

@@ -648,16 +660,30 @@ def visualize_prediction_stripplot(self,
648660
loc='lower center',
649661
frameon=False,
650662
ncol=4,
651-
bbox_to_anchor=(0.5, -0.2), # Adjust position: (x, y)
663+
bbox_to_anchor=(0.35, -0.2),
652664
handletextpad=1, # Increase padding between legend handle and text
653-
columnspacing=8 # Increase spacing between columns
665+
columnspacing=2 # Increase spacing between columns
654666
)
667+
668+
# legend = fig.legend(legend_handles,
669+
# class_names,
670+
# title=title,
671+
# loc='lower center',
672+
# frameon=False,
673+
# ncol=len(legend_handles)/4,
674+
# bbox_to_anchor=(0.5, 0-legend_top_padding), # Adjust position: (x, y)
675+
# handletextpad=1, # Increase padding between legend handle and text
676+
# columnspacing=2 # Increase spacing between columns
677+
# )
655678
font_properties = FontProperties(weight='bold')
656679
legend.get_title().set_font_properties(font_properties)
657680

681+
# plt.tight_layout()
682+
683+
fig.set_size_inches(3.75, 6)
684+
658685
# Save the plot to a file
659-
plt.tight_layout()
660-
plt.subplots_adjust(hspace=0.5)
686+
# plt.subplots_adjust(hspace=0.5)
661687
plt.savefig(f'{self.output_dir}/prediction_stripplot.pdf',
662688
bbox_inches='tight', format='pdf')
663689

@@ -678,20 +704,28 @@ def _get_prediction_sets_by_softmax_threshold(self, min_softmax_threshold=0.5):
678704

679705

680706
def visualize_model_sets(self, min_softmax_threshold=0.5, color="black"):
681-
plt.figure()
682-
plt.figure(figsize=(self.FIGURE_WIDTH, 8))
707+
fig = plt.figure(figsize=(3.75, 5))
708+
# fig.set_size_inches(3.75, 1)
709+
710+
# Set fontsize to 6
711+
plt.rcParams.update({'font.size': 6})
712+
683713
upset_data = self._get_prediction_sets_by_softmax_threshold(
684714
min_softmax_threshold)
685715
# Set a multi-index
686716
upset_data.set_index(upset_data.columns.tolist(), inplace=True)
687717

688-
plot(upset_data,
718+
plts = plot(upset_data,
719+
fig=fig,
720+
element_size=9,
689721
sort_by="cardinality",
690722
facecolor=color,
691-
show_counts="%d",
692-
show_percentages="{:.0%}",
723+
show_counts="%d ",
724+
show_percentages=False,
693725
orientation='horizontal',
694726
min_subset_size=3)
727+
728+
695729
plt.savefig(f'{self.output_dir}/upset.pdf', bbox_inches='tight', format='pdf')
696730

697731
def prediction_sets_df(self, prediction_sets, export_to_dir=None):

0 commit comments

Comments
 (0)