Skip to content

Commit 7e8d4cc

Browse files
test passing kwargs in comparison[cv] case; fix bugs
1 parent 3cd06e5 commit 7e8d4cc

File tree

2 files changed

+111
-12
lines changed

2 files changed

+111
-12
lines changed

skore/src/skore/sklearn/_plot/metrics/roc_curve.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _plot_single_estimator(
153153
estimator_name: str,
154154
roc_curve_kwargs: list[dict[str, Any]],
155155
plot_chance_level: bool = True,
156-
chance_level_kwargs: Optional[dict[str, Any]] = None,
156+
chance_level_kwargs: Optional[dict[str, Any]],
157157
) -> tuple[Axes, list[Line2D], Union[str, None]]:
158158
"""Plot ROC curve for a single estimator.
159159
@@ -272,7 +272,7 @@ def _plot_cross_validated_estimator(
272272
estimator_name: str,
273273
roc_curve_kwargs: list[dict[str, Any]],
274274
plot_chance_level: bool = True,
275-
chance_level_kwargs: Optional[dict[str, Any]] = None,
275+
chance_level_kwargs: Optional[dict[str, Any]],
276276
) -> tuple[Axes, list[Line2D], Union[str, None]]:
277277
"""Plot ROC curve for a cross-validated estimator.
278278
@@ -398,7 +398,7 @@ def _plot_comparison_estimator(
398398
estimator_names: list[str],
399399
roc_curve_kwargs: list[dict[str, Any]],
400400
plot_chance_level: bool = True,
401-
chance_level_kwargs: Optional[dict[str, Any]] = None,
401+
chance_level_kwargs: Optional[dict[str, Any]],
402402
) -> tuple[Axes, list[Line2D], Union[str, None]]:
403403
"""Plot ROC curve of several estimators.
404404
@@ -518,7 +518,7 @@ def _plot_comparison_cross_validation(
518518
estimator_names: list[str],
519519
roc_curve_kwargs: list[dict[str, Any]],
520520
plot_chance_level: bool = True,
521-
chance_level_kwargs: Optional[dict[str, Any]] = None,
521+
chance_level_kwargs: Optional[dict[str, Any]],
522522
) -> tuple[Axes, list[Line2D], Union[str, None]]:
523523
"""Plot ROC curve of several cross-validations.
524524
@@ -568,11 +568,11 @@ def _plot_comparison_cross_validation(
568568
"roc_auc"
569569
]
570570

571+
line_kwargs["color"] = colors[report_idx]
572+
line_kwargs["alpha"] = 0.6
571573
line_kwargs_validated = _validate_style_kwargs(
572574
line_kwargs, roc_curve_kwargs[report_idx]
573575
)
574-
line_kwargs_validated["color"] = colors[report_idx]
575-
line_kwargs_validated["alpha"] = 0.6
576576

577577
for split_index, segment in roc_curve.groupby("split_index"):
578578
if split_index == 0:
@@ -620,6 +620,7 @@ def _plot_comparison_cross_validation(
620620
colormaps.get_cmap("tab10"),
621621
10 if len(estimator_names) < 10 else len(estimator_names),
622622
)
623+
idx = 0
623624

624625
for est_idx, estimator_name in enumerate(estimator_names):
625626
est_color = colors[est_idx]
@@ -633,12 +634,6 @@ def _plot_comparison_cross_validation(
633634
f"label == {label} & estimator_name == '{estimator_name}'"
634635
)["roc_auc"]
635636

