Skip to content

Commit 3031461

Browse files
committed
feat: split min_cells into min_cells_true and min_cells_pred for independent row/column filtering
1 parent e8325a9 commit 3031461

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def plot_confusion_matrix(
213213
show_annotation_colors: bool = True,
214214
xlabel_position: Literal["bottom", "top"] = "bottom",
215215
show_grid: bool = True,
216-
min_cells: int | None = None,
216+
min_cells_true: int | None = None,
217+
min_cells_pred: int | None = None,
217218
**kwargs,
218219
) -> plt.Axes:
219220
"""
@@ -241,10 +242,12 @@ def plot_confusion_matrix(
241242
Position of x-axis tick labels. Either "bottom" (default) or "top".
242243
show_grid
243244
Whether to show gridlines on the heatmap. Default is True.
244-
min_cells
245-
Minimum number of cells required for a category to be included in the confusion matrix.
246-
Categories with fewer cells in both true and predicted labels are filtered out.
247-
If None, all categories are shown.
245+
min_cells_true
246+
Minimum number of cells required for a category to be included based on true labels (rows).
247+
Categories with fewer true cells are filtered out. If None, no filtering is applied.
248+
min_cells_pred
249+
Minimum number of cells required for a category to be included based on predicted labels (columns).
250+
Categories with fewer predicted cells are filtered out. If None, no filtering is applied.
248251
**kwargs
249252
Additional keyword arguments to pass to ConfusionMatrixDisplay.
250253
@@ -273,14 +276,28 @@ def plot_confusion_matrix(
273276
y_true = y_true[subset]
274277
y_pred = y_pred[subset]
275278

276-
# Filter categories by minimum cell count
277-
if min_cells is not None:
279+
# Filter categories by minimum cell count (separately for true and predicted)
280+
valid_true_categories = None
281+
valid_pred_categories = None
282+
283+
if min_cells_true is not None:
278284
true_counts = y_true.value_counts()
285+
valid_true_categories = set(true_counts[true_counts >= min_cells_true].index)
286+
287+
if min_cells_pred is not None:
279288
pred_counts = y_pred.value_counts()
280-
# Keep categories that have at least min_cells in either true or predicted
281-
valid_categories = set(true_counts[true_counts >= min_cells].index) | set(
282-
pred_counts[pred_counts >= min_cells].index
283-
)
289+
valid_pred_categories = set(pred_counts[pred_counts >= min_cells_pred].index)
290+
291+
# Combine filters: keep cells where both true and pred labels pass their respective filters
292+
if valid_true_categories is not None or valid_pred_categories is not None:
293+
# Start with all categories valid, then intersect with filters
294+
if valid_true_categories is not None and valid_pred_categories is not None:
295+
valid_categories = valid_true_categories & valid_pred_categories
296+
elif valid_true_categories is not None:
297+
valid_categories = valid_true_categories
298+
else:
299+
valid_categories = valid_pred_categories
300+
284301
mask = y_true.isin(valid_categories) & y_pred.isin(valid_categories)
285302
y_true = y_true[mask]
286303
y_pred = y_pred[mask]

0 commit comments

Comments
 (0)