Skip to content

Commit eb25f59

Browse files
committed
feat(CrossValidationReport): Add threshold averaging for roc plot
1 parent 248f76a commit eb25f59

File tree

6 files changed

+82
-4
lines changed

6 files changed

+82
-4
lines changed

skore/src/skore/sklearn/_cross_validation/metrics_accessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ def _get_display(
11001100
*,
11011101
X: Optional[ArrayLike] = None,
11021102
y: Optional[ArrayLike] = None,
1103+
average: Optional[Literal["threshold"]] = None,
11031104
data_source: DataSource,
11041105
response_method: str,
11051106
display_class: type[
@@ -1208,6 +1209,7 @@ def _get_display(
12081209
display = display_class._compute_data_for_display(
12091210
y_true=y_true,
12101211
y_pred=y_pred,
1212+
average=average,
12111213
report_type="cross-validation",
12121214
estimators=[
12131215
report.estimator_ for report in self._parent.estimator_reports_
@@ -1232,6 +1234,7 @@ def roc(
12321234
data_source: DataSource = "test",
12331235
X: Optional[ArrayLike] = None,
12341236
y: Optional[ArrayLike] = None,
1237+
average: Optional[Literal["threshold"]] = None,
12351238
pos_label: Optional[PositiveLabel] = None,
12361239
) -> RocCurveDisplay:
12371240
"""Plot the ROC curve.
@@ -1280,6 +1283,7 @@ def roc(
12801283
data_source=data_source,
12811284
X=X,
12821285
y=y,
1286+
average=average,
12831287
response_method=response_method,
12841288
display_class=RocCurveDisplay,
12851289
display_kwargs=display_kwargs,

skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def _compute_data_for_display(
545545
cls,
546546
y_true: Sequence[YPlotData],
547547
y_pred: Sequence[YPlotData],
548+
average: Optional[Literal["threshold"]] = None,
548549
*,
549550
report_type: Literal["comparison-estimator", "cross-validation", "estimator"],
550551
estimators: Sequence[BaseEstimator],

skore/src/skore/sklearn/_plot/metrics/roc_curve.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ def _compute_data_for_display(
618618
cls,
619619
y_true: Sequence[YPlotData],
620620
y_pred: Sequence[YPlotData],
621+
average: Optional[Literal["threshold"]] = None,
621622
*,
622623
report_type: Literal["comparison-estimator", "cross-validation", "estimator"],
623624
estimators: Sequence[BaseEstimator],
@@ -672,21 +673,40 @@ def _compute_data_for_display(
672673

673674
fpr: dict[PositiveLabel, list[ArrayLike]] = defaultdict(list)
674675
tpr: dict[PositiveLabel, list[ArrayLike]] = defaultdict(list)
676+
thresholds: dict[PositiveLabel, list[ArrayLike]] = defaultdict(list)
675677
roc_auc: dict[PositiveLabel, list[float]] = defaultdict(list)
676678

677679
if ml_task == "binary-classification":
680+
pos_label_validated = cast(PositiveLabel, pos_label_validated)
678681
for y_true_i, y_pred_i in zip(y_true, y_pred):
679-
fpr_i, tpr_i, _ = roc_curve(
682+
fpr_i, tpr_i, thresholds_i = roc_curve(
680683
y_true_i.y,
681684
y_pred_i.y,
682685
pos_label=pos_label,
683686
drop_intermediate=drop_intermediate,
684687
)
685688
roc_auc_i = auc(fpr_i, tpr_i)
686-
pos_label_validated = cast(PositiveLabel, pos_label_validated)
687689
fpr[pos_label_validated].append(fpr_i)
688690
tpr[pos_label_validated].append(tpr_i)
691+
thresholds[pos_label_validated].append(thresholds_i)
689692
roc_auc[pos_label_validated].append(roc_auc_i)
693+
if average is not None:
694+
if average == "threshold":
695+
average_fpr, average_tpr = cls._threshold_average(
696+
fpr[pos_label_validated],
697+
tpr[pos_label_validated],
698+
thresholds[pos_label_validated],
699+
)
700+
else:
701+
raise TypeError(
702+
"'threshold' is the only supported option for `average`,"
703+
f"but got {average} instead"
704+
)
705+
average_roc_auc = auc(average_fpr, average_tpr)
706+
fpr[pos_label_validated] = [average_fpr]
707+
tpr[pos_label_validated] = [average_tpr]
708+
roc_auc[pos_label_validated] = [average_roc_auc]
709+
690710
else: # multiclass-classification
691711
# OvR fashion to collect fpr, tpr, and roc_auc
692712
for y_true_i, y_pred_i, est in zip(y_true, y_pred, estimators):

skore/src/skore/sklearn/_plot/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from matplotlib.axes import Axes
88
from matplotlib.colors import Colormap
9+
from numpy.typing import ArrayLike
910
from rich.console import Console
1011
from rich.panel import Panel
1112
from rich.tree import Tree
@@ -242,6 +243,45 @@ def _validate_from_predictions_params(
242243

243244
return pos_label
244245

246+
@staticmethod
247+
def _threshold_average(
248+
xs: list[ArrayLike], ys: list[ArrayLike], thresholds: list[ArrayLike]
249+
) -> tuple[ArrayLike, ArrayLike]:
250+
"""
251+
Private method to calculate threshold average roc or precision_recall_curve.
252+
253+
Parameters
254+
----------
255+
x : list of array-like of shape (n_samples,)
256+
False positive rates or precision
257+
y : list of array-like of shape (n_samples,)
258+
True positive rates or recall
259+
thresholds : list of array-like of shape (n_samples,)
260+
Thresholds
261+
"""
262+
unique_thresholds = sorted(np.unique(np.concatenate(thresholds)), reverse=True)
263+
264+
average_x = []
265+
average_y = []
266+
for target_threshold in unique_thresholds:
267+
threshold_x, threshold_y = [], []
268+
for x, y, threshold in zip(
269+
xs,
270+
ys,
271+
thresholds,
272+
):
273+
closest_idx = max(
274+
np.searchsorted(threshold[::-1], target_threshold, side="right")
275+
- 1,
276+
0,
277+
)
278+
closest_idx_inverted = (closest_idx + 1) * -1
279+
threshold_x.append(x[closest_idx_inverted])
280+
threshold_y.append(y[closest_idx_inverted])
281+
average_x.append(np.mean(threshold_x))
282+
average_y.append(np.mean(threshold_y))
283+
return average_x, average_y
284+
245285

246286
def _despine_matplotlib_axis(
247287
ax: Axes,

skore/tests/unit/sklearn/cross_validation/test_cross_validation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,12 @@ def test_cross_validation_report_metrics_data_source_external(
319319
########################################################################################
320320

321321

322-
def test_cross_validation_report_plot_roc(binary_classification_data):
322+
@pytest.mark.parametrize("average", [None, "threshold"])
323+
def test_cross_validation_report_plot_roc(binary_classification_data, average):
323324
"""Check that the ROC plot method works."""
324325
estimator, X, y = binary_classification_data
325326
report = CrossValidationReport(estimator, X, y, cv_splitter=2)
326-
assert isinstance(report.metrics.roc(), RocCurveDisplay)
327+
assert isinstance(report.metrics.roc(average=average), RocCurveDisplay)
327328

328329

329330
@pytest.mark.parametrize("display", ["roc", "precision_recall"])

skore/tests/unit/sklearn/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy
22
import pandas
33
import pytest
4+
from numpy.testing import assert_array_equal
45
from sklearn.cluster import KMeans
56
from sklearn.datasets import (
67
make_classification,
@@ -10,6 +11,7 @@
1011
from sklearn.dummy import DummyClassifier, DummyRegressor
1112
from sklearn.linear_model import LinearRegression, LogisticRegression
1213
from sklearn.multioutput import MultiOutputClassifier
14+
from skore.sklearn._plot.utils import _ClassifierCurveDisplayMixin
1315
from skore.sklearn.find_ml_task import _find_ml_task
1416

1517

@@ -118,3 +120,13 @@ def test_find_ml_task_pandas():
118120

119121
def test_find_ml_task_string():
120122
assert _find_ml_task(["0", "1", "2"], None) == "multiclass-classification"
123+
124+
125+
def test_threshold_average():
126+
xs = [numpy.array([3, 2, 1]), numpy.array([3, 2, 1])]
127+
ys = [numpy.array([3, 2, 1]), numpy.array([3, 2, 1])]
128+
thresholds = [numpy.array([4, 3, 1]), numpy.array([5, 3, 2])]
129+
x, y = _ClassifierCurveDisplayMixin._threshold_average(xs, ys, thresholds)
130+
expected = numpy.array([3, 2.5, 2, 1, 1])
131+
assert_array_equal(x, expected)
132+
assert_array_equal(y, expected)

0 commit comments

Comments
 (0)