Skip to content
105 changes: 102 additions & 3 deletions skore/src/skore/sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from skore.sklearn._base import _BaseAccessor, _get_cached_response_values
from skore.sklearn._comparison.report import ComparisonReport
from skore.sklearn._plot.metrics import (
CalibrationCurveDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
RocCurveDisplay,
Expand Down Expand Up @@ -1226,12 +1227,22 @@ def _get_display(
X: Union[ArrayLike, None],
y: Union[ArrayLike, None],
data_source: DataSource,
response_method: Union[str, list[str]],
response_method: Union[str, list[str], tuple[str, ...]],
display_class: type[
Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay]
Union[
RocCurveDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
CalibrationCurveDisplay,
]
],
display_kwargs: dict[str, Any],
) -> Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay]:
) -> Union[
RocCurveDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
CalibrationCurveDisplay,
]:
"""Get the display from the cache or compute it.

Parameters
Expand Down Expand Up @@ -1573,3 +1584,91 @@ def prediction_error(
),
)
return display

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["binary-classification", "multiclass-classification"]
)
)
def calibration_curve(
self,
*,
data_source: DataSource = "test",
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
pos_label: PositiveLabel,
strategy: str = "uniform",
n_bins: int = 5,
) -> CalibrationCurveDisplay:
"""Plot the calibration curve.

Parameters
----------
data_source : {"test", "train", "X_y"}, default="test"
The data source to use.

- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.

X : array-like of shape (n_samples, n_features), default=None
New data on which to compute the metric. By default, we use the validation
set provided when creating the report.

y : array-like of shape (n_samples,), default=None
New target on which to compute the metric. By default, we use the target
provided when creating the report.

pos_label : int, float, bool or str
The positive class.

strategy : {'uniform', 'quantile'}, default='uniform'
Strategy used to define the widths of the bins.

- 'uniform': All bins have identical widths.
- 'quantile': All bins have the same number of points.

n_bins : int, default=5
Number of bins to use when calculating the histogram.

Returns
-------
CalibrationCurveDisplay
The calibration curve display.

Examples
--------
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import train_test_split
>>> from skore import ComparisonReport, EstimatorReport
>>> X, y = load_breast_cancer(return_X_y=True)
>>> split_data = train_test_split(X=X, y=y, random_state=42, as_dict=True)
>>> estimator_1 = LogisticRegression(max_iter=10000, random_state=42)
>>> estimator_report_1 = EstimatorReport(estimator_1, **split_data)
>>> estimator_2 = LogisticRegression(max_iter=10000, random_state=43)
>>> estimator_report_2 = EstimatorReport(estimator_2, **split_data)
>>> comparison_report = ComparisonReport(
... [estimator_report_1, estimator_report_2]
... )
>>> display = comparison_report.metrics.calibration_curve(pos_label=1)
>>> display.plot()
"""
response_method = ("predict_proba", "decision_function")
display_kwargs = {
"pos_label": pos_label,
"strategy": strategy,
"n_bins": n_bins,
}
display = cast(
CalibrationCurveDisplay,
self._get_display(
X=X,
y=y,
data_source=data_source,
response_method=response_method,
display_class=CalibrationCurveDisplay,
display_kwargs=display_kwargs,
),
)
return display
113 changes: 108 additions & 5 deletions skore/src/skore/sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
from functools import partial
from operator import attrgetter
from typing import Any, Callable, Literal, Optional, Union, cast
from typing import Any, Callable, Literal, Optional, Protocol, Union, cast

import joblib
import numpy as np
Expand All @@ -16,6 +16,7 @@
from skore.sklearn._base import _BaseAccessor, _get_cached_response_values
from skore.sklearn._estimator.report import EstimatorReport
from skore.sklearn._plot import (
CalibrationCurveDisplay,
ConfusionMatrixDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
Expand All @@ -32,6 +33,11 @@
DataSource = Literal["test", "train", "X_y"]


class DisplayClassProtocol(Protocol):
@classmethod
def _compute_data_for_display(cls, *args, **kwargs) -> Any: ...


class _MetricsAccessor(_BaseAccessor["EstimatorReport"], DirNamesMixin):
"""Accessor for metrics-related operations.

