@@ -213,8 +213,7 @@ def plot_confusion_matrix(
213213 show_annotation_colors : bool = True ,
214214 xlabel_position : Literal ["bottom" , "top" ] = "bottom" ,
215215 show_grid : bool = True ,
216- min_cells_true : int | None = None ,
217- min_cells_pred : int | None = None ,
216+ min_cells : int | None = None ,
218217 ** kwargs ,
219218 ) -> plt .Axes :
220219 """
@@ -242,12 +241,10 @@ def plot_confusion_matrix(
242241 Position of x-axis tick labels. Either "bottom" (default) or "top".
243242 show_grid
244243 Whether to show gridlines on the heatmap. Default is True.
245- min_cells_true
246- Minimum number of cells required for a category to be included based on true labels (rows).
247- Categories with fewer true cells are filtered out. If None, no filtering is applied.
248- min_cells_pred
249- Minimum number of cells required for a category to be included based on predicted labels (columns).
250- Categories with fewer predicted cells are filtered out. If None, no filtering is applied.
244+ min_cells
245+ Minimum number of cells required for a category to be included in the confusion matrix.
246+ Categories with fewer cells in both true and predicted labels are filtered out.
247+ If None, all categories are shown.
251248 **kwargs
252249 Additional keyword arguments to pass to ConfusionMatrixDisplay.
253250
@@ -276,28 +273,14 @@ def plot_confusion_matrix(
276273 y_true = y_true [subset ]
277274 y_pred = y_pred [subset ]
278275
279- # Filter categories by minimum cell count (separately for true and predicted)
280- valid_true_categories = None
281- valid_pred_categories = None
282-
283- if min_cells_true is not None :
276+ # Filter categories by minimum cell count
277+ if min_cells is not None :
284278 true_counts = y_true .value_counts ()
285- valid_true_categories = set (true_counts [true_counts >= min_cells_true ].index )
286-
287- if min_cells_pred is not None :
288279 pred_counts = y_pred .value_counts ()
289- valid_pred_categories = set (pred_counts [pred_counts >= min_cells_pred ].index )
290-
291- # Combine filters: keep cells where both true and pred labels pass their respective filters
292- if valid_true_categories is not None or valid_pred_categories is not None :
293- # Start with all categories valid, then intersect with filters
294- if valid_true_categories is not None and valid_pred_categories is not None :
295- valid_categories = valid_true_categories & valid_pred_categories
296- elif valid_true_categories is not None :
297- valid_categories = valid_true_categories
298- else :
299- valid_categories = valid_pred_categories
300-
280+ # Keep categories that have at least min_cells in either true or predicted
281+ valid_categories = set (true_counts [true_counts >= min_cells ].index ) | set (
282+ pred_counts [pred_counts >= min_cells ].index
283+ )
301284 mask = y_true .isin (valid_categories ) & y_pred .isin (valid_categories )
302285 y_true = y_true [mask ]
303286 y_pred = y_pred [mask ]
0 commit comments