636-
line_kwargs_validated = _validate_style_kwargs(
637-
line_kwargs, roc_curve_kwargs[est_idx]
638-
)
639-
line_kwargs_validated["color"] = est_color
640-
line_kwargs_validated["alpha"] = 0.6
641-
642637
for split_index, segment in roc_curve.groupby("split_index"):
643638
if split_index == 0:
644639
label_kwargs = {
@@ -651,13 +646,21 @@ def _plot_comparison_cross_validation(
651646
else:
652647
label_kwargs = {}
653648

649+
line_kwargs["color"] = est_color
650+
line_kwargs["alpha"] = 0.6
651+
line_kwargs_validated = _validate_style_kwargs(
652+
line_kwargs, roc_curve_kwargs[idx]
653+
)
654+
654655
(line,) = self.ax_[label_idx].plot(
655656
segment["fpr"],
656657
segment["tpr"],
657658
**(line_kwargs_validated | label_kwargs),
658659
)
659660
lines.append(line)
660661

662+
idx = idx + 1
663+
661664
info_pos_label = f"\n(Positive label: {label})"
662665
_set_axis_labels(self.ax_[label_idx], info_pos_label)
663666

@@ -784,6 +787,8 @@ def plot(
784787
self._plot_comparison_cross_validation(
785788
estimator_names=self.roc_auc["estimator_name"].unique(),
786789
roc_curve_kwargs=roc_curve_kwargs,
790+
plot_chance_level=plot_chance_level,
791+
chance_level_kwargs=chance_level_kwargs,
787792
)
788793
)
789794
else:

skore/tests/unit/sklearn/plot/roc_curve/test_comparison_cross_validation.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,97 @@ def test_multiclass_classification(pyplot, multiclass_classification_report):
137137
assert ax.get_adjustable() == "box"
138138
assert ax.get_aspect() in ("equal", 1.0)
139139
assert ax.get_xlim() == ax.get_ylim() == (-0.01, 1.01)
140+
141+
142+
def test_binary_classification_wrong_kwargs(pyplot, binary_classification_report):
143+
"""Check that we raise a proper error message when passing an inappropriate
144+
value for the `roc_curve_kwargs` argument."""
145+
report = binary_classification_report
146+
display = report.metrics.roc()
147+
err_msg = (
148+
"You intend to plot multiple curves. We expect `roc_curve_kwargs` to be a "
149+
"list of dictionaries with the same length as the number of curves. "
150+
"Got 2 instead of 10."
151+
)
152+
with pytest.raises(ValueError, match=err_msg):
153+
display.plot(roc_curve_kwargs=[{}, {}])
154+
155+
156+
@pytest.mark.parametrize("roc_curve_kwargs", [[{"color": "red"}] * 10])
157+
def test_binary_classification_kwargs(
158+
pyplot, binary_classification_report, roc_curve_kwargs
159+
):
160+
"""Check that we can pass keyword arguments to the ROC curve plot."""
161+
report = binary_classification_report
162+
display = report.metrics.roc()
163+
display.plot(
164+
roc_curve_kwargs=roc_curve_kwargs, chance_level_kwargs={"color": "blue"}
165+
)
166+
assert display.lines_[0].get_color() == "red"
167+
assert display.chance_level_.get_color() == "blue"
168+
169+
# check the `.style` display setter
170+
display.plot() # default style
171+
assert display.lines_[0].get_color() == (
172+
np.float64(0.12156862745098039),
173+
np.float64(0.4666666666666667),
174+
np.float64(0.7058823529411765),
175+
np.float64(1.0),
176+
)
177+
assert display.chance_level_.get_color() == "k"
178+
179+
display.set_style(
180+
roc_curve_kwargs=roc_curve_kwargs, chance_level_kwargs={"color": "blue"}
181+
)
182+
display.plot()
183+
assert display.lines_[0].get_color() == "red"
184+
assert display.chance_level_.get_color() == "blue"
185+
186+
# overwrite the style that was set above
187+
display.plot(
188+
roc_curve_kwargs=[{"color": "#1f77b4"}] * 10,
189+
chance_level_kwargs={"color": "red"},
190+
)
191+
assert display.lines_[0].get_color() == "#1f77b4"
192+
assert display.chance_level_.get_color() == "red"
193+
194+
195+
def test_multiclass_classification_wrong_kwargs(
196+
pyplot, multiclass_classification_report
197+
):
198+
"""Check that we raise a proper error message when passing an inappropriate
199+
value for the `roc_curve_kwargs` argument."""
200+
report = multiclass_classification_report
201+
display = report.metrics.roc()
202+
err_msg = "You intend to plot multiple curves."
203+
with pytest.raises(ValueError, match=err_msg):
204+
display.plot(roc_curve_kwargs=[{}, {}])
205+
206+
with pytest.raises(ValueError, match=err_msg):
207+
display.plot(roc_curve_kwargs={})
208+
209+
210+
def test_multiclass_classification_kwargs(pyplot, multiclass_classification_report):
211+
"""Check that we can pass keyword arguments to the ROC curve plot for
212+
multiclass classification."""
213+
report = multiclass_classification_report
214+
display = report.metrics.roc()
215+
display.plot(
216+
roc_curve_kwargs=(
217+
[{"color": "red"}] * 10
218+
+ [{"color": "blue"}] * 10
219+
+ [{"color": "green"}] * 10
220+
),
221+
chance_level_kwargs={"color": "blue"},
222+
)
223+
assert display.lines_[0].get_color() == "red"
224+
assert display.lines_[10].get_color() == "blue"
225+
assert display.lines_[20].get_color() == "green"
226+
assert display.chance_level_[0].get_color() == "blue"
227+
228+
display.plot(plot_chance_level=False)
229+
assert display.chance_level_ is None
230+
231+
display.plot(despine=False)
232+
assert display.ax_[0].spines["top"].get_visible()
233+
assert display.ax_[0].spines["right"].get_visible()

0 commit comments

Comments
 (0)