@@ -202,18 +202,31 @@ def visualize_class_counts_by_dataset(self):
202202 if num_datasets == 1 :
203203 axs = [axs ]
204204
205+ # Group by the first level of the index (dataset) and sum the values
206+ grouped_ccs = ccs .groupby (level = 0 ).sum ()
207+
208+ # Order datasets by number of items
209+ ordered_datasets = grouped_ccs .sort_values (ascending = False ).index
210+ print (type (ordered_datasets ))
211+
205212 # For each dataset, create a bar chart
206- for i , dataset in enumerate (ccs .index .get_level_values (0 ).unique ()):
207- sorted = ccs .loc [dataset ].sort_values (ascending = False )
213+ for i , dataset in enumerate (ordered_datasets ):
214+ dataset_series = ccs .loc [dataset ]
215+
216+ # Ensure dataset_series is a Series
217+ if not isinstance (dataset_series , pd .Series ):
218+ raise ValueError (f"Expected ccs.loc[{ dataset } ] to be a Series" )
219+
220+ sorted_series = dataset_series .sort_values (ascending = False )
208221
209- # Set fixed width for bar
210- # axs[i].bar(sorted.index, sorted.values, width=0.5)
211- sorted .plot .bar (ax = axs [i ])
222+ # Plot bar chart with fixed width
223+ axs [i ].bar (sorted_series .index , sorted_series .values , width = 0.5 )
212224 axs [i ].set_title (dataset )
213225
214226 # Print count above each bar
215- for j , v in enumerate (sorted ):
216- axs [i ].text (j , v , str (v ), ha = 'center' , va = 'bottom' )
227+ for j , v in enumerate (sorted_series ):
228+ if np .isfinite (v ):
229+ axs [i ].text (j , v , str (v ), ha = 'center' , va = 'bottom' )
217230
218231 # show the plot
219232 plt .savefig (f'{ self .output_dir } /class_counts_by_dataset.png' ,
0 commit comments