Skip to content

Commit 57b0a5a

Browse files
committed
Order class barchart by num of examples per class
1 parent df0a36b commit 57b0a5a

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/conformist/prediction_dataset.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,31 @@ def visualize_class_counts_by_dataset(self):
202202
if num_datasets == 1:
203203
axs = [axs]
204204

205+
# Group by the first level of the index (dataset) and sum the values
206+
grouped_ccs = ccs.groupby(level=0).sum()
207+
208+
# Order datasets by number of items
209+
ordered_datasets = grouped_ccs.sort_values(ascending=False).index
210+
print(type(ordered_datasets))
211+
205212
# For each dataset, create a bar chart
206-
for i, dataset in enumerate(ccs.index.get_level_values(0).unique()):
207-
sorted = ccs.loc[dataset].sort_values(ascending=False)
213+
for i, dataset in enumerate(ordered_datasets):
214+
dataset_series = ccs.loc[dataset]
215+
216+
# Ensure dataset_series is a Series
217+
if not isinstance(dataset_series, pd.Series):
218+
raise ValueError(f"Expected ccs.loc[{dataset}] to be a Series")
219+
220+
sorted_series = dataset_series.sort_values(ascending=False)
208221

209-
# Set fixed width for bar
210-
# axs[i].bar(sorted.index, sorted.values, width=0.5)
211-
sorted.plot.bar(ax=axs[i])
222+
# Plot bar chart with fixed width
223+
axs[i].bar(sorted_series.index, sorted_series.values, width=0.5)
212224
axs[i].set_title(dataset)
213225

214226
# Print count above each bar
215-
for j, v in enumerate(sorted):
216-
axs[i].text(j, v, str(v), ha='center', va='bottom')
227+
for j, v in enumerate(sorted_series):
228+
if np.isfinite(v):
229+
axs[i].text(j, v, str(v), ha='center', va='bottom')
217230

218231
# show the plot
219232
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',

0 commit comments

Comments
 (0)