Skip to content

Commit d167760

Browse files
committed
Tweak formatting of class barchart
1 parent f6084bc commit d167760

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

src/conformist/prediction_dataset.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def visualize_class_counts_by_dataset(self):
219219

220220
# Count how many datasets and create a grid of plots
221221
num_datasets = len(ccs.index.get_level_values(0).unique())
222-
fig, axs = plt.subplots(num_datasets, 1, figsize=(10, 8 * num_datasets))
222+
fig, axs = plt.subplots(num_datasets, 1, figsize=(10, 2 * num_datasets))
223223

224224
if num_datasets == 1:
225225
axs = [axs]
@@ -250,22 +250,41 @@ def visualize_class_counts_by_dataset(self):
250250
color=bar_colors)
251251
axs[i].set_title(dataset)
252252

253+
# Padding
254+
axs[i].margins(y=0.3)
255+
253256
# Print count above each bar
254257
for j, v in enumerate(sorted_series):
255258
if np.isfinite(v):
256259
axs[i].text(j, v, str(v), ha='center', va='bottom')
257260

258261
# Set x-axis labels to class names and rotate them vertically
259-
axs[i].set_xticks(range(len(sorted_series.index)))
260-
axs[i].set_xticklabels(sorted_series.index, rotation=90)
262+
# axs[i].set_xticks(range(len(sorted_series.index)))
263+
# axs[i].set_xticklabels(sorted_series.index, rotation=90)
264+
# Remove x-axis labels
265+
axs[i].set_xticks([])
266+
# Remove xticks
267+
axs[i].tick_params(axis='x', which='both', bottom=False, top=False)
261268

262269
# Add legend
263270
# if i == 0:
264271
# axs[i].legend(bars, sorted_series.index, title="Classes")
265272

273+
# Add a custom legend
274+
legend_handles = [Patch(color=class_to_color[cls], label=cls) for cls in self.class_names()]
275+
fig.legend(legend_handles,
276+
self.class_names(),
277+
title="Classes",
278+
loc='lower center',
279+
ncol=len(legend_handles)/2,
280+
bbox_to_anchor=(0.5, -0.05), # Adjust position: (x, y)
281+
frameon=False # Remove the frame around the legend
282+
)
283+
fig.subplots_adjust(bottom=0.1)
284+
266285
# Adjust layout to prevent overlap and add margin under each panel
267286
plt.tight_layout()
268-
plt.subplots_adjust(hspace=0.5) # Adjust hspace to add margin
287+
plt.subplots_adjust(top=1) # Adjust hspace to add margin
269288

270289
# show the plot
271290
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',

0 commit comments

Comments
 (0)