Skip to content

Commit 47cd9d4

Browse files
replace _filter_by by .query
1 parent b2265df commit 47cd9d4

File tree

7 files changed

+67
-122
lines changed

7 files changed

+67
-122
lines changed

skore/src/skore/sklearn/_plot/metrics/roc_curve.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
HelpDisplayMixin,
1818
_ClassifierCurveDisplayMixin,
1919
_despine_matplotlib_axis,
20-
_filter_by,
2120
_validate_style_kwargs,
2221
sample_mpl_colormap,
2322
)
@@ -222,15 +221,11 @@ def _plot_single_estimator(
222221
)
223222

224223
for class_idx, class_label in enumerate(labels):
225-
roc_curve = _filter_by(
226-
self.roc_curve,
227-
label=class_label,
228-
)
224+
roc_curve = self.roc_curve.query(f"label == {class_label}")
229225

230-
roc_auc = _filter_by(
231-
self.roc_auc,
232-
label=class_label,
233-
)["roc_auc"].iloc[0]
226+
roc_auc = self.roc_auc.query(f"label == {class_label}")["roc_auc"].iloc[
227+
0
228+
]
234229

235230
roc_curve_kwargs_class = roc_curve_kwargs[class_idx]
236231

@@ -315,15 +310,11 @@ def _plot_cross_validated_estimator(
315310
if self.ml_task == "binary-classification":
316311
pos_label = cast(PositiveLabel, self.pos_label)
317312
for split_idx in self.roc_curve["split_index"].unique():
318-
roc_curve = _filter_by(
319-
self.roc_curve,
320-
label=pos_label,
321-
split_index=split_idx,
313+
roc_curve = self.roc_curve.query(
314+
f"label == {pos_label} & split_index == {split_idx}"
322315
)
323-
roc_auc = _filter_by(
324-
self.roc_auc,
325-
label=pos_label,
326-
split_index=split_idx,
316+
roc_auc = self.roc_auc.query(
317+
f"label == {pos_label} & split_index == {split_idx}"
327318
)["roc_auc"].iloc[0]
328319

329320
line_kwargs_validated = _validate_style_kwargs(
@@ -351,17 +342,14 @@ def _plot_cross_validated_estimator(
351342
)
352343

353344
for class_idx, class_label in enumerate(labels):
354-
roc_auc = _filter_by(
355-
self.roc_auc,
356-
label=class_label,
357-
)["roc_auc"].iloc[0]
345+
roc_auc = self.roc_auc.query(f"label == {class_label}")["roc_auc"].iloc[
346+
0
347+
]
358348
roc_curve_kwargs_class = roc_curve_kwargs[class_idx]
359349

360350
for split_idx in self.roc_curve["split_index"].unique():
361-
roc_curve_label = _filter_by(
362-
self.roc_curve,
363-
label=class_label,
364-
split_index=split_idx,
351+
roc_curve_label = self.roc_curve.query(
352+
f"label == {class_label} & split_index == {split_idx}"
365353
)
366354

367355
line_kwargs_validated = _validate_style_kwargs(
@@ -448,16 +436,12 @@ def _plot_comparison_estimator(
448436
if self.ml_task == "binary-classification":
449437
pos_label = cast(PositiveLabel, self.pos_label)
450438
for est_idx, est_name in enumerate(estimator_names):
451-
roc_curve = _filter_by(
452-
self.roc_curve,
453-
label=pos_label,
454-
estimator_name=est_name,
439+
roc_curve = self.roc_curve.query(
440+
f"label == {pos_label} & estimator_name == '{est_name}'"
455441
)
456442

457-
roc_auc = _filter_by(
458-
self.roc_auc,
459-
label=pos_label,
460-
estimator_name=est_name,
443+
roc_auc = self.roc_auc.query(
444+
f"label == {pos_label} & estimator_name == '{est_name}'"
461445
)["roc_auc"].iloc[0]
462446

463447
line_kwargs_validated = _validate_style_kwargs(
@@ -485,16 +469,12 @@ def _plot_comparison_estimator(
485469
est_color = class_colors[est_idx]
486470

487471
for class_idx, class_label in enumerate(labels):
488-
roc_curve = _filter_by(
489-
self.roc_curve,
490-
label=class_label,
491-
estimator_name=est_name,
472+
roc_curve = self.roc_curve.query(
473+
f"label == {class_label} & estimator_name == '{est_name}'"
492474
)
493475

494-
roc_auc = _filter_by(
495-
self.roc_auc,
496-
label=class_label,
497-
estimator_name=est_name,
476+
roc_auc = self.roc_auc.query(
477+
f"label == {class_label} & estimator_name == '{est_name}'"
498478
)["roc_auc"].iloc[0]
499479

500480
class_linestyle = LINESTYLE[(class_idx % len(LINESTYLE))][1]
@@ -580,16 +560,13 @@ def _plot_comparison_cross_validation(
580560
10 if len(estimator_names) < 10 else len(estimator_names),
581561
)
582562
for report_idx, estimator_name in enumerate(estimator_names):
583-
roc_curve = _filter_by(
584-
self.roc_curve,
585-
label=self.pos_label,
586-
estimator_name=estimator_name,
563+
roc_curve = self.roc_curve.query(
564+
f"label == {self.pos_label} & estimator_name == '{estimator_name}'"
587565
)
588566

589-
roc_auc = _filter_by(
590-
self.roc_auc,
591-
estimator_name=estimator_name,
592-
)["roc_auc"]
567+
roc_auc = self.roc_auc.query(f"estimator_name == '{estimator_name}'")[
568+
"roc_auc"
569+
]
593570

594571
line_kwargs_validated = _validate_style_kwargs(
595572
line_kwargs, roc_curve_kwargs[report_idx]
@@ -648,16 +625,12 @@ def _plot_comparison_cross_validation(
648625
est_color = colors[est_idx]
649626

650627
for label_idx, label in enumerate(labels):
651-
roc_curve = _filter_by(
652-
self.roc_curve,
653-
label=label,
654-
estimator_name=estimator_name,
628+
roc_curve = self.roc_curve.query(
629+
f"label == {label} & estimator_name == '{estimator_name}'"
655630
)
656631

657-
roc_auc = _filter_by(
658-
self.roc_auc,
659-
label=label,
660-
estimator_name=estimator_name,
632+
roc_auc = self.roc_auc.query(
633+
f"label == {label} & estimator_name == '{estimator_name}'"
661634
)["roc_auc"]
662635

663636
line_kwargs_validated = _validate_style_kwargs(

skore/src/skore/sklearn/_plot/utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -353,20 +353,3 @@ def sample_mpl_colormap(
353353
"""
354354
indices = np.linspace(0, 1, n)
355355
return [cmap(i) for i in indices]
356-
357-
358-
def _filter_by(
359-
df,
360-
label: Optional[PositiveLabel] = None,
361-
split_index: Optional[int] = None,
362-
estimator_name: Optional[str] = None,
363-
) -> DataFrame:
364-
noop_filter = df.iloc[:, 0].map(lambda _: True)
365-
label_filter = (df["label"] == label) if label is not None else True
366-
split_number_filter = (
367-
(df["split_index"] == split_index) if split_index is not None else True
368-
)
369-
estimator_name_filter = (
370-
(df["estimator_name"] == estimator_name) if estimator_name is not None else True
371-
)
372-
return df[noop_filter & label_filter & split_number_filter & estimator_name_filter]

skore/tests/unit/sklearn/plot/roc_curve/conftest.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from sklearn.datasets import make_classification
33
from sklearn.linear_model import LogisticRegression
44
from sklearn.model_selection import train_test_split
5-
from skore.sklearn._plot.utils import _filter_by
65

76

87
@pytest.fixture
@@ -31,17 +30,3 @@ def binary_classification_data_no_split():
3130
def multiclass_classification_data_no_split():
3231
X, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=42)
3332
return LogisticRegression(), X, y
34-
35-
36-
def get_roc_auc(
37-
display,
38-
label=None,
39-
split_index=None,
40-
estimator_name=None,
41-
) -> float:
42-
return _filter_by(
43-
display.roc_auc,
44-
label=label,
45-
split_index=split_index,
46-
estimator_name=estimator_name,
47-
)["roc_auc"].iloc[0]

skore/tests/unit/sklearn/plot/roc_curve/test_comparison_cross_validation.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.linear_model import LogisticRegression
1010
from skore import ComparisonReport, CrossValidationReport
1111
from skore.sklearn._plot.metrics.roc_curve import RocCurveDisplay
12-
from skore.sklearn._plot.utils import _filter_by, sample_mpl_colormap
12+
from skore.sklearn._plot.utils import sample_mpl_colormap
1313

1414

1515
def test_binary_classification(pyplot):
@@ -37,10 +37,8 @@ def test_binary_classification(pyplot):
3737
for i, estimator_name in enumerate(report.report_names_):
3838
roc_curve_mpl = display.lines_[i * n_splits]
3939
assert isinstance(roc_curve_mpl, Line2D)
40-
auc = _filter_by(
41-
display.roc_auc,
42-
label=pos_label,
43-
estimator_name=estimator_name,
40+
auc = display.roc_auc.query(
41+
f"label == {pos_label} & estimator_name == '{estimator_name}'"
4442
)["roc_auc"]
4543

4644
assert roc_curve_mpl.get_label() == (
@@ -95,10 +93,8 @@ def test_multiclass(pyplot):
9593
roc_curve_mpl = display.lines_[i * n_splits]
9694
assert isinstance(roc_curve_mpl, Line2D)
9795

98-
auc = _filter_by(
99-
display.roc_auc,
100-
label=label,
101-
estimator_name=estimator_name,
96+
auc = display.roc_auc.query(
97+
f"label == {label} & estimator_name == '{estimator_name}'"
10298
)["roc_auc"]
10399

104100
assert roc_curve_mpl.get_label() == (

skore/tests/unit/sklearn/plot/roc_curve/test_comparison_estimator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from skore.sklearn._plot import RocCurveDisplay
66
from skore.sklearn._plot.utils import sample_mpl_colormap
77

8-
from .conftest import get_roc_auc
9-
108

119
def test_binary_classification(pyplot, binary_classification_data):
1210
"""Check the attributes and default plotting behaviour of the ROC curve plot with
@@ -63,9 +61,9 @@ def test_binary_classification(pyplot, binary_classification_data):
6361
zip(report.report_names_, display.lines_)
6462
):
6563
assert isinstance(line, mpl.lines.Line2D)
66-
roc_auc_class = get_roc_auc(
67-
display, label=display.pos_label, estimator_name=estimator_name
68-
)
64+
roc_auc_class = display.roc_auc.query(
65+
f"label == {display.pos_label} & estimator_name == '{estimator_name}'"
66+
)["roc_auc"].iloc[0]
6967
assert line.get_label() == (f"{estimator_name} (AUC = {roc_auc_class:0.2f})")
7068
assert mpl.colors.to_rgba(line.get_color()) == expected_colors[idx]
7169

@@ -144,11 +142,9 @@ def test_multiclass_classification(pyplot, multiclass_classification_data):
144142
for class_label_idx, class_label in enumerate(class_labels):
145143
roc_curve_mpl = display.lines_[idx * len(class_labels) + class_label_idx]
146144
assert isinstance(roc_curve_mpl, mpl.lines.Line2D)
147-
roc_auc_class = get_roc_auc(
148-
display,
149-
label=class_label,
150-
estimator_name=estimator_name,
151-
)
145+
roc_auc_class = display.roc_auc.query(
146+
f"label == {class_label} & estimator_name == '{estimator_name}'"
147+
)["roc_auc"].iloc[0]
152148
assert roc_curve_mpl.get_label() == (
153149
f"{estimator_name} - {str(class_label).title()} "
154150
f"(AUC = {roc_auc_class:0.2f})"

skore/tests/unit/sklearn/plot/roc_curve/test_cross_validation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from skore.sklearn._plot import RocCurveDisplay
66
from skore.sklearn._plot.utils import sample_mpl_colormap
77

8-
from .conftest import get_roc_auc
9-
108

119
@pytest.mark.parametrize("data_source", ["train", "test", "X_y"])
1210
def test_binary_classification(
@@ -60,7 +58,9 @@ def test_binary_classification(
6058
expected_colors = sample_mpl_colormap(pyplot.cm.tab10, 10)
6159
for split_idx, line in enumerate(display.lines_):
6260
assert isinstance(line, mpl.lines.Line2D)
63-
roc_auc_split = get_roc_auc(display, label=pos_label, split_index=split_idx)
61+
roc_auc_split = display.roc_auc.query(
62+
f"label == {pos_label} & split_index == {split_idx}"
63+
)["roc_auc"].iloc[0]
6464
assert line.get_label() == (
6565
f"Estimator of fold #{split_idx + 1} (AUC = {roc_auc_split:0.2f})"
6666
)
@@ -139,7 +139,9 @@ def test_multiclass_classification(
139139
roc_curve_mpl = display.lines_[class_label * cv + split_idx]
140140
assert isinstance(roc_curve_mpl, mpl.lines.Line2D)
141141
if split_idx == 0:
142-
roc_auc_class = get_roc_auc(display, label=class_label)
142+
roc_auc_class = display.roc_auc.query(f"label == {class_label}")[
143+
"roc_auc"
144+
].iloc[0]
143145
assert roc_curve_mpl.get_label() == (
144146
f"{str(class_label).title()} "
145147
f"(AUC = {np.mean(roc_auc_class):0.2f}"

skore/tests/unit/sklearn/plot/roc_curve/test_estimator.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from skore.sklearn._plot import RocCurveDisplay
66
from skore.sklearn._plot.utils import sample_mpl_colormap
77

8-
from .conftest import get_roc_auc
9-
108

119
def test_binary_classification(pyplot, binary_classification_data):
1210
"""Check the attributes and default plotting behaviour of the ROC curve plot with
@@ -48,7 +46,10 @@ def test_binary_classification(pyplot, binary_classification_data):
4846
assert len(display.lines_) == 1
4947
roc_curve_mpl = display.lines_[0]
5048
assert isinstance(roc_curve_mpl, mpl.lines.Line2D)
51-
assert roc_curve_mpl.get_label() == f"Test set (AUC = {get_roc_auc(display):0.2f})"
49+
assert (
50+
roc_curve_mpl.get_label()
51+
== f"Test set (AUC = {display.roc_auc['roc_auc'].iloc[0]:0.2f})"
52+
)
5253
assert roc_curve_mpl.get_color() == "#1f77b4" # tab:blue in hex
5354

5455
assert isinstance(display.chance_level_, mpl.lines.Line2D)
@@ -107,7 +108,9 @@ def test_multiclass_classification(pyplot, multiclass_classification_data):
107108
for class_label, expected_color in zip(estimator.classes_, default_colors):
108109
roc_curve_mpl = display.lines_[class_label]
109110
assert isinstance(roc_curve_mpl, mpl.lines.Line2D)
110-
roc_auc_class = get_roc_auc(display, label=class_label)
111+
roc_auc_class = display.roc_auc.query(f"label == {class_label}")[
112+
"roc_auc"
113+
].iloc[0]
111114
assert roc_curve_mpl.get_label() == (
112115
f"{str(class_label).title()} - test set (AUC = {roc_auc_class:0.2f})"
113116
)
@@ -139,12 +142,15 @@ def test_data_source_binary_classification(pyplot, binary_classification_data):
139142
display.plot()
140143
assert (
141144
display.lines_[0].get_label()
142-
== f"Train set (AUC = {get_roc_auc(display):0.2f})"
145+
== f"Train set (AUC = {display.roc_auc['roc_auc'].iloc[0]:0.2f})"
143146
)
144147

145148
display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
146149
display.plot()
147-
assert display.lines_[0].get_label() == f"AUC = {get_roc_auc(display):0.2f}"
150+
assert (
151+
display.lines_[0].get_label()
152+
== f"AUC = {display.roc_auc['roc_auc'].iloc[0]:0.2f}"
153+
)
148154

149155

150156
def test_data_source_multiclass_classification(pyplot, multiclass_classification_data):
@@ -156,17 +162,21 @@ def test_data_source_multiclass_classification(pyplot, multiclass_classification
156162
display = report.metrics.roc(data_source="train")
157163
display.plot()
158164
for class_label in estimator.classes_:
165+
roc_auc_class = display.roc_auc.query(f"label == {class_label}")[
166+
"roc_auc"
167+
].iloc[0]
159168
assert display.lines_[class_label].get_label() == (
160-
f"{str(class_label).title()} - train set "
161-
f"(AUC = {get_roc_auc(display, label=class_label):0.2f})"
169+
f"{str(class_label).title()} - train set (AUC = {roc_auc_class:0.2f})"
162170
)
163171

164172
display = report.metrics.roc(data_source="X_y", X=X_train, y=y_train)
165173
display.plot()
166174
for class_label in estimator.classes_:
175+
roc_auc_class = display.roc_auc.query(f"label == {class_label}")[
176+
"roc_auc"
177+
].iloc[0]
167178
assert display.lines_[class_label].get_label() == (
168-
f"{str(class_label).title()} - "
169-
f"AUC = {get_roc_auc(display, label=class_label):0.2f}"
179+
f"{str(class_label).title()} - AUC = {roc_auc_class:0.2f}"
170180
)
171181

172182

0 commit comments

Comments
 (0)