Skip to content
32 changes: 30 additions & 2 deletions examples/model_evaluation/plot_estimator_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,12 @@ def operational_decision_cost(y_true, y_pred, amount):
# %%
# We can normalize the confusion matrix to get percentages instead of raw counts.
# Here we normalize by true labels (rows):
cm_display = report.metrics.confusion_matrix()
cm_display.plot(normalize="true")
plt.show()

# %%
# More plotting options are available via ``heatmap_kwargs``, which are passed to
# seaborn's heatmap. For example, we can customize the colormap and number format:
cm_display = report.metrics.confusion_matrix()
cm_display.plot(heatmap_kwargs={"cmap": "Greens", "fmt": ".2e"})
plt.show()

Expand All @@ -414,6 +412,36 @@ def operational_decision_cost(y_true, y_pred, amount):
cm_frame = cm_display.frame()
cm_frame

# %%
# Decision threshold support
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# For binary classification, the confusion matrix can be computed at different
# decision thresholds. This is useful for understanding how the model's predictions
# change as the decision threshold varies.

# %%
# First, we create a display with threshold support enabled:
cm_threshold_display = report.metrics.confusion_matrix(threshold=True)

# %%
# Now we can plot the confusion matrix at a specific threshold:
cm_threshold_display.plot(threshold_value=0.3)
plt.show()

# %%
# Since there are a finite number of threshold where the predictions change,
# we plot the decision matrix associated with the threshold closest to the one provided.
#
# The frame method also supports threshold selection:
cm_threshold_display.frame(threshold_value=0.7)

# %%
# When `threshold_value` is set to `"all"`, we get all confusion matrices for all
# available thresholds:
cm_all_thresholds = cm_threshold_display.frame(threshold_value="all")
cm_all_thresholds.head(5)

# %%
# .. seealso::
#
Expand Down
17 changes: 17 additions & 0 deletions skore/src/skore/_externals/_sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,3 +869,20 @@ def parametrize_with_checks(
check_X_y, # noqa: F401
validate_data, # noqa: F401
)


########################################################################################
# Upgrading for scikit-learn 1.8
########################################################################################

if sklearn_version < parse_version("1.8"):

def confusion_matrix_at_thresholds(y_true, y_score, pos_label):
from sklearn.metrics._ranking import _binary_clf_curve

fps, tps, thresholds = _binary_clf_curve(y_true, y_score, pos_label=pos_label)
fns = tps[-1] - tps
tns = fps[-1] - fps
return tns, fps, fns, tps, thresholds
else:
from sklearn.metrics._ranking import confusion_matrix_at_thresholds # noqa: F401
45 changes: 42 additions & 3 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,8 @@ def confusion_matrix(
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
threshold: bool = False,
pos_label: PositiveLabel | None = _DEFAULT,
) -> ConfusionMatrixDisplay:
"""Plot the confusion matrix.

Expand All @@ -2082,6 +2084,19 @@ def confusion_matrix(
New target on which to compute the metric. By default, we use the target
provided when creating the report.

threshold : bool, default=False
Whether to enable decision threshold support for binary classification.
When True, the display will precompute confusion matrices at all possible
decision thresholds, allowing you to specify a threshold in `.plot()` or
`.frame()` methods. This is only applicable for binary classification and
requires the estimator to have `predict_proba` or `decision_function`
methods.

pos_label : int, float, bool, str or None, default=_DEFAULT
The label to consider as the positive class when displaying the matrix. Use
this parameter to override the positive class. By default, the positive
class is set to the one provided when creating the report.

Returns
-------
display : :class:`~skore._sklearn._plot.ConfusionMatrixDisplay`
Expand All @@ -2097,16 +2112,40 @@ def confusion_matrix(
>>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True)
>>> classifier = LogisticRegression(max_iter=10_000)
>>> report = EstimatorReport(classifier, **split_data)
>>> report.metrics.confusion_matrix()
>>> display = report.metrics.confusion_matrix()
>>> display.plot()

With decision threshold support for binary classification:

>>> display = report.metrics.confusion_matrix(threshold=True)
>>> display.plot(threshold_value=0.7)
"""
display_kwargs = {"display_labels": self._parent.estimator_.classes_.tolist()}
if pos_label is _DEFAULT:
pos_label = self._parent.pos_label

display_kwargs = {
"display_labels": self._parent.estimator_.classes_.tolist(),
"pos_label": pos_label,
"threshold": threshold,
}

response_method: str | list[str] | tuple[str, ...]
if threshold:
if self._parent._ml_task == "binary-classification":
response_method = ("predict_proba", "decision_function")
else:
raise ValueError(
"Threshold support can only be set for binary classification."
)
else:
response_method = "predict"
display = cast(
ConfusionMatrixDisplay,
self._get_display(
X=X,
y=y,
data_source=data_source,
response_method="predict",
response_method=response_method,
display_class=ConfusionMatrixDisplay,
display_kwargs=display_kwargs,
),
Expand Down
Loading