@@ -213,7 +213,8 @@ def plot_confusion_matrix(
213213 show_annotation_colors : bool = True ,
214214 xlabel_position : Literal ["bottom" , "top" ] = "bottom" ,
215215 show_grid : bool = True ,
216- min_cells : int | None = None ,
216+ min_cells_true : int | None = None ,
217+ min_cells_pred : int | None = None ,
217218 ** kwargs ,
218219 ) -> plt .Axes :
219220 """
@@ -241,10 +242,12 @@ def plot_confusion_matrix(
241242 Position of x-axis tick labels. Either "bottom" (default) or "top".
242243 show_grid
243244 Whether to show gridlines on the heatmap. Default is True.
244- min_cells
245- Minimum number of cells required for a category to be included in the confusion matrix.
246- Categories with fewer cells in both true and predicted labels are filtered out.
247- If None, all categories are shown.
245+ min_cells_true
246+ Minimum number of cells required for a category to be included based on true labels (rows).
247+ Categories with fewer true cells are filtered out. If None, no filtering is applied.
248+ min_cells_pred
249+ Minimum number of cells required for a category to be included based on predicted labels (columns).
250+ Categories with fewer predicted cells are filtered out. If None, no filtering is applied.
248251 **kwargs
249252 Additional keyword arguments to pass to ConfusionMatrixDisplay.
250253
@@ -273,14 +276,28 @@ def plot_confusion_matrix(
273276 y_true = y_true [subset ]
274277 y_pred = y_pred [subset ]
275278
276- # Filter categories by minimum cell count
277- if min_cells is not None :
279+ # Filter categories by minimum cell count (separately for true and predicted)
280+ valid_true_categories = None
281+ valid_pred_categories = None
282+
283+ if min_cells_true is not None :
278284 true_counts = y_true .value_counts ()
285+ valid_true_categories = set (true_counts [true_counts >= min_cells_true ].index )
286+
287+ if min_cells_pred is not None :
279288 pred_counts = y_pred .value_counts ()
280- # Keep categories that have at least min_cells in either true or predicted
281- valid_categories = set (true_counts [true_counts >= min_cells ].index ) | set (
282- pred_counts [pred_counts >= min_cells ].index
283- )
289+ valid_pred_categories = set (pred_counts [pred_counts >= min_cells_pred ].index )
290+
291+ # Combine filters: keep cells where both true and pred labels pass their respective filters
292+ if valid_true_categories is not None or valid_pred_categories is not None :
293+ # Start with all categories valid, then intersect with filters
294+ if valid_true_categories is not None and valid_pred_categories is not None :
295+ valid_categories = valid_true_categories & valid_pred_categories
296+ elif valid_true_categories is not None :
297+ valid_categories = valid_true_categories
298+ else :
299+ valid_categories = valid_pred_categories
300+
284301 mask = y_true .isin (valid_categories ) & y_pred .isin (valid_categories )
285302 y_true = y_true [mask ]
286303 y_pred = y_pred [mask ]
0 commit comments