@@ -136,6 +136,11 @@ def class_counts(self, translate=False):
136136 counts .index .name = None
137137 return counts
138138
139+ def class_counts_by_dataset (self ):
140+ counting_df = self .melt ()
141+ counts = counting_df .groupby ([self .DATASET_NAME_COL , self .MELTED_KNOWN_CLASS_COL ]).size ()
142+ return counts
143+
139144 def translate_class_name (self , class_name ):
140145 if self .display_classes and class_name in self .display_classes :
141146 return self .display_classes [class_name ]
@@ -158,6 +163,7 @@ def class_names(self, translate=False):
158163 def run_reports (self , base_output_dir ):
159164 self .create_output_dir (base_output_dir )
160165 self .visualize_class_counts ()
166+ self .visualize_class_counts_by_dataset ()
161167 self .visualize_prediction_heatmap ()
162168 print (f'Reports saved to { self .output_dir } ' )
163169
@@ -183,6 +189,33 @@ def visualize_class_counts(self):
183189 # show the plot
184190 plt .savefig (f'{ self .output_dir } /class_counts.png' , bbox_inches = 'tight' )
185191
192+ def visualize_class_counts_by_dataset (self ):
193+ plt .figure ()
194+
195+ # create a bar chart
196+ ccs = self .class_counts_by_dataset ()
197+
198+ # Count how many datasets and create a grid of plots
199+ num_datasets = len (ccs .index .get_level_values (0 ).unique ())
200+ fig , axs = plt .subplots (num_datasets , 1 , figsize = (10 , 8 * num_datasets ))
201+
202+ # For each dataset, create a bar chart
203+ for i , dataset in enumerate (ccs .index .get_level_values (0 ).unique ()):
204+ sorted = ccs .loc [dataset ].sort_values (ascending = False )
205+
206+ # Set fixed width for bar
207+ # axs[i].bar(sorted.index, sorted.values, width=0.5)
208+ sorted .plot .bar (ax = axs [i ])
209+ axs [i ].set_title (dataset )
210+
211+ # Print count above each bar
212+ for j , v in enumerate (sorted ):
213+ axs [i ].text (j , v , str (v ), ha = 'center' , va = 'bottom' )
214+
215+ # show the plot
216+ plt .savefig (f'{ self .output_dir } /class_counts_by_dataset.png' ,
217+ bbox_inches = 'tight' )
218+
186219 def visualize_prediction_heatmap (self ):
187220 plt .figure (figsize = (10 , 8 ))
188221
0 commit comments