Skip to content

Commit 26a3ced

Browse files
committed
handle pos_label
1 parent cb01d7b commit 26a3ced

File tree

3 files changed

+70
-29
lines changed

3 files changed

+70
-29
lines changed

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2060,6 +2060,7 @@ def confusion_matrix(
20602060
X: ArrayLike | None = None,
20612061
y: ArrayLike | None = None,
20622062
threshold: bool = False,
2063+
pos_label: PositiveLabel | None = _DEFAULT,
20632064
) -> ConfusionMatrixDisplay:
20642065
"""Plot the confusion matrix.
20652066
@@ -2091,6 +2092,12 @@ def confusion_matrix(
20912092
requires the estimator to have `predict_proba` or `decision_function`
20922093
methods.
20932094
2095+
pos_label : int, float, bool, str or None, default=_DEFAULT
2096+
The label to consider as the positive class when computing the metric. Use
2097+
this parameter to override the positive class. By default, the positive
2098+
class is set to the one provided when creating the report. If `None`,
2099+
the metric is computed considering each class as a positive class.
2100+
20942101
Returns
20952102
-------
20962103
display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay`
@@ -2114,9 +2121,12 @@ def confusion_matrix(
21142121
>>> display = report.metrics.confusion_matrix(threshold=True)
21152122
>>> display.plot(threshold=0.7)
21162123
"""
2124+
if pos_label is _DEFAULT:
2125+
pos_label = self._parent.pos_label
2126+
21172127
display_kwargs = {
21182128
"display_labels": self._parent.estimator_.classes_.tolist(),
2119-
"pos_label": self._parent.pos_label,
2129+
"pos_label": pos_label,
21202130
"threshold": threshold,
21212131
}
21222132

skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from skore._externals._sklearn_compat import confusion_matrix_at_thresholds
1111
from skore._sklearn._plot.base import DisplayMixin
1212
from skore._sklearn._plot.utils import _validate_style_kwargs
13-
from skore._sklearn.types import ReportType, YPlotData
13+
from skore._sklearn.types import PositiveLabel, ReportType, YPlotData
1414

1515

