77from scipy .sparse import issparse
88from scipy .spatial .distance import jensenshannon
99from scipy .stats import pearsonr , spearmanr
10- from sklearn .metrics import (
11- accuracy_score ,
12- classification_report ,
13- f1_score ,
14- precision_score ,
15- recall_score ,
16- )
10+ from sklearn .metrics import accuracy_score , classification_report , f1_score , precision_score , recall_score
1711
1812from cellmapper ._docs import d
1913from cellmapper .logging import logger
@@ -386,6 +380,8 @@ def evaluate_label_transfer(
386380 def plot_confusion_matrix (
387381 self ,
388382 label_key : str ,
383+ * ,
384+ true_key : str | None = None ,
389385 subset : np .ndarray | pd .Series | None = None ,
390386 figsize : tuple [int , int ] = (10 , 8 ),
391387 cmap : str = "viridis" ,
@@ -412,7 +408,11 @@ def plot_confusion_matrix(
412408 Parameters
413409 ----------
414410 label_key
415- Key in .obs storing ground-truth cell type annotations.
411+ Key in .obs storing predicted labels (from map_obs). The column
412+ ``f"{label_key}{prediction_postfix}"`` is used as the x-axis (predicted).
413+ true_key
414+ Key in .obs to use for the y-axis (true labels). If None, uses ``label_key``.
415+ This allows comparing arbitrary columns, e.g., source_time vs mapped_time.
416416 subset
417417 Boolean mask to select a subset of cells for the confusion matrix.
418418 Must have the same length as query.obs or be a pandas Series indexed by obs_names.
@@ -466,7 +466,8 @@ def plot_confusion_matrix(
466466 raise ValueError ("Label transfer has not been performed. Call map_obs() first." )
467467
468468 # Extract true and predicted labels
469- y_true = self .query .obs [label_key ].copy ()
469+ true_col = true_key if true_key is not None else label_key
470+ y_true = self .query .obs [true_col ].copy ()
470471 y_pred = self .query .obs [f"{ label_key } { self .prediction_postfix } " ].copy ()
471472
472473 # Drop NaNs
0 commit comments