Expand Down Expand Up @@ -1625,11 +1631,15 @@ def _get_display(
y: Union[ArrayLike, None],
data_source: DataSource,
response_method: Union[str, list[str], tuple[str, ...]],
display_class: type[
Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay]
],
display_class: type[DisplayClassProtocol],
display_kwargs: dict[str, Any],
) -> Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay]:
) -> Union[
RocCurveDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
CalibrationCurveDisplay,
ConfusionMatrixDisplay,
]:
"""Get the display from the cache or compute it.

Parameters
Expand Down Expand Up @@ -1862,6 +1872,99 @@ def precision_recall(
)
return display

@available_if(
_check_all_checks(
checks=[
_check_supported_ml_task(
supported_ml_tasks=[
"binary-classification",
"multiclass-classification",
]
),
_check_estimator_has_method(method_name="predict_proba"),
]
)
)
def calibration_curve(
self,
*,
data_source: DataSource = "test",
X: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None,
pos_label: PositiveLabel,
strategy: str = "uniform",
n_bins: int = 5,
) -> CalibrationCurveDisplay:
"""Plot the calibration curve (reliability diagram).

A calibration curve shows how well a model's predicted probabilities
match observed outcomes. It plots the mean predicted probability in each bin
against the fraction of positive samples in that bin.

Parameters
----------
data_source : {"test", "train", "X_y"}, default="test"
The data source to use.

- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.
- "X_y" : use the provided `X` and `y` to compute the metric.

X : array-like of shape (n_samples, n_features), default=None
New data on which to compute the metric. By default, we use the validation
set provided when creating the report.

y : array-like of shape (n_samples,), default=None
New target on which to compute the metric. By default, we use the target
provided when creating the report.

pos_label : int, float, bool or str
The positive class label.

strategy : {"uniform", "quantile"}, default="uniform"
Strategy used to define the widths of the bins.

- "uniform" : The bins have identical widths.
- "quantile" : The bins have the same number of samples and depend
on predicted probabilities.

n_bins : int, default=5
Number of bins to use when calculating the calibration curve.

Returns
-------
CalibrationCurveDisplay
The calibration curve display.

Examples
--------
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import train_test_split
>>> from skore import EstimatorReport
>>> X, y = make_classification(random_state=0)
>>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True)
>>> classifier = LogisticRegression()
>>> report = EstimatorReport(classifier, **split_data)
>>> display = report.metrics.calibration_curve(pos_label=1)
>>> display.plot()
"""
response_method = "predict_proba"
display_kwargs = {
"pos_label": pos_label,
"strategy": strategy,
"n_bins": n_bins,
}
display = self._get_display(
X=X,
y=y,
data_source=data_source,
response_method=response_method,
display_class=CalibrationCurveDisplay,
display_kwargs=display_kwargs,
)
return cast(CalibrationCurveDisplay, display)

@available_if(
_check_supported_ml_task(
supported_ml_tasks=["regression", "multioutput-regression"]
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/sklearn/_plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from skore.sklearn._plot.metrics import (
CalibrationCurveDisplay,
ConfusionMatrixDisplay,
PrecisionRecallCurveDisplay,
PredictionErrorDisplay,
RocCurveDisplay,
)

__all__ = [
"CalibrationCurveDisplay",
"ConfusionMatrixDisplay",
"RocCurveDisplay",
"PrecisionRecallCurveDisplay",
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/sklearn/_plot/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from skore.sklearn._plot.metrics.calibration_curve import CalibrationCurveDisplay
from skore.sklearn._plot.metrics.confusion_matrix import ConfusionMatrixDisplay
from skore.sklearn._plot.metrics.precision_recall_curve import (
PrecisionRecallCurveDisplay,
Expand All @@ -6,6 +7,7 @@
from skore.sklearn._plot.metrics.roc_curve import RocCurveDisplay

__all__ = [
"CalibrationCurveDisplay",
"ConfusionMatrixDisplay",
"PrecisionRecallCurveDisplay",
"PredictionErrorDisplay",
Expand Down
Loading