Skip to content

Commit 6e2a343

Browse files
committed
Barchart by dataset - WIP
1 parent 538cb52 commit 6e2a343

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

src/conformist/prediction_dataset.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def class_counts(self, translate=False):
136136
counts.index.name = None
137137
return counts
138138

139+
def class_counts_by_dataset(self):
140+
counting_df = self.melt()
141+
counts = counting_df.groupby([self.DATASET_NAME_COL, self.MELTED_KNOWN_CLASS_COL]).size()
142+
return counts
143+
139144
def translate_class_name(self, class_name):
140145
if self.display_classes and class_name in self.display_classes:
141146
return self.display_classes[class_name]
@@ -158,6 +163,7 @@ def class_names(self, translate=False):
158163
def run_reports(self, base_output_dir):
159164
self.create_output_dir(base_output_dir)
160165
self.visualize_class_counts()
166+
self.visualize_class_counts_by_dataset()
161167
self.visualize_prediction_heatmap()
162168
print(f'Reports saved to {self.output_dir}')
163169

@@ -183,6 +189,33 @@ def visualize_class_counts(self):
183189
# show the plot
184190
plt.savefig(f'{self.output_dir}/class_counts.png', bbox_inches='tight')
185191

192+
def visualize_class_counts_by_dataset(self):
193+
plt.figure()
194+
195+
# create a bar chart
196+
ccs = self.class_counts_by_dataset()
197+
198+
# Count how many datasets and create a grid of plots
199+
num_datasets = len(ccs.index.get_level_values(0).unique())
200+
fig, axs = plt.subplots(num_datasets, 1, figsize=(10, 8 * num_datasets))
201+
202+
# For each dataset, create a bar chart
203+
for i, dataset in enumerate(ccs.index.get_level_values(0).unique()):
204+
sorted = ccs.loc[dataset].sort_values(ascending=False)
205+
206+
# Set fixed width for bar
207+
# axs[i].bar(sorted.index, sorted.values, width=0.5)
208+
sorted.plot.bar(ax=axs[i])
209+
axs[i].set_title(dataset)
210+
211+
# Print count above each bar
212+
for j, v in enumerate(sorted):
213+
axs[i].text(j, v, str(v), ha='center', va='bottom')
214+
215+
# show the plot
216+
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',
217+
bbox_inches='tight')
218+
186219
def visualize_prediction_heatmap(self):
187220
plt.figure(figsize=(10, 8))
188221

0 commit comments

Comments
 (0)