@@ -219,7 +219,7 @@ def visualize_class_counts_by_dataset(self):
219219
220220 # Count how many datasets and create a grid of plots
221221 num_datasets = len (ccs .index .get_level_values (0 ).unique ())
222- fig , axs = plt .subplots (num_datasets , 1 , figsize = (10 , 8 * num_datasets ))
222+ fig , axs = plt .subplots (num_datasets , 1 , figsize = (10 , 2 * num_datasets ))
223223
224224 if num_datasets == 1 :
225225 axs = [axs ]
@@ -250,22 +250,41 @@ def visualize_class_counts_by_dataset(self):
250250 color = bar_colors )
251251 axs [i ].set_title (dataset )
252252
253+ # Padding
254+ axs [i ].margins (y = 0.3 )
255+
253256 # Print count above each bar
254257 for j , v in enumerate (sorted_series ):
255258 if np .isfinite (v ):
256259 axs [i ].text (j , v , str (v ), ha = 'center' , va = 'bottom' )
257260
258261 # Set x-axis labels to class names and rotate them vertically
259- axs [i ].set_xticks (range (len (sorted_series .index )))
260- axs [i ].set_xticklabels (sorted_series .index , rotation = 90 )
262+ # axs[i].set_xticks(range(len(sorted_series.index)))
263+ # axs[i].set_xticklabels(sorted_series.index, rotation=90)
264+ # Remove x-axis labels
265+ axs [i ].set_xticks ([])
266+ # Remove xticks
267+ axs [i ].tick_params (axis = 'x' , which = 'both' , bottom = False , top = False )
261268
262269 # Add legend
263270 # if i == 0:
264271 # axs[i].legend(bars, sorted_series.index, title="Classes")
265272
273+ # Add a custom legend
274+ legend_handles = [Patch (color = class_to_color [cls ], label = cls ) for cls in self .class_names ()]
275+ fig .legend (legend_handles ,
276+ self .class_names (),
277+ title = "Classes" ,
278+ loc = 'lower center' ,
279+ ncol = len (legend_handles )/ 2 ,
280+ bbox_to_anchor = (0.5 , - 0.05 ), # Adjust position: (x, y)
281+ frameon = False # Remove the frame around the legend
282+ )
283+ fig .subplots_adjust (bottom = 0.1 )
284+
266285 # Adjust layout to prevent overlap and add margin under each panel
267286 plt .tight_layout ()
268- plt .subplots_adjust (hspace = 0.5 ) # Adjust hspace to add margin
287+ plt .subplots_adjust (top = 1 ) # Adjust hspace to add margin
269288
270289 # show the plot
271290 plt .savefig (f'{ self .output_dir } /class_counts_by_dataset.png' ,
0 commit comments