Skip to content

Commit 62ec965

Browse files
refactor(roc-curve): Use _filter_by
1 parent b1fef42 commit 62ec965

File tree

4 files changed

+125
-114
lines changed

4 files changed

+125
-114
lines changed

skore/src/skore/sklearn/_plot/metrics/roc_curve.py

Lines changed: 105 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Literal, Optional, Union, cast
33

44
import matplotlib.pyplot as plt
5-
import numpy as np
65
from matplotlib import colormaps
76
from matplotlib.axes import Axes
87
from matplotlib.lines import Line2D
@@ -18,6 +17,7 @@
1817
HelpDisplayMixin,
1918
_ClassifierCurveDisplayMixin,
2019
_despine_matplotlib_axis,
20+
_filter_by,
2121
_validate_style_kwargs,
2222
sample_mpl_colormap,
2323
)
@@ -221,30 +221,38 @@ def _plot_single_estimator(
221221
)
222222

223223
for class_idx, class_label in enumerate(labels):
224-
roc_curve_label = self.roc_curve[self.roc_curve["label"] == class_label]
225-
fpr_class = roc_curve_label["fpr"]
226-
tpr_class = roc_curve_label["tpr"]
227-
roc_auc_class = self.roc_auc[self.roc_auc["label"] == class_label][
228-
"roc_auc"
229-
].iloc[0]
224+
roc_curve = _filter_by(
225+
self.roc_curve,
226+
label=class_label,
227+
)
228+
229+
roc_auc = _filter_by(
230+
self.roc_auc,
231+
label=class_label,
232+
)["roc_auc"].iloc[0]
233+
230234
roc_curve_kwargs_class = roc_curve_kwargs[class_idx]
231235

232236
default_line_kwargs: dict[str, Any] = {"color": class_colors[class_idx]}
233237
if self.data_source in ("train", "test"):
234238
default_line_kwargs["label"] = (
235239
f"{str(class_label).title()} - {self.data_source} "
236-
f"set (AUC = {roc_auc_class:0.2f})"
240+
f"set (AUC = {roc_auc:0.2f})"
237241
)
238242
else: # data_source in (None, "X_y")
239243
default_line_kwargs["label"] = (
240-
f"{str(class_label).title()} - AUC = {roc_auc_class:0.2f}"
244+
f"{str(class_label).title()} - AUC = {roc_auc:0.2f}"
241245
)
242246

243247
line_kwargs = _validate_style_kwargs(
244248
default_line_kwargs, roc_curve_kwargs_class
245249
)
246250

247-
(line,) = self.ax_.plot(fpr_class, tpr_class, **line_kwargs)
251+
(line,) = self.ax_.plot(
252+
roc_curve["fpr"],
253+
roc_curve["tpr"],
254+
**line_kwargs,
255+
)
248256
lines.append(line)
249257

