Skip to content

Commit d3197f0

Browse files
committed
Clean up the confusion matrix API
1 parent c7d3c96 commit d3197f0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def evaluate_label_transfer(
379379

380380
def plot_confusion_matrix(
381381
self,
382-
label_key: str,
382+
pred_key: str,
383383
*,
384384
true_key: str | None = None,
385385
subset: np.ndarray | pd.Series | None = None,
@@ -407,11 +407,11 @@ def plot_confusion_matrix(
407407
408408
Parameters
409409
----------
410-
label_key
411-
Key in .obs storing predicted labels (from map_obs). The column
412-
``f"{label_key}{prediction_postfix}"`` is used as the x-axis (predicted).
410+
pred_key
411+
Key in .obs identifying the mapped labels (from map_obs). The column
412+
``f"{pred_key}{prediction_postfix}"`` is used as the x-axis (predicted).
413413
true_key
414-
Key in .obs to use for the y-axis (true labels). If None, uses ``label_key``.
414+
Key in .obs to use for the y-axis (true labels). If None, uses ``pred_key``.
415415
This allows comparing arbitrary columns, e.g., source_time vs mapped_time.
416416
subset
417417
Boolean mask to select a subset of cells for the confusion matrix.
@@ -466,9 +466,9 @@ def plot_confusion_matrix(
466466
raise ValueError("Label transfer has not been performed. Call map_obs() first.")
467467

468468
# Extract true and predicted labels
469-
true_col = true_key if true_key is not None else label_key
469+
true_col = true_key if true_key is not None else pred_key
470470
y_true = self.query.obs[true_col].copy()
471-
y_pred = self.query.obs[f"{label_key}{self.prediction_postfix}"].copy()
471+
y_pred = self.query.obs[f"{pred_key}{self.prediction_postfix}"].copy()
472472

473473
# Drop NaNs
474474
valid_mask = y_true.notna() & y_pred.notna()
@@ -594,8 +594,8 @@ def plot_confusion_matrix(
594594
# Annotation color strips
595595
if show_annotation_colors:
596596
# Row colors (true labels) from query, column colors (predicted) from reference
597-
row_colors = _get_category_colors(self.query, label_key, list(cm_display.index))
598-
col_colors = _get_category_colors(self.reference, label_key, list(cm_display.columns))
597+
row_colors = _get_category_colors(self.query, true_col, list(cm_display.index))
598+
col_colors = _get_category_colors(self.reference, pred_key, list(cm_display.columns))
599599
_draw_annotation_strips(ax, row_colors, col_colors, xlabel_position)
600600

601601
if save:

0 commit comments

Comments
 (0)