Skip to content

Commit c7d3c96

Browse files
committed
make confusion matrix diaply more flexible
1 parent ffcb201 commit c7d3c96

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,7 @@
77
from scipy.sparse import issparse
88
from scipy.spatial.distance import jensenshannon
99
from scipy.stats import pearsonr, spearmanr
10-
from sklearn.metrics import (
11-
accuracy_score,
12-
classification_report,
13-
f1_score,
14-
precision_score,
15-
recall_score,
16-
)
10+
from sklearn.metrics import accuracy_score, classification_report, f1_score, precision_score, recall_score
1711

1812
from cellmapper._docs import d
1913
from cellmapper.logging import logger
@@ -386,6 +380,8 @@ def evaluate_label_transfer(
386380
def plot_confusion_matrix(
387381
self,
388382
label_key: str,
383+
*,
384+
true_key: str | None = None,
389385
subset: np.ndarray | pd.Series | None = None,
390386
figsize: tuple[int, int] = (10, 8),
391387
cmap: str = "viridis",
@@ -412,7 +408,11 @@ def plot_confusion_matrix(
412408
Parameters
413409
----------
414410
label_key
415-
Key in .obs storing ground-truth cell type annotations.
411+
Key in .obs storing predicted labels (from map_obs). The column
412+
``f"{label_key}{prediction_postfix}"`` is used as the x-axis (predicted).
413+
true_key
414+
Key in .obs to use for the y-axis (true labels). If None, uses ``label_key``.
415+
This allows comparing arbitrary columns, e.g., source_time vs mapped_time.
416416
subset
417417
Boolean mask to select a subset of cells for the confusion matrix.
418418
Must have the same length as query.obs or be a pandas Series indexed by obs_names.
@@ -466,7 +466,8 @@ def plot_confusion_matrix(
466466
raise ValueError("Label transfer has not been performed. Call map_obs() first.")
467467

468468
# Extract true and predicted labels
469-
y_true = self.query.obs[label_key].copy()
469+
true_col = true_key if true_key is not None else label_key
470+
y_true = self.query.obs[true_col].copy()
470471
y_pred = self.query.obs[f"{label_key}{self.prediction_postfix}"].copy()
471472

472473
# Drop NaNs

0 commit comments

Comments
 (0)