@@ -379,7 +379,7 @@ def evaluate_label_transfer(
379379
380380 def plot_confusion_matrix (
381381 self ,
382- label_key : str ,
382+ pred_key : str ,
383383 * ,
384384 true_key : str | None = None ,
385385 subset : np .ndarray | pd .Series | None = None ,
@@ -407,11 +407,11 @@ def plot_confusion_matrix(
407407
408408 Parameters
409409 ----------
410- label_key
411- Key in .obs storing predicted labels (from map_obs). The column
412- ``f"{label_key }{prediction_postfix}"`` is used as the x-axis (predicted).
410+ pred_key
411+ Key in .obs identifying the mapped labels (from map_obs). The column
412+ ``f"{pred_key }{prediction_postfix}"`` is used as the x-axis (predicted).
413413 true_key
414- Key in .obs to use for the y-axis (true labels). If None, uses ``label_key ``.
414+ Key in .obs to use for the y-axis (true labels). If None, uses ``pred_key ``.
415415 This allows comparing arbitrary columns, e.g., source_time vs mapped_time.
416416 subset
417417 Boolean mask to select a subset of cells for the confusion matrix.
@@ -466,9 +466,9 @@ def plot_confusion_matrix(
466466 raise ValueError ("Label transfer has not been performed. Call map_obs() first." )
467467
468468 # Extract true and predicted labels
469- true_col = true_key if true_key is not None else label_key
469+ true_col = true_key if true_key is not None else pred_key
470470 y_true = self .query .obs [true_col ].copy ()
471- y_pred = self .query .obs [f"{ label_key } { self .prediction_postfix } " ].copy ()
471+ y_pred = self .query .obs [f"{ pred_key } { self .prediction_postfix } " ].copy ()
472472
473473 # Drop NaNs
474474 valid_mask = y_true .notna () & y_pred .notna ()
@@ -594,8 +594,8 @@ def plot_confusion_matrix(
594594 # Annotation color strips
595595 if show_annotation_colors :
596596 # Row colors (true labels) from query, column colors (predicted) from reference
597- row_colors = _get_category_colors (self .query , label_key , list (cm_display .index ))
598- col_colors = _get_category_colors (self .reference , label_key , list (cm_display .columns ))
597+ row_colors = _get_category_colors (self .query , true_col , list (cm_display .index ))
598+ col_colors = _get_category_colors (self .reference , pred_key , list (cm_display .columns ))
599599 _draw_annotation_strips (ax , row_colors , col_colors , xlabel_position )
600600
601601 if save :
0 commit comments