From bca9b5458d9158af36800fd88cbe77e8be34d5bd Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 13:27:37 +0530 Subject: [PATCH 1/8] feat: Add CalibrationCurveDisplay class for binary classification --- .../_plot/metrics/calibration_curve.py | 372 ++++++++++++++++++ 1 file changed, 372 insertions(+) create mode 100644 skore/src/skore/sklearn/_plot/metrics/calibration_curve.py diff --git a/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py b/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py new file mode 100644 index 0000000000..401ef44ee4 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py @@ -0,0 +1,372 @@ +from typing import Any, Literal, Optional, Union + +import matplotlib.pyplot as plt +from numpy.typing import NDArray + +from skore.sklearn._plot.style import StyleDisplayMixin +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) +from skore.sklearn.types import MLTask, PositiveLabel + + +class CalibrationCurveDisplay(StyleDisplayMixin, HelpDisplayMixin): + """Visualization of the calibration curve for a classifier. + + A calibration curve (also known as a reliability diagram) plots the calibration + of a classifier, showing how well the predicted probabilities match observed + outcomes. It plots the mean predicted probability in each bin against the + fraction of positive samples in that bin. + + An instance of this class is created by + `EstimatorReport.metrics.calibration_curve()`. You should not create an instance + of this class directly. + + Parameters + ---------- + prob_true : dict[Any, list[NDArray]] + Dictionary mapping positive labels to lists of true probabilities. + + prob_pred : dict[Any, list[NDArray]] + Dictionary mapping positive labels to lists of predicted probabilities. + + y_prob : list[NDArray] + List of predicted probabilities. + + estimator_names : list[str] + List of estimator names. + + pos_label : PositiveLabel + The positive label. + + data_source : {"train", "test", "X_y"} + The source of the data. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + report_type : {"cross-validation", "estimator", "comparison-estimator"} + The type of report. + + n_bins : int + Number of bins used for the calibration curve. + + strategy : {"uniform", "quantile"} + Strategy used to define the widths of the bins. + + Attributes + ---------- + line_ : matplotlib Artist + Calibration curve lines. + + ax_ : matplotlib Axes + Axes with the calibration curve. + + hist_ax_ : matplotlib Axes + Axes with the histogram of predicted probabilities. + + figure_ : matplotlib Figure + Figure containing the curve and histogram. + """ + + _default_line_kwargs: Union[dict[str, Any], None] = None + _default_ref_line_kwargs: Union[dict[str, Any], None] = None + _default_hist_kwargs: Union[dict[str, Any], None] = None + + def __init__( + self, + *, + prob_true: dict[Any, list[NDArray]], + prob_pred: dict[Any, list[NDArray]], + y_prob: list[NDArray], + estimator_names: list[str], + pos_label: PositiveLabel, + data_source: Literal["train", "test", "X_y"], + ml_task: MLTask, + report_type: Literal["cross-validation", "estimator", "comparison-estimator"], + n_bins: int, + strategy: str, + ) -> None: + self.prob_true = prob_true + self.prob_pred = prob_pred + self.y_prob = y_prob + self.estimator_names = estimator_names + self.pos_label = pos_label + self.data_source = data_source + self.ml_task = ml_task + self.report_type = report_type + self.n_bins = n_bins + self.strategy = strategy + + def _plot_single_estimator( + self, + *, + line_kwargs: dict[str, Any], + hist_kwargs: dict[str, Any], + ) -> list[Any]: + """Plot calibration curve for a single estimator. + + Parameters + ---------- + line_kwargs : dict[str, Any] + Keyword arguments for the line plots. + + hist_kwargs : dict[str, Any] + Keyword arguments for the histogram. + + Returns + ------- + lines : list[matplotlib Artist] + The plotted lines. + """ + lines = [] + if self.data_source in ("train", "test"): + line_label = f"{self.data_source.title()} set" + else: # data_source == "X_y" + line_label = "Data set" + + # Plot calibration curve + cal_line = self.ax_.plot( + self.prob_pred[self.pos_label][0], + self.prob_true[self.pos_label][0], + label=line_label, + **line_kwargs, + )[0] + lines.append(cal_line) + + # Plot histogram of predicted probabilities + self.hist_ax_.hist( + self.y_prob[0], + range=(0, 1), + bins=self.n_bins, + **hist_kwargs, + ) + + return lines + + def _plot_cross_validated_estimator( + self, + *, + line_kwargs: dict[str, Any], + hist_kwargs: dict[str, Any], + ) -> list[Any]: + """Plot calibration curve for a cross-validated estimator. + + Parameters + ---------- + line_kwargs : dict[str, Any] + Keyword arguments for the line plots. + + hist_kwargs : dict[str, Any] + Keyword arguments for the histogram. + + Returns + ------- + lines : list[matplotlib Artist] + The plotted lines. + """ + lines = [] + + # Plot calibration curves for each fold + for split_idx in range(len(self.y_prob)): + label = f"Estimator of fold #{split_idx + 1}" + + cal_line = self.ax_.plot( + self.prob_pred[self.pos_label][split_idx], + self.prob_true[self.pos_label][split_idx], + label=label, + **line_kwargs, + )[0] + lines.append(cal_line) + + # Add to histogram + self.hist_ax_.hist( + self.y_prob[split_idx], + range=(0, 1), + bins=self.n_bins, + **hist_kwargs, + ) + + return lines + + def _plot_comparison_estimator( + self, + *, + line_kwargs: dict[str, Any], + hist_kwargs: dict[str, Any], + ) -> list[Any]: + """Plot calibration curves for multiple estimators. + + Parameters + ---------- + line_kwargs : dict[str, Any] + Keyword arguments for the line plots. + + hist_kwargs : dict[str, Any] + Keyword arguments for the histogram. + + Returns + ------- + lines : list[matplotlib Artist] + The plotted lines. + """ + lines = [] + + # Plot calibration curves for each estimator + for estimator_idx in range(len(self.y_prob)): + cal_line = self.ax_.plot( + self.prob_pred[self.pos_label][estimator_idx], + self.prob_true[self.pos_label][estimator_idx], + label=self.estimator_names[estimator_idx], + **line_kwargs, + )[0] + lines.append(cal_line) + + # Add to histogram + self.hist_ax_.hist( + self.y_prob[estimator_idx], + range=(0, 1), + bins=self.n_bins, + **hist_kwargs, + ) + + return lines + + @StyleDisplayMixin.style_plot + def plot( + self, + ax: Optional[plt.Axes] = None, + hist_ax: Optional[plt.Axes] = None, + *, + line_kwargs: Optional[dict[str, Any]] = None, + ref_line_kwargs: Optional[dict[str, Any]] = None, + hist_kwargs: Optional[dict[str, Any]] = None, + despine: bool = True, + ) -> None: + """Plot calibration curve (line plot + histogram). + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot the calibration curve on. If `None`, a new figure and + axes is created. + + hist_ax : matplotlib axes, default=None + Axes object to plot the histogram on. If `None`, a new figure and axes is + created. + + line_kwargs : dict[str, Any], default=None + Dictionary with keywords passed to the `matplotlib.pyplot.plot` call + for the calibration curves. + + ref_line_kwargs : dict[str, Any], default=None + Dictionary with keywords passed to the `matplotlib.pyplot.plot` call + for the reference line (perfectly calibrated). + + hist_kwargs : dict[str, Any], default=None + Dictionary with keywords passed to the `matplotlib.pyplot.hist` call + for the histogram. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.linear_model import LogisticRegression + >>> from skore import train_test_split + >>> from skore import EstimatorReport + >>> X, y = make_classification(random_state=0) + >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) + >>> classifier = LogisticRegression() + >>> report = EstimatorReport(classifier, **split_data) + >>> display = report.metrics.calibration_curve(pos_label=1) + >>> display.plot() + """ + # Create figure and axes if not provided + if ax is None or hist_ax is None: + fig, (self.ax_, self.hist_ax_) = plt.subplots( + nrows=2, figsize=(8, 8), height_ratios=[2, 1], sharex=True + ) + self.figure_ = fig + else: + self.ax_ = ax + self.hist_ax_ = hist_ax + self.figure_ = ax.figure + + # Set default kwargs + default_line_kwargs = {"alpha": 0.8} + default_ref_line_kwargs = { + "color": "black", + "linestyle": "--", + "alpha": 0.8, + "label": "Perfectly calibrated", + } + default_hist_kwargs = {"alpha": 0.5, "color": "gray"} + + # Update with user-provided kwargs + line_kwargs_validated = _validate_style_kwargs( + default_line_kwargs, + line_kwargs or self._default_line_kwargs or {}, + ) + ref_line_kwargs_validated = _validate_style_kwargs( + default_ref_line_kwargs, + ref_line_kwargs or self._default_ref_line_kwargs or {}, + ) + hist_kwargs_validated = _validate_style_kwargs( + default_hist_kwargs, + hist_kwargs or self._default_hist_kwargs or {}, + ) + + # Create plot based on report type + if self.report_type == "estimator": + self.line_ = self._plot_single_estimator( + line_kwargs=line_kwargs_validated, + hist_kwargs=hist_kwargs_validated, + ) + elif self.report_type == "cross-validation": + self.line_ = self._plot_cross_validated_estimator( + line_kwargs=line_kwargs_validated, + hist_kwargs=hist_kwargs_validated, + ) + elif self.report_type == "comparison-estimator": + self.line_ = self._plot_comparison_estimator( + line_kwargs=line_kwargs_validated, + hist_kwargs=hist_kwargs_validated, + ) + else: + raise ValueError( + f"`report_type` should be one of 'estimator', 'cross-validation', " + f"or 'comparison-estimator'. Got '{self.report_type}' instead." + ) + + # Plot reference line + self.ax_.plot([0, 1], [0, 1], **ref_line_kwargs_validated) + + # Set labels and titles + self.ax_.set_ylabel("Fraction of positives") + self.ax_.set_ylim([0, 1]) + self.ax_.legend(loc="upper left") + + self.hist_ax_.set_xlabel("Mean predicted probability") + self.hist_ax_.set_ylabel("Count") + self.hist_ax_.set_xlim([0, 1]) + + if self.data_source in ("train", "test"): + self.ax_.set_title( + "Calibration curve (reliability diagram) - " + f"{self.data_source.title()} set" + ) + else: + self.ax_.set_title("Calibration curve (reliability diagram)") + + plt.tight_layout() + + # Despine if requested + if despine: + _despine_matplotlib_axis(self.ax_, x_range=(0, 1), y_range=(0, 1)) + _despine_matplotlib_axis( + self.hist_ax_, x_range=(0, 1), y_range=self.hist_ax_.get_ylim() + ) From 095669d294b87d0f3b6ad252d30cb7f62df2ba83 Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 15:08:29 +0530 Subject: [PATCH 2/8] feat: Update imports for Calibration curve metrics --- skore/src/skore/sklearn/_plot/__init__.py | 2 ++ skore/src/skore/sklearn/_plot/metrics/__init__.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/skore/src/skore/sklearn/_plot/__init__.py b/skore/src/skore/sklearn/_plot/__init__.py index c53a218fdd..ebbd979297 100644 --- a/skore/src/skore/sklearn/_plot/__init__.py +++ b/skore/src/skore/sklearn/_plot/__init__.py @@ -1,4 +1,5 @@ from skore.sklearn._plot.metrics import ( + CalibrationCurveDisplay, ConfusionMatrixDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay, @@ -6,6 +7,7 @@ ) __all__ = [ + "CalibrationCurveDisplay", "ConfusionMatrixDisplay", "RocCurveDisplay", "PrecisionRecallCurveDisplay", diff --git a/skore/src/skore/sklearn/_plot/metrics/__init__.py b/skore/src/skore/sklearn/_plot/metrics/__init__.py index cdd5b05e71..40bdf357db 100644 --- a/skore/src/skore/sklearn/_plot/metrics/__init__.py +++ b/skore/src/skore/sklearn/_plot/metrics/__init__.py @@ -1,3 +1,4 @@ +from skore.sklearn._plot.metrics.calibration_curve import CalibrationCurveDisplay from skore.sklearn._plot.metrics.confusion_matrix import ConfusionMatrixDisplay from skore.sklearn._plot.metrics.precision_recall_curve import ( PrecisionRecallCurveDisplay, @@ -6,6 +7,7 @@ from skore.sklearn._plot.metrics.roc_curve import RocCurveDisplay __all__ = [ + "CalibrationCurveDisplay", "ConfusionMatrixDisplay", "PrecisionRecallCurveDisplay", "PredictionErrorDisplay", From f9a553a3df6cca20a86a283fa9218f04b2e54f22 Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 13:31:10 +0530 Subject: [PATCH 3/8] feat: Register calibration_curve method in metrics accessor --- .../sklearn/_estimator/metrics_accessor.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py index 04583a09f4..d2d5709a61 100644 --- a/skore/src/skore/sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -16,6 +16,7 @@ from skore.sklearn._base import _BaseAccessor, _get_cached_response_values from skore.sklearn._estimator.report import EstimatorReport from skore.sklearn._plot import ( + CalibrationCurveDisplay, ConfusionMatrixDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay, @@ -1626,10 +1627,22 @@ def _get_display( data_source: DataSource, response_method: Union[str, list[str], tuple[str, ...]], display_class: type[ - Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay] + Union[ + RocCurveDisplay, + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + CalibrationCurveDisplay, + ConfusionMatrixDisplay, + ] ], display_kwargs: dict[str, Any], - ) -> Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay]: + ) -> Union[ + RocCurveDisplay, + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + CalibrationCurveDisplay, + ConfusionMatrixDisplay, + ]: """Get the display from the cache or compute it. Parameters From 0d55d6661038b424907bc6bad48e97d591a83326 Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 14:51:41 +0530 Subject: [PATCH 4/8] feat: Add support for multiclass calibration --- .../_plot/metrics/calibration_curve.py | 134 +++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py b/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py index 401ef44ee4..7921536a94 100644 --- a/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py +++ b/skore/src/skore/sklearn/_plot/metrics/calibration_curve.py @@ -1,7 +1,9 @@ from typing import Any, Literal, Optional, Union import matplotlib.pyplot as plt +import numpy as np from numpy.typing import NDArray +from sklearn.calibration import calibration_curve from skore.sklearn._plot.style import StyleDisplayMixin from skore.sklearn._plot.utils import ( @@ -9,7 +11,7 @@ _despine_matplotlib_axis, _validate_style_kwargs, ) -from skore.sklearn.types import MLTask, PositiveLabel +from skore.sklearn.types import MLTask, PositiveLabel, YPlotData class CalibrationCurveDisplay(StyleDisplayMixin, HelpDisplayMixin): @@ -370,3 +372,133 @@ def plot( _despine_matplotlib_axis( self.hist_ax_, x_range=(0, 1), y_range=self.hist_ax_.get_ylim() ) + + @classmethod + def _compute_data_for_display( + cls, + y_true: list[YPlotData], + y_pred: list[YPlotData], + *, + report_type: Literal["cross-validation", "estimator", "comparison-estimator"], + estimator_names: list[str], + ml_task: MLTask, + data_source: Literal["train", "test", "X_y"], + pos_label: PositiveLabel, + strategy: str = "uniform", + n_bins: int = 5, + **kwargs, + ) -> "CalibrationCurveDisplay": + """Compute the calibration curve data. + + Parameters + ---------- + y_true : list[YPlotData] + True target values. + + y_pred : list[YPlotData] + Predicted probabilities. + + report_type : {"cross-validation", "estimator", "comparison-estimator"} + The type of report. + + estimator_names : list[str] + Names of the estimators. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"} + The data source used to compute the calibration curve. + + pos_label : PositiveLabel + The positive class label. + + strategy : str, default="uniform" + Strategy used to define the widths of the bins: 'uniform' or 'quantile'. + + n_bins : int, default=5 + Number of bins to use when calculating the calibration curve. + + **kwargs : Dict + Additional keyword arguments to be compatible with other metrics. + + Returns + ------- + display : CalibrationCurveDisplay + The display object with computed calibration data. + """ + # Support both binary and multiclass classification + supported_tasks = ["binary-classification", "multiclass-classification"] + if ml_task not in supported_tasks: + raise ValueError( + f"The machine learning task must be one of {supported_tasks}. " + f"Got {ml_task} instead." + ) + + allowed_strategies = ["uniform", "quantile"] + if strategy not in allowed_strategies: + raise ValueError( + f"strategy must be one of {allowed_strategies}. Got {strategy} instead." + ) + + prob_true: dict[Any, list[NDArray]] = {pos_label: []} + prob_pred: dict[Any, list[NDArray]] = {pos_label: []} + y_prob: list[NDArray] = [] + + # Compute calibration curve for each estimator + for y_true_i, y_pred_i in zip(y_true, y_pred): + # Get binary target values + y_true_binary = (np.array(y_true_i.y) == pos_label).astype(int) + + # Get probabilities - handle both direct probabilities or 2D arrays + y_pred_array = np.array(y_pred_i.y) + + # If y_pred is a 2D array with multiple columns (probability for each class) + if len(y_pred_array.shape) == 2 and y_pred_array.shape[1] >= 2: + # For binary classification with standard sklearn format + if y_pred_array.shape[1] == 2: + # Use second column (index 1) for positive class probability + # Standard convention in sklearn + y_pred_proba = y_pred_array[:, 1] + else: + # For multi-class, try to find the column for pos_label + pos_idx = 1 # Default to second column + if hasattr(y_pred_i, "classes") and hasattr( + y_pred_i.classes, "__iter__" + ): + # If classes are available, find the index of pos_label + try: + classes = np.array(y_pred_i.classes) + pos_idx = np.where(classes == pos_label)[0][0] + except (IndexError, AttributeError): + pass + y_pred_proba = y_pred_array[:, pos_idx] + else: + y_pred_proba = y_pred_array + + # Store probabilities for histogram + y_prob.append(y_pred_proba) + + # Compute calibration curve + prob_true_i, prob_pred_i = calibration_curve( + y_true_binary, + y_pred_proba, + n_bins=n_bins, + strategy=strategy, + ) + + prob_true[pos_label].append(prob_true_i) + prob_pred[pos_label].append(prob_pred_i) + + return cls( + prob_true=prob_true, + prob_pred=prob_pred, + y_prob=y_prob, + estimator_names=estimator_names, + pos_label=pos_label, + data_source=data_source, + ml_task=ml_task, + report_type=report_type, + n_bins=n_bins, + strategy=strategy, + ) From 59acb7b97d40296d1b37802e93174f39cb095504 Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 15:05:06 +0530 Subject: [PATCH 5/8] feat: Add calibration_curve func to estimator metrics for multiclass classification --- .../sklearn/_estimator/metrics_accessor.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py index d2d5709a61..cb82cdc69a 100644 --- a/skore/src/skore/sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -1875,6 +1875,99 @@ def precision_recall( ) return display + @available_if( + _check_all_checks( + checks=[ + _check_supported_ml_task( + supported_ml_tasks=[ + "binary-classification", + "multiclass-classification", + ] + ), + _check_estimator_has_method(method_name="predict_proba"), + ] + ) + ) + def calibration_curve( + self, + *, + data_source: DataSource = "test", + X: Optional[ArrayLike] = None, + y: Optional[ArrayLike] = None, + pos_label: PositiveLabel, + strategy: str = "uniform", + n_bins: int = 5, + ) -> CalibrationCurveDisplay: + """Plot the calibration curve (reliability diagram). + + A calibration curve shows how well a model's predicted probabilities + match observed outcomes. It plots the mean predicted probability in each bin + against the fraction of positive samples in that bin. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the report. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the report. + + pos_label : int, float, bool or str + The positive class label. + + strategy : {"uniform", "quantile"}, default="uniform" + Strategy used to define the widths of the bins. + + - "uniform" : The bins have identical widths. + - "quantile" : The bins have the same number of samples and depend + on predicted probabilities. + + n_bins : int, default=5 + Number of bins to use when calculating the calibration curve. + + Returns + ------- + CalibrationCurveDisplay + The calibration curve display. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.linear_model import LogisticRegression + >>> from skore import train_test_split + >>> from skore import EstimatorReport + >>> X, y = make_classification(random_state=0) + >>> split_data = train_test_split(X=X, y=y, random_state=0, as_dict=True) + >>> classifier = LogisticRegression() + >>> report = EstimatorReport(classifier, **split_data) + >>> display = report.metrics.calibration_curve(pos_label=1) + >>> display.plot() + """ + response_method = "predict_proba" + display_kwargs = { + "pos_label": pos_label, + "strategy": strategy, + "n_bins": n_bins, + } + display = self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=CalibrationCurveDisplay, + display_kwargs=display_kwargs, + ) + return cast(CalibrationCurveDisplay, display) + @available_if( _check_supported_ml_task( supported_ml_tasks=["regression", "multioutput-regression"] From 5299a8d573597e99c4ffc6841ec23a2a87356775 Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 15:07:02 +0530 Subject: [PATCH 6/8] feat: Add calibration_curve func to comparison metrics for multiclass classification --- .../sklearn/_comparison/metrics_accessor.py | 105 +++++++++++++++++- 1 file changed, 102 insertions(+), 3 deletions(-) diff --git a/skore/src/skore/sklearn/_comparison/metrics_accessor.py b/skore/src/skore/sklearn/_comparison/metrics_accessor.py index cd7ca8d992..ae06a86af3 100644 --- a/skore/src/skore/sklearn/_comparison/metrics_accessor.py +++ b/skore/src/skore/sklearn/_comparison/metrics_accessor.py @@ -12,6 +12,7 @@ from skore.sklearn._base import _BaseAccessor, _get_cached_response_values from skore.sklearn._comparison.report import ComparisonReport from skore.sklearn._plot.metrics import ( + CalibrationCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay, RocCurveDisplay, @@ -1226,12 +1227,22 @@ def _get_display( X: Union[ArrayLike, None], y: Union[ArrayLike, None], data_source: DataSource, - response_method: Union[str, list[str]], + response_method: Union[str, list[str], tuple[str, ...]], display_class: type[ - Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay] + Union[ + RocCurveDisplay, + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + CalibrationCurveDisplay, + ] ], display_kwargs: dict[str, Any], - ) -> Union[RocCurveDisplay, PrecisionRecallCurveDisplay, PredictionErrorDisplay]: + ) -> Union[ + RocCurveDisplay, + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + CalibrationCurveDisplay, + ]: """Get the display from the cache or compute it. Parameters @@ -1573,3 +1584,91 @@ def prediction_error( ), ) return display + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def calibration_curve( + self, + *, + data_source: DataSource = "test", + X: Optional[ArrayLike] = None, + y: Optional[ArrayLike] = None, + pos_label: PositiveLabel, + strategy: str = "uniform", + n_bins: int = 5, + ) -> CalibrationCurveDisplay: + """Plot the calibration curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the report. + - "train" : use the train set provided when creating the report. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the report. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the report. + + pos_label : int, float, bool or str + The positive class. + + strategy : {'uniform', 'quantile'}, default='uniform' + Strategy used to define the widths of the bins. + + - 'uniform': All bins have identical widths. + - 'quantile': All bins have the same number of points. + + n_bins : int, default=5 + Number of bins to use when calculating the histogram. + + Returns + ------- + CalibrationCurveDisplay + The calibration curve display. + + Examples + -------- + >>> from sklearn.datasets import load_breast_cancer + >>> from sklearn.linear_model import LogisticRegression + >>> from skore import train_test_split + >>> from skore import ComparisonReport, EstimatorReport + >>> X, y = load_breast_cancer(return_X_y=True) + >>> split_data = train_test_split(X=X, y=y, random_state=42, as_dict=True) + >>> estimator_1 = LogisticRegression(max_iter=10000, random_state=42) + >>> estimator_report_1 = EstimatorReport(estimator_1, **split_data) + >>> estimator_2 = LogisticRegression(max_iter=10000, random_state=43) + >>> estimator_report_2 = EstimatorReport(estimator_2, **split_data) + >>> comparison_report = ComparisonReport( + ... [estimator_report_1, estimator_report_2] + ... ) + >>> display = comparison_report.metrics.calibration_curve(pos_label=1) + >>> display.plot() + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = { + "pos_label": pos_label, + "strategy": strategy, + "n_bins": n_bins, + } + display = cast( + CalibrationCurveDisplay, + self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=CalibrationCurveDisplay, + display_kwargs=display_kwargs, + ), + ) + return display From 5c81f69649165ee635de48d368a3a5892425e2f8 Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 15:13:06 +0530 Subject: [PATCH 7/8] test: Add tests for calibration curve --- .../sklearn/plot/test_calibration_curve.py | 397 ++++++++++++++++++ 1 file changed, 397 insertions(+) create mode 100644 skore/tests/unit/sklearn/plot/test_calibration_curve.py diff --git a/skore/tests/unit/sklearn/plot/test_calibration_curve.py b/skore/tests/unit/sklearn/plot/test_calibration_curve.py new file mode 100644 index 0000000000..3e50d889dc --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_calibration_curve.py @@ -0,0 +1,397 @@ +"""Tests for the calibration curve display.""" + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from matplotlib.gridspec import GridSpec +from sklearn.calibration import calibration_curve +from sklearn.datasets import make_classification +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LogisticRegression +from sklearn.naive_bayes import GaussianNB +from sklearn.svm import SVC +from skore import ComparisonReport, EstimatorReport, train_test_split +from skore.sklearn._plot.metrics.calibration_curve import CalibrationCurveDisplay +from skore.sklearn.types import YPlotData + + +@pytest.fixture +def binary_classification_data(): + """Create binary classification data for testing.""" + X, y = make_classification( + n_samples=1000, n_classes=2, n_informative=4, random_state=42 + ) + return train_test_split(X=X, y=y, random_state=42, as_dict=True) + + +@pytest.fixture +def binary_classification_model(): + """Create a classification model for testing.""" + return LogisticRegression(random_state=42) + + +def test_calibration_curve_display_init(): + """Test that CalibrationCurveDisplay can be initialized with proper attributes.""" + # Create dummy data + prob_true = {1: [np.array([0.1, 0.3, 0.5, 0.7, 0.9])]} + prob_pred = {1: [np.array([0.2, 0.4, 0.6, 0.8, 1.0])]} + y_prob = [np.array([0.2, 0.4, 0.6, 0.8, 1.0])] + + # Initialize the display + display = CalibrationCurveDisplay( + prob_true=prob_true, + prob_pred=prob_pred, + y_prob=y_prob, + estimator_names=["TestEstimator"], + pos_label=1, + data_source="test", + ml_task="binary-classification", + report_type="estimator", + n_bins=5, + strategy="uniform", + ) + + # Check that the attributes are set correctly + assert display.prob_true == prob_true + assert display.prob_pred == prob_pred + assert display.y_prob == y_prob + assert display.estimator_names == ["TestEstimator"] + assert display.pos_label == 1 + assert display.data_source == "test" + assert display.ml_task == "binary-classification" + assert display.report_type == "estimator" + assert display.n_bins == 5 + assert display.strategy == "uniform" + + +def test_calibration_curve_from_report( + binary_classification_data, binary_classification_model +): + """Test that the calibration curve can be created from an EstimatorReport.""" + # Create a report + report = EstimatorReport(binary_classification_model, **binary_classification_data) + + # Get the calibration curve display + display = report.metrics.calibration_curve(pos_label=1) + + # Check that the display is of the right type + assert isinstance(display, CalibrationCurveDisplay) + + # Check basic attributes + assert display.pos_label == 1 + assert display.data_source == "test" + assert display.ml_task == "binary-classification" + assert display.report_type == "estimator" + + +def test_calibration_curve_plotting(): + """Test that the calibration curve plotting works correctly.""" + # Create synthetic data + X, y = make_classification(n_samples=1000, n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + # Fit a model + model = LogisticRegression(random_state=42) + model.fit(X_train, y_train) + y_prob = model.predict_proba(X_test)[:, 1] + + # Calculate calibration curve manually + prob_true, prob_pred = calibration_curve( + y_test, y_prob, n_bins=5, strategy="uniform" + ) + + # Create YPlotData objects for the display + y_true_data = [ + YPlotData(estimator_name="LogisticRegression", split_index=None, y=y_test) + ] + y_pred_data = [ + YPlotData( + estimator_name="LogisticRegression", + split_index=None, + y=model.predict_proba(X_test), + ) + ] + + # Create the display directly + display = CalibrationCurveDisplay._compute_data_for_display( + y_true=y_true_data, + y_pred=y_pred_data, + report_type="estimator", + estimator_names=["LogisticRegression"], + ml_task="binary-classification", + data_source="test", + pos_label=1, + n_bins=5, + strategy="uniform", + ) + + # Check that the computed values match + np.testing.assert_allclose(display.prob_true[1][0], prob_true) + np.testing.assert_allclose(display.prob_pred[1][0], prob_pred) + + # Test that the plot method works without error + fig, (ax, hist_ax) = plt.subplots(nrows=2, figsize=(8, 8), height_ratios=[2, 1]) + display.plot(ax=ax, hist_ax=hist_ax) + plt.close(fig) + + +def test_multiple_models_calibration(): + """Test calibration curves with multiple models.""" + # Create data + X, y = make_classification(n_samples=1000, n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + # Create and fit models + models = { + "Logistic Regression": LogisticRegression(random_state=42), + "Naive Bayes": GaussianNB(), + "SVC": SVC(probability=True, random_state=42), + "Random Forest": RandomForestClassifier(random_state=42), + } + + # Fit all models + for _name, model in models.items(): + model.fit(X_train, y_train) + + # Create YPlotData objects for each model + y_true_data = [] + y_pred_data = [] + + for name, model in models.items(): + y_true_data.append(YPlotData(estimator_name=name, split_index=None, y=y_test)) + y_pred_proba = model.predict_proba(X_test) + y_pred_data.append( + YPlotData(estimator_name=name, split_index=None, y=y_pred_proba) + ) + + # Create the display + display = CalibrationCurveDisplay._compute_data_for_display( + y_true=y_true_data, + y_pred=y_pred_data, + report_type="comparison-estimator", + estimator_names=list(models.keys()), + ml_task="binary-classification", + data_source="test", + pos_label=1, + n_bins=10, + strategy="uniform", + ) + + # Check that we have data for each model + assert len(display.prob_true[1]) == len(models) + assert len(display.prob_pred[1]) == len(models) + assert len(display.y_prob) == len(models) + + # Create a plot + fig, (ax, hist_ax) = plt.subplots(nrows=2, figsize=(10, 10), height_ratios=[2, 1]) + display.plot(ax=ax, hist_ax=hist_ax) + plt.close(fig) + + +def test_calibration_curve_parameters(): + """Test that the calibration curve parameters are correctly handled.""" + # Create synthetic data and use keyword arguments for train_test_split + X, y = make_classification(n_samples=1000, random_state=42) + split_data = train_test_split(X=X, y=y, random_state=42, as_dict=True) + + # Create the report + report = EstimatorReport(LogisticRegression(random_state=42), **split_data) + + # Test default parameters + display = report.metrics.calibration_curve(pos_label=1) + assert display.n_bins == 5 + assert display.strategy == "uniform" + + # Test with custom parameters + display = report.metrics.calibration_curve( + pos_label=1, n_bins=10, strategy="quantile" + ) + assert display.n_bins == 10 + assert display.strategy == "quantile" + + # Test invalid strategy + with pytest.raises(ValueError): + CalibrationCurveDisplay._compute_data_for_display( + y_true=[ + YPlotData(estimator_name="test", split_index=None, y=np.array([0, 1])) + ], + y_pred=[ + YPlotData( + estimator_name="test", + split_index=None, + y=np.array([[0.1, 0.9], [0.2, 0.8]]), + ) + ], + report_type="estimator", + estimator_names=["test"], + ml_task="binary-classification", + data_source="test", + pos_label=1, + strategy="invalid", + ) + + +def test_different_strategies_and_bins(): + """Test with different strategies and bin counts.""" + # Create data + X, y = make_classification(n_samples=500, n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + # Create and fit model + model = LogisticRegression(random_state=42) + model.fit(X_train, y_train) + + # Create YPlotData objects + y_true_data = [ + YPlotData(estimator_name="LogisticRegression", split_index=None, y=y_test) + ] + y_pred_proba = model.predict_proba(X_test) + y_pred_data = [ + YPlotData(estimator_name="LogisticRegression", split_index=None, y=y_pred_proba) + ] + + # Test with different strategies + for strategy in ["uniform", "quantile"]: + for n_bins in [5, 10, 20]: + # Create the display + display = CalibrationCurveDisplay._compute_data_for_display( + y_true=y_true_data, + y_pred=y_pred_data, + report_type="estimator", + estimator_names=["LogisticRegression"], + ml_task="binary-classification", + data_source="test", + pos_label=1, + n_bins=n_bins, + strategy=strategy, + ) + + # Check that the display has the right attributes + assert display.strategy == strategy + assert display.n_bins == n_bins + + # Check that we have the expected number of bins + # The actual number of bins may be fewer than requested if some bins + # are empty. This is expected behavior from scikit-learn's + # calibration_curve function + assert len(display.prob_true[1][0]) >= n_bins - 5 + assert len(display.prob_true[1][0]) <= n_bins + + +def test_multiclass_calibration(): + """Test calibration curves with multiclass classification.""" + # Create multiclass data + X, y = make_classification( + n_samples=1000, n_classes=3, n_informative=6, random_state=42 + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + # Create and fit model + model = LogisticRegression(random_state=42) + model.fit(X_train, y_train) + + # Create report + report = EstimatorReport( + model, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + # Check each class + for pos_label in range(3): + # Get the calibration curve display for this class + display = report.metrics.calibration_curve(pos_label=pos_label) + + # Check that the display is of the right type + assert isinstance(display, CalibrationCurveDisplay) + + # Verify attributes + assert display.pos_label == pos_label + assert display.data_source == "test" + assert display.ml_task == "multiclass-classification" + assert display.report_type == "estimator" + + # Test that we can plot without error + fig, (ax, hist_ax) = plt.subplots(nrows=2, figsize=(8, 8), height_ratios=[2, 1]) + display.plot(ax=ax, hist_ax=hist_ax) + plt.close(fig) + + # Test plotting multiple classes on the same plot + fig = plt.figure(figsize=(10, 10)) + gs = GridSpec(2, 1, height_ratios=[2, 1]) + ax = fig.add_subplot(gs[0]) + ax.set_title("Multiclass Calibration Curves") + ax.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") + + # Plot each class + for pos_label in range(3): + display = report.metrics.calibration_curve(pos_label=pos_label) + ax.plot( + display.prob_pred[pos_label][0], + display.prob_true[pos_label][0], + marker="o", + label=f"Class {pos_label}", + ) + + ax.legend() + plt.close(fig) + + +def test_comparison_report_integration(): + """Test that the calibration curve works with ComparisonReport.""" + # Create data + X, y = make_classification(n_samples=1000, n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + # Create and fit models + models = { + "Logistic Regression": LogisticRegression(random_state=42), + "Random Forest": RandomForestClassifier(random_state=42), + } + + # Fit models + for _name, model in models.items(): + model.fit(X_train, y_train) + + # Create reports for each model + reports = {} + for name, model in models.items(): + reports[name] = EstimatorReport( + model, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + # Create comparison report + comparison = ComparisonReport(reports) + + # Get calibration curve from comparison report + display = comparison.metrics.calibration_curve(pos_label=1) + + assert isinstance(display, CalibrationCurveDisplay) + + assert display.pos_label == 1 + assert display.report_type == "comparison-estimator" + assert display.ml_task == "binary-classification" + assert len(display.estimator_names) == len(models) + assert all(name in display.estimator_names for name in models) + + # Test plotting with GridSpec + fig = plt.figure(figsize=(10, 10)) + gs = GridSpec(3, 2, height_ratios=[2, 1, 1]) + + # Plot all curves on one axis + ax_curve = fig.add_subplot(gs[0, :]) + ax_curve.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") + + # Plot each model's curve + for i, name in enumerate(models.keys()): + ax_curve.plot( + display.prob_pred[1][i], display.prob_true[1][i], marker="o", label=name + ) + + ax_curve.legend() + + # Plot histograms + for i, name in enumerate(models.keys()): + row, col = divmod(i, 2) + ax_hist = fig.add_subplot(gs[row + 1, col]) + ax_hist.hist(display.y_prob[i], range=(0, 1), bins=10) + ax_hist.set_title(name) + plt.close(fig) From 438eef19b82332dc4e77680719bbd078dac735ad Mon Sep 17 00:00:00 2001 From: waridrox Date: Sat, 24 May 2025 17:29:41 +0530 Subject: [PATCH 8/8] feat: Add DisplayClassProtocol for type-safe display classes --- .../sklearn/_estimator/metrics_accessor.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py index cb82cdc69a..28ce1ff30d 100644 --- a/skore/src/skore/sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from functools import partial from operator import attrgetter -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import Any, Callable, Literal, Optional, Protocol, Union, cast import joblib import numpy as np @@ -33,6 +33,11 @@ DataSource = Literal["test", "train", "X_y"] +class DisplayClassProtocol(Protocol): + @classmethod + def _compute_data_for_display(cls, *args, **kwargs) -> Any: ... + + class _MetricsAccessor(_BaseAccessor["EstimatorReport"], DirNamesMixin): """Accessor for metrics-related operations. @@ -1626,15 +1631,7 @@ def _get_display( y: Union[ArrayLike, None], data_source: DataSource, response_method: Union[str, list[str], tuple[str, ...]], - display_class: type[ - Union[ - RocCurveDisplay, - PrecisionRecallCurveDisplay, - PredictionErrorDisplay, - CalibrationCurveDisplay, - ConfusionMatrixDisplay, - ] - ], + display_class: type[DisplayClassProtocol], display_kwargs: dict[str, Any], ) -> Union[ RocCurveDisplay,