Skip to content

Commit f6bdd88

Browse files
committed
Size tweaks
1 parent 10512ef commit f6bdd88

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

src/conformist/prediction_dataset.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def visualize_class_counts_by_dataset(self,
248248
num_datasets = len(ccs.index.get_level_values(0).unique())
249249
fig, axs = plt.subplots(num_datasets,
250250
1,
251-
figsize=(self.FIGURE_WIDTH, 2 * num_datasets))
251+
figsize=(self.FIGURE_WIDTH, 2.5 * num_datasets))
252252

253253
if num_datasets == 1:
254254
axs = [axs]
@@ -312,7 +312,7 @@ def visualize_class_counts_by_dataset(self,
312312
loc='lower center',
313313
frameon=False,
314314
ncol=len(legend_handles)/4,
315-
bbox_to_anchor=(0.5, -0.15), # Adjust position: (x, y)
315+
bbox_to_anchor=(0.5, -0.125), # Adjust position: (x, y)
316316
handletextpad=1, # Increase padding between legend handle and text
317317
columnspacing=8 # Increase spacing between columns
318318
)
@@ -447,8 +447,11 @@ def softmax_summary(self):
447447
df = self.melt()
448448
softmax_cols = [col for col in df.columns if col in self.class_names()]
449449

450-
summary_df = pd.DataFrame(columns=['mean true positive softmax',
451-
'mean false positive softmax'])
450+
true_col_name = 'mean true positive softmax'
451+
false_col_name = 'mean false positive softmax'
452+
453+
summary_df = pd.DataFrame(
454+
columns=[true_col_name, false_col_name])
452455

453456
# For each col in softmax_cols, calculate the mean softmax score for the true class
454457
# and the mean softmax score for the false classes
@@ -463,7 +466,7 @@ def softmax_summary(self):
463466
summary_df.loc[col] = [mean_true_pos, mean_false_pos]
464467

465468
# Sort the DataFrame by mean true positive softmax
466-
summary_df = summary_df.sort_values(by='mean true positive softmax',
469+
summary_df = summary_df.sort_values(by=true_col_name,
467470
ascending=False)
468471

469472
# Name index "Predicted class"
@@ -483,16 +486,34 @@ def softmax_summary(self):
483486
plt.axis('off') # Hide axes
484487

485488
# Create table
486-
table = plt.table(cellText=summary_df.values, colLabels=summary_df.columns, rowLabels=summary_df.index, loc='center', cellLoc='center')
489+
table = plt.table(cellText=summary_df.values,
490+
colLabels=summary_df.columns,
491+
rowLabels=summary_df.index,
492+
loc='left',
493+
cellLoc='center',
494+
)
495+
496+
plt.title('Predicted classes', weight='bold')
497+
498+
accent_color = "#d4cbb3"
499+
# Make font color dark gray
500+
font_color = '#222222'
501+
502+
# Change border colors
503+
for key, cell in table.get_celld().items():
504+
cell.set_edgecolor(accent_color) # Set border color
505+
cell.set_linewidth(1) # Set border width
506+
cell.get_text().set_color(font_color)
507+
# cell.PAD = 0.2
487508

488509
# Make every other row gray
489510
for i in range(0, len(summary_df), 2):
490-
table.get_celld()[(i + 1, 0)].set_facecolor('#eeeeee')
491-
table.get_celld()[(i + 1, 1)].set_facecolor('#eeeeee')
511+
table.get_celld()[(i + 1, 0)].set_facecolor(accent_color)
512+
table.get_celld()[(i + 1, 1)].set_facecolor(accent_color)
492513

493514
table.auto_set_font_size(False)
494515
table.set_fontsize(self.FIGURE_FONTSIZE)
495-
table.scale(1.2, 1.2) # Scale table size
516+
table.scale(2, 2) # Scale table size
496517

497518
plt.savefig(f'{self.output_dir}/softmax_summary.png', bbox_inches='tight')
498519

0 commit comments

Comments
 (0)