Skip to content

Commit 30f905a

Browse files
committed
refactor: revert to single min_cells parameter for category filtering
1 parent 3031461 commit 30f905a

File tree

1 file changed

+11
-28
lines changed

1 file changed

+11
-28
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,7 @@ def plot_confusion_matrix(
213213
show_annotation_colors: bool = True,
214214
xlabel_position: Literal["bottom", "top"] = "bottom",
215215
show_grid: bool = True,
216-
min_cells_true: int | None = None,
217-
min_cells_pred: int | None = None,
216+
min_cells: int | None = None,
218217
**kwargs,
219218
) -> plt.Axes:
220219
"""
@@ -242,12 +241,10 @@ def plot_confusion_matrix(
242241
Position of x-axis tick labels. Either "bottom" (default) or "top".
243242
show_grid
244243
Whether to show gridlines on the heatmap. Default is True.
245-
min_cells_true
246-
Minimum number of cells required for a category to be included based on true labels (rows).
247-
Categories with fewer true cells are filtered out. If None, no filtering is applied.
248-
min_cells_pred
249-
Minimum number of cells required for a category to be included based on predicted labels (columns).
250-
Categories with fewer predicted cells are filtered out. If None, no filtering is applied.
244+
min_cells
245+
Minimum number of cells required for a category to be included in the confusion matrix.
246+
Categories with fewer cells in both true and predicted labels are filtered out.
247+
If None, all categories are shown.
251248
**kwargs
252249
Additional keyword arguments to pass to ConfusionMatrixDisplay.
253250
@@ -276,28 +273,14 @@ def plot_confusion_matrix(
276273
y_true = y_true[subset]
277274
y_pred = y_pred[subset]
278275

279-
# Filter categories by minimum cell count (separately for true and predicted)
280-
valid_true_categories = None
281-
valid_pred_categories = None
282-
283-
if min_cells_true is not None:
276+
# Filter categories by minimum cell count
277+
if min_cells is not None:
284278
true_counts = y_true.value_counts()
285-
valid_true_categories = set(true_counts[true_counts >= min_cells_true].index)
286-
287-
if min_cells_pred is not None:
288279
pred_counts = y_pred.value_counts()
289-
valid_pred_categories = set(pred_counts[pred_counts >= min_cells_pred].index)
290-
291-
# Combine filters: keep cells where both true and pred labels pass their respective filters
292-
if valid_true_categories is not None or valid_pred_categories is not None:
293-
# Start with all categories valid, then intersect with filters
294-
if valid_true_categories is not None and valid_pred_categories is not None:
295-
valid_categories = valid_true_categories & valid_pred_categories
296-
elif valid_true_categories is not None:
297-
valid_categories = valid_true_categories
298-
else:
299-
valid_categories = valid_pred_categories
300-
280+
# Keep categories that have at least min_cells in either true or predicted
281+
valid_categories = set(true_counts[true_counts >= min_cells].index) | set(
282+
pred_counts[pred_counts >= min_cells].index
283+
)
301284
mask = y_true.isin(valid_categories) & y_pred.isin(valid_categories)
302285
y_true = y_true[mask]
303286
y_pred = y_pred[mask]

0 commit comments

Comments
 (0)