250258
info_pos_label = None # irrelevant for multiclass
@@ -306,27 +314,29 @@ def _plot_cross_validated_estimator(
306314
if self.ml_task == "binary-classification":
307315
pos_label = cast(PositiveLabel, self.pos_label)
308316
for split_idx in self.roc_curve["split_index"].unique():
309-
fpr_split = self.roc_curve[
310-
(self.roc_curve["label"] == pos_label)
311-
& (self.roc_curve["split_index"] == split_idx)
312-
]["fpr"]
313-
tpr_split = self.roc_curve[
314-
(self.roc_curve["label"] == pos_label)
315-
& (self.roc_curve["split_index"] == split_idx)
316-
]["tpr"]
317-
roc_auc_split = self.roc_auc[
318-
(self.roc_auc["label"] == pos_label)
319-
& (self.roc_auc["split_index"] == split_idx)
320-
]["roc_auc"].iloc[0]
317+
roc_curve = _filter_by(
318+
self.roc_curve,
319+
label=pos_label,
320+
split_index=split_idx,
321+
)
322+
roc_auc = _filter_by(
323+
self.roc_auc,
324+
label=pos_label,
325+
split_index=split_idx,
326+
)["roc_auc"].iloc[0]
321327

322328
line_kwargs_validated = _validate_style_kwargs(
323329
line_kwargs, roc_curve_kwargs[split_idx]
324330
)
325331
line_kwargs_validated["label"] = (
326-
f"Estimator of fold #{split_idx + 1} (AUC = {roc_auc_split:0.2f})"
332+
f"Estimator of fold #{split_idx + 1} (AUC = {roc_auc:0.2f})"
327333
)
328334

329-
(line,) = self.ax_.plot(fpr_split, tpr_split, **line_kwargs_validated)
335+
(line,) = self.ax_.plot(
336+
roc_curve["fpr"],
337+
roc_curve["tpr"],
338+
**line_kwargs_validated,
339+
)
330340
lines.append(line)
331341

332342
info_pos_label = (
@@ -340,20 +350,18 @@ def _plot_cross_validated_estimator(
340350
)
341351

342352
for class_idx, class_label in enumerate(labels):
343-
roc_auc_class = self.roc_auc[self.roc_auc["label"] == class_label][
344-
"roc_auc"
345-
].iloc[0]
353+
roc_auc = _filter_by(
354+
self.roc_auc,
355+
label=class_label,
356+
)["roc_auc"].iloc[0]
346357
roc_curve_kwargs_class = roc_curve_kwargs[class_idx]
347358

348359
for split_idx in self.roc_curve["split_index"].unique():
349-
roc_curve_label = self.roc_curve[
350-
(self.roc_curve["label"] == class_label)
351-
& (self.roc_curve["split_index"] == split_idx)
352-
]
353-
fpr_split = roc_curve_label["fpr"]
354-
tpr_split = roc_curve_label["tpr"]
355-
roc_auc_mean = np.mean(roc_auc_class)
356-
roc_auc_std = np.std(roc_auc_class)
360+
roc_curve_label = _filter_by(
361+
self.roc_curve,
362+
label=class_label,
363+
split_index=split_idx,
364+
)
357365

358366
line_kwargs_validated = _validate_style_kwargs(
359367
{
@@ -365,14 +373,16 @@ def _plot_cross_validated_estimator(
365373
if split_idx == 0:
366374
line_kwargs_validated["label"] = (
367375
f"{str(class_label).title()} "
368-
f"(AUC = {roc_auc_mean:0.2f} +/- "
369-
f"{roc_auc_std:0.2f})"
376+
f"(AUC = {roc_auc.mean():0.2f} +/- "
377+
f"{roc_auc.std():0.2f})"
370378
)
371379
else:
372380
line_kwargs_validated["label"] = None
373381

374382
(line,) = self.ax_.plot(
375-
fpr_split, tpr_split, **line_kwargs_validated
383+
roc_curve_label["fpr"],
384+
roc_curve_label["tpr"],
385+
**line_kwargs_validated,
376386
)
377387
lines.append(line)
378388

@@ -437,24 +447,27 @@ def _plot_comparison_estimator(
437447
if self.ml_task == "binary-classification":
438448
pos_label = cast(PositiveLabel, self.pos_label)
439449
for est_idx, est_name in enumerate(estimator_names):
440-
roc_curve_estimator = self.roc_curve[
441-
(self.roc_curve["label"] == pos_label)
442-
& (self.roc_curve["estimator_name"] == est_name)
443-
]
444-
fpr_est = roc_curve_estimator["fpr"]
445-
tpr_est = roc_curve_estimator["tpr"]
446-
roc_auc_est = self.roc_auc[
447-
(self.roc_auc["label"] == pos_label)
448-
& (self.roc_auc["estimator_name"] == est_name)
449-
]["roc_auc"].iloc[0]
450+
roc_curve = _filter_by(
451+
self.roc_curve,
452+
label=pos_label,
453+
estimator_name=est_name,
454+
)
455+
456+
roc_auc = _filter_by(
457+
self.roc_auc,
458+
label=pos_label,
459+
estimator_name=est_name,
460+
)["roc_auc"].iloc[0]
450461

451462
line_kwargs_validated = _validate_style_kwargs(
452463
line_kwargs, roc_curve_kwargs[est_idx]
453464
)
454-
line_kwargs_validated["label"] = (
455-
f"{est_name} (AUC = {roc_auc_est:0.2f})"
465+
line_kwargs_validated["label"] = f"{est_name} (AUC = {roc_auc:0.2f})"
466+
(line,) = self.ax_.plot(
467+
roc_curve["fpr"],
468+
roc_curve["tpr"],
469+
**line_kwargs_validated,
456470
)
457-
(line,) = self.ax_.plot(fpr_est, tpr_est, **line_kwargs_validated)
458471
lines.append(line)
459472

460473
info_pos_label = (
@@ -471,16 +484,18 @@ def _plot_comparison_estimator(
471484
est_color = class_colors[est_idx]
472485

473486
for class_idx, class_label in enumerate(labels):
474-
roc_curve_estimator = self.roc_curve[
475-
(self.roc_curve["label"] == class_label)
476-
& (self.roc_curve["estimator_name"] == est_name)
477-
]
478-
fpr_est_class = roc_curve_estimator["fpr"]
479-
tpr_est_class = roc_curve_estimator["tpr"]
480-
roc_auc_mean = self.roc_auc[
481-
(self.roc_auc["label"] == class_label)
482-
& (self.roc_auc["estimator_name"] == est_name)
483-
]["roc_auc"].iloc[0]
487+
roc_curve = _filter_by(
488+
self.roc_curve,
489+
label=class_label,
490+
estimator_name=est_name,
491+
)
492+
493+
roc_auc = _filter_by(
494+
self.roc_auc,
495+
label=class_label,
496+
estimator_name=est_name,
497+
)["roc_auc"].iloc[0]
498+
484499
class_linestyle = LINESTYLE[(class_idx % len(LINESTYLE))][1]
485500

486501
line_kwargs["color"] = est_color
@@ -492,11 +507,11 @@ def _plot_comparison_estimator(
492507
)
493508
line_kwargs_validated["label"] = (
494509
f"{est_name} - {str(class_label).title()} "
495-
f"(AUC = {roc_auc_mean:0.2f})"
510+
f"(AUC = {roc_auc:0.2f})"
496511
)
497512

498513
(line,) = self.ax_.plot(
499-
fpr_est_class, tpr_est_class, **line_kwargs_validated
514+
roc_curve["fpr"], roc_curve["tpr"], **line_kwargs_validated
500515
)
501516
lines.append(line)
502517

@@ -564,28 +579,30 @@ def _plot_comparison_cross_validation(
564579
10 if len(estimator_names) < 10 else len(estimator_names),
565580
)
566581
for report_idx, estimator_name in enumerate(estimator_names):
567-
roc_auc_estimator = self.roc_auc[
568-
self.roc_auc["estimator_name"] == estimator_name
569-
]["roc_auc"]
582+
roc_curve = _filter_by(
583+
self.roc_curve,
584+
label=self.pos_label,
585+
estimator_name=estimator_name,
586+
)
587+
588+
roc_auc = _filter_by(
589+
self.roc_auc,
590+
estimator_name=estimator_name,
591+
)["roc_auc"]
570592

571593
line_kwargs_validated = _validate_style_kwargs(
572594
line_kwargs, roc_curve_kwargs[report_idx]
573595
)
574596
line_kwargs_validated["color"] = colors[report_idx]
575597
line_kwargs_validated["alpha"] = 0.6
576598

577-
roc_curve_estimator = self.roc_curve[
578-
(self.roc_curve["label"] == self.pos_label)
579-
& (self.roc_curve["estimator_name"] == estimator_name)
580-
]
581-
582-
for split_index, segment in roc_curve_estimator.groupby("split_index"):
599+
for split_index, segment in roc_curve.groupby("split_index"):
583600
if split_index == 0:
584601
label_kwargs = {
585602
"label": (
586603
f"{estimator_name} "
587-
f"(AUC = {roc_auc_estimator.mean():0.2f} +/- "
588-
f"{roc_auc_estimator.std():0.2f})"
604+
f"(AUC = {roc_auc.mean():0.2f} +/- "
605+
f"{roc_auc.std():0.2f})"
589606
)
590607
}
591608
else:
@@ -630,31 +647,31 @@ def _plot_comparison_cross_validation(
630647
est_color = colors[est_idx]
631648

632649
for label_idx, label in enumerate(labels):
633-
roc_auc_estimator = self.roc_auc[
634-
(self.roc_auc["label"] == label)
635-
& (self.roc_auc["estimator_name"] == estimator_name)
636-
]["roc_auc"]
650+
roc_curve = _filter_by(
651+
self.roc_curve,
652+
label=label,
653+
estimator_name=estimator_name,
654+
)
655+
656+
roc_auc = _filter_by(
657+
self.roc_auc,
658+
label=label,
659+
estimator_name=estimator_name,
660+
)["roc_auc"]
637661

638662
line_kwargs_validated = _validate_style_kwargs(
639663
line_kwargs, roc_curve_kwargs[est_idx]
640664
)
641665
line_kwargs_validated["color"] = est_color
642666
line_kwargs_validated["alpha"] = 0.6
643667

644-
roc_curve_estimator = self.roc_curve[
645-
(self.roc_curve["label"] == label)
646-
& (self.roc_curve["estimator_name"] == estimator_name)
647-
]
648-
649-
for split_index, segment in roc_curve_estimator.groupby(
650-
"split_index"
651-
):
668+
for split_index, segment in roc_curve.groupby("split_index"):
652669
if split_index == 0:
653670
label_kwargs = {
654671
"label": (
655672
f"{estimator_name} "
656-
f"(AUC = {roc_auc_estimator.mean():0.2f} +/- "
657-
f"{roc_auc_estimator.std():0.2f})"
673+
f"(AUC = {roc_auc.mean():0.2f} +/- "
674+
f"{roc_auc.std():0.2f})"
658675
)
659676
}
660677
else:

skore/tests/unit/sklearn/plot/roc_curve/conftest.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sklearn.datasets import make_classification
33
from sklearn.linear_model import LogisticRegression
44
from sklearn.model_selection import train_test_split
5+
from skore.sklearn._plot.utils import _filter_by
56

67

78
@pytest.fixture
@@ -35,21 +36,12 @@ def multiclass_classification_data_no_split():
3536
def get_roc_auc(
3637
display,
3738
label=None,
38-
split_number=None,
39+
split_index=None,
3940
estimator_name=None,
4041
) -> float:
41-
noop_filter = display.roc_auc["roc_auc"].map(lambda x: True)
42-
label_filter = (display.roc_auc["label"] == label) if label is not None else True
43-
split_number_filter = (
44-
(display.roc_auc["split_index"] == split_number)
45-
if split_number is not None
46-
else True
47-
)
48-
estimator_name_filter = (
49-
(display.roc_auc["estimator_name"] == estimator_name)
50-
if estimator_name is not None
51-
else True
52-
)
53-
return display.roc_auc[
54-
noop_filter & label_filter & split_number_filter & estimator_name_filter
55-
]["roc_auc"].iloc[0]
42+
return _filter_by(
43+
display.roc_auc,
44+
label=label,
45+
split_index=split_index,
46+
estimator_name=estimator_name,
47+
)["roc_auc"].iloc[0]

0 commit comments

Comments
 (0)