1616
class ConfusionMatrixDisplay(DisplayMixin):
@@ -178,6 +178,9 @@ def _plot_single_estimator(
178178
heatmap_kwargs or {},
179179
)
180180
normalize_by = "normalized_by_" + normalize if normalize else "count"
181+
cm_pivot = cm.pivot(
182+
index="True label", columns="Predicted label", values=normalize_by
183+
).reindex(index=self.display_labels, columns=self.display_labels)
181184
sns.heatmap(
182185
cm.pivot(
183186
index="true_label", columns="predicted_label", values=normalize_by
@@ -205,8 +208,8 @@ def _compute_data_for_display(
205208
*,
206209
report_type: ReportType,
207210
display_labels: list[str],
208-
pos_label: str,
209211
threshold: bool = False,
212+
pos_label: PositiveLabel | None,
210213
**kwargs,
211214
) -> "ConfusionMatrixDisplay":
212215
"""Compute the confusion matrix for display.
@@ -228,8 +231,9 @@ def _compute_data_for_display(
228231
display_labels : list of str
229232
Display labels for plot.
230233
231-
pos_label : str
232-
The positive label.
234+
pos_label : int, float, bool, str or None
235+
The class considered as the positive class when computing the
236+
precision and recall metrics.
233237
234238
threshold : bool, default=False
235239
Whether to compute the confusion matrix at different thresholds.
@@ -246,6 +250,9 @@ def _compute_data_for_display(
246250
y_true_values = y_true[0].y
247251
y_pred_values = y_pred[0].y
248252
cms = []
253+
if isinstance(pos_label, str):
254+
neg_label = next(label for label in display_labels if label != pos_label)
255+
display_labels = [neg_label, pos_label]
249256

250257
if threshold:
251258
tns, fps, fns, tps, thresholds = confusion_matrix_at_thresholds(
@@ -260,7 +267,8 @@ def _compute_data_for_display(
260267
sklearn_confusion_matrix(
261268
y_true=y_true_values,
262269
y_pred=y_pred_values,
263-
normalize=None, # we will normalize later
270+
normalize=None, # we will normalize later
271+
labels=display_labels,
264272
)
265273
)
266274
thresholds = [None]
@@ -386,4 +394,4 @@ def frame(
386394
cm = cm[cm["threshold"] == threshold]
387395
return cm.pivot(
388396
index="true_label", columns="predicted_label", values=normalize_by
389-
)
397+
).reindex(index=self.display_labels, columns=self.display_labels)

skore/tests/unit/displays/confusion_matrix/test_estimator.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,11 @@ def test_not_implemented_error_for_non_estimator_report(
316316

317317

318318
def test_threshold_display_creation(
319-
pyplot, logistic_binary_classification_with_train_test
319+
pyplot, forest_binary_classification_with_train_test
320320
):
321321
"""Check that we can create a confusion matrix display with threshold support."""
322322
estimator, X_train, X_test, y_train, y_test = (
323-
logistic_binary_classification_with_train_test
323+
forest_binary_classification_with_train_test
324324
)
325325
report = EstimatorReport(
326326
estimator,
@@ -339,12 +339,12 @@ def test_threshold_display_creation(
339339

340340

341341
def test_threshold_display_without_threshold(
342-
pyplot, logistic_binary_classification_with_train_test
342+
pyplot, forest_binary_classification_with_train_test
343343
):
344344
"""Check that do_threshold is False when threshold=False and that we raise an error
345345
when frame or plot is called with threshold."""
346346
estimator, X_train, X_test, y_train, y_test = (
347-
logistic_binary_classification_with_train_test
347+
forest_binary_classification_with_train_test
348348
)
349349
report = EstimatorReport(
350350
estimator,
@@ -371,10 +371,10 @@ def test_threshold_display_without_threshold(
371371
display.plot(threshold=0.5)
372372

373373

374-
def test_plot_with_threshold(pyplot, logistic_binary_classification_with_train_test):
374+
def test_plot_with_threshold(pyplot, forest_binary_classification_with_train_test):
375375
"""Check that we can plot with a specific threshold."""
376376
estimator, X_train, X_test, y_train, y_test = (
377-
logistic_binary_classification_with_train_test
377+
forest_binary_classification_with_train_test
378378
)
379379
report = EstimatorReport(
380380
estimator,
@@ -390,11 +390,11 @@ def test_plot_with_threshold(pyplot, logistic_binary_classification_with_train_t
390390

391391

392392
def test_plot_with_default_threshold(
393-
pyplot, logistic_binary_classification_with_train_test
393+
pyplot, forest_binary_classification_with_train_test
394394
):
395395
"""Check that the default threshold (0.5) is used when not specified."""
396396
estimator, X_train, X_test, y_train, y_test = (
397-
logistic_binary_classification_with_train_test
397+
forest_binary_classification_with_train_test
398398
)
399399
report = EstimatorReport(
400400
estimator,
@@ -415,10 +415,10 @@ def test_plot_with_default_threshold(
415415
)
416416

417417

418-
def test_frame_with_threshold(logistic_binary_classification_with_train_test):
418+
def test_frame_with_threshold(forest_binary_classification_with_train_test):
419419
"""Check that we can get a frame at a specific threshold."""
420420
estimator, X_train, X_test, y_train, y_test = (
421-
logistic_binary_classification_with_train_test
421+
forest_binary_classification_with_train_test
422422
)
423423
report = EstimatorReport(
424424
estimator,
@@ -434,10 +434,10 @@ def test_frame_with_threshold(logistic_binary_classification_with_train_test):
434434
assert frame.shape == (2, 2)
435435

436436

437-
def test_frame_all_thresholds(logistic_binary_classification_with_train_test):
437+
def test_frame_all_thresholds(forest_binary_classification_with_train_test):
438438
"""Check that we get all thresholds when threshold=None."""
439439
estimator, X_train, X_test, y_train, y_test = (
440-
logistic_binary_classification_with_train_test
440+
forest_binary_classification_with_train_test
441441
)
442442
report = EstimatorReport(
443443
estimator,
@@ -454,12 +454,10 @@ def test_frame_all_thresholds(logistic_binary_classification_with_train_test):
454454
assert len(frame) == len(display.thresholds_)
455455

456456

457-
def test_threshold_normalization(
458-
pyplot, logistic_binary_classification_with_train_test
459-
):
457+
def test_threshold_normalization(pyplot, forest_binary_classification_with_train_test):
460458
"""Check that normalization works with threshold support."""
461459
estimator, X_train, X_test, y_train, y_test = (
462-
logistic_binary_classification_with_train_test
460+
forest_binary_classification_with_train_test
463461
)
464462
report = EstimatorReport(
465463
estimator,
@@ -483,12 +481,10 @@ def test_threshold_normalization(
483481
assert np.isclose(frame.sum().sum(), 1.0)
484482

485483

486-
def test_threshold_closest_match(
487-
pyplot, logistic_binary_classification_with_train_test
488-
):
484+
def test_threshold_closest_match(pyplot, forest_binary_classification_with_train_test):
489485
"""Check that the closest threshold is selected."""
490486
estimator, X_train, X_test, y_train, y_test = (
491-
logistic_binary_classification_with_train_test
487+
forest_binary_classification_with_train_test
492488
)
493489
report = EstimatorReport(
494490
estimator,
@@ -514,12 +510,12 @@ def test_threshold_closest_match(
514510

515511

516512
def test_frame_plot_coincidence_with_threshold(
517-
pyplot, logistic_binary_classification_with_train_test
513+
pyplot, forest_binary_classification_with_train_test
518514
):
519515
"""Check that the values in the frame and plot coincide when threshold is
520516
provided."""
521517
estimator, X_train, X_test, y_train, y_test = (
522-
logistic_binary_classification_with_train_test
518+
forest_binary_classification_with_train_test
523519
)
524520
report = EstimatorReport(
525521
estimator,
@@ -533,3 +529,30 @@ def test_frame_plot_coincidence_with_threshold(
533529
frame_values = frame.values.flatten()
534530
display.plot(threshold=0.5)
535531
assert np.allclose(frame_values, display.ax_.collections[0].get_array().flatten())
532+
533+
534+
def test_pos_label(pyplot, forest_binary_classification_with_train_test):
535+
"""Check that the pos_label parameter works correctly."""
536+
estimator, X_train, X_test, y_train, y_test = (
537+
forest_binary_classification_with_train_test
538+
)
539+
labels = np.array(["A", "B"], dtype=object)
540+
y_train = labels[y_train]
541+
y_test = labels[y_test]
542+
estimator.fit(X_train, y_train)
543+
report = EstimatorReport(
544+
estimator,
545+
X_train=X_train,
546+
y_train=y_train,
547+
X_test=X_test,
548+
y_test=y_test,
549+
)
550+
display = report.metrics.confusion_matrix(pos_label="A")
551+
display.plot()
552+
assert display.ax_.get_xticklabels()[1].get_text() == "A"
553+
assert display.ax_.get_yticklabels()[1].get_text() == "A"
554+
555+
display = report.metrics.confusion_matrix(pos_label="B")
556+
display.plot()
557+
assert display.ax_.get_xticklabels()[1].get_text() == "B"
558+
assert display.ax_.get_yticklabels()[1].get_text() == "B"

0 commit comments

Comments
 (0)