Skip to content

Commit 4941dab

Browse files
committed
refactor: improve subplot sharing for consistent axis scales
1 parent 587463f commit 4941dab

File tree

3 files changed

+179
-45
lines changed

3 files changed

+179
-45
lines changed

skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _plot_single_estimator(
137137
*,
138138
estimator_name: str,
139139
pr_curve_kwargs: list[dict[str, Any]],
140+
ax: Optional[Axes] = None,
140141
) -> tuple[Axes, list[Line2D], Union[str, None]]:
141142
"""Plot precision-recall curve for a single estimator.
142143
@@ -150,6 +151,9 @@ def _plot_single_estimator(
150151
binary case, we should have a single dict. In multiclass case, we should
151152
have a list of dicts, one per class.
152153
154+
ax : matplotlib.axes.Axes, default=None
155+
The axes to plot on. If None, self.ax_ is used.
156+
153157
Returns
154158
-------
155159
ax : matplotlib.axes.Axes
@@ -165,6 +169,11 @@ def _plot_single_estimator(
165169
lines: list[Line2D] = []
166170
line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"}
167171

172+
# Use the provided axis or self.ax_
173+
plot_ax = ax if ax is not None else self.ax_
174+
if plot_ax is None:
175+
_, plot_ax = plt.subplots()
176+
168177
if self.ml_task == "binary-classification":
169178
pos_label = cast(PositiveLabel, self.pos_label)
170179

@@ -181,7 +190,7 @@ def _plot_single_estimator(
181190
f"AP = {self.average_precision[pos_label][0]:0.2f}"
182191
)
183192

184-
(line,) = self.ax_.plot(
193+
(line,) = plot_ax.plot(
185194
self.recall[pos_label][0],
186195
self.precision[pos_label][0],
187196
**line_kwargs_validated,
@@ -219,22 +228,23 @@ def _plot_single_estimator(
219228
f"AP = {average_precision_class:0.2f}"
220229
)
221230

222-
(line,) = self.ax_.plot(
231+
(line,) = plot_ax.plot(
223232
recall_class, precision_class, **line_kwargs_validated
224233
)
225234
lines.append(line)
226235

227236
info_pos_label = None # irrelevant for multiclass
228237

229-
self.ax_.legend(bbox_to_anchor=(1.02, 1), title=estimator_name)
238+
plot_ax.legend(bbox_to_anchor=(1.02, 1), title=estimator_name)
230239

231-
return self.ax_, lines, info_pos_label
240+
return plot_ax, lines, info_pos_label
232241

233242
def _plot_cross_validated_estimator(
234243
self,
235244
*,
236245
estimator_name: str,
237246
pr_curve_kwargs: list[dict[str, Any]],
247+
ax: Optional[Axes] = None,
238248
) -> tuple[Axes, list[Line2D], Union[str, None]]:
239249
"""Plot precision-recall curve for a cross-validated estimator.
240250
@@ -248,6 +258,9 @@ def _plot_cross_validated_estimator(
248258
precision-recall curves. The length of the list should match the number of
249259
curves to plot.
250260
261+
ax : matplotlib.axes.Axes, default=None
262+
The axes to plot on. If None, self.ax_ is used.
263+
251264
Returns
252265
-------
253266
ax : matplotlib.axes.Axes
@@ -263,6 +276,11 @@ def _plot_cross_validated_estimator(
263276
lines: list[Line2D] = []
264277
line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"}
265278

279+
# Use the provided axis or self.ax_
280+
plot_ax = ax if ax is not None else self.ax_
281+
if plot_ax is None:
282+
_, plot_ax = plt.subplots()
283+
266284
if self.ml_task == "binary-classification":
267285
pos_label = cast(PositiveLabel, self.pos_label)
268286
for split_idx in range(len(self.precision[pos_label])):
@@ -278,7 +296,7 @@ def _plot_cross_validated_estimator(
278296
f"(AP = {average_precision_split:0.2f})"
279297
)
280298

281-
(line,) = self.ax_.plot(
299+
(line,) = plot_ax.plot(
282300
recall_split, precision_split, **line_kwargs_validated
283301
)
284302
lines.append(line)
@@ -319,7 +337,7 @@ def _plot_cross_validated_estimator(
319337
else:
320338
line_kwargs_validated["label"] = None
321339

322-
(line,) = self.ax_.plot(
340+
(line,) = plot_ax.plot(
323341
recall_split, precision_split, **line_kwargs_validated
324342
)
325343
lines.append(line)
@@ -328,15 +346,16 @@ def _plot_cross_validated_estimator(
328346
title = f"{estimator_name} on $\\bf{{{self.data_source}}}$ set"
329347
else:
330348
title = f"{estimator_name} on $\\bf{{external}}$ set"
331-
self.ax_.legend(bbox_to_anchor=(1.02, 1), title=title)
349+
plot_ax.legend(bbox_to_anchor=(1.02, 1), title=title)
332350

333-
return self.ax_, lines, info_pos_label
351+
return plot_ax, lines, info_pos_label
334352

335353
def _plot_comparison_estimator(
336354
self,
337355
*,
338356
estimator_names: list[str],
339357
pr_curve_kwargs: list[dict[str, Any]],
358+
ax: Optional[Axes] = None,
340359
) -> tuple[Axes, list[Line2D], Union[str, None]]:
341360
"""Plot precision-recall curve of several estimators.
342361
@@ -350,6 +369,9 @@ def _plot_comparison_estimator(
350369
precision-recall curves. The length of the list should match the number of
351370
curves to plot.
352371
372+
ax : matplotlib.axes.Axes, default=None
373+
The axes to plot on. If None, self.ax_ is used.
374+
353375
Returns
354376
-------
355377
ax : matplotlib.axes.Axes
@@ -365,6 +387,11 @@ def _plot_comparison_estimator(
365387
lines: list[Line2D] = []
366388
line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"}
367389

390+
# Use the provided axis or self.ax_
391+
plot_ax = ax if ax is not None else self.ax_
392+
if plot_ax is None:
393+
_, plot_ax = plt.subplots()
394+
368395
if self.ml_task == "binary-classification":
369396
pos_label = cast(PositiveLabel, self.pos_label)
370397
for est_idx, est_name in enumerate(estimator_names):
@@ -378,7 +405,7 @@ def _plot_comparison_estimator(
378405
line_kwargs_validated["label"] = (
379406
f"{est_name} (AP = {average_precision_est:0.2f})"
380407
)
381-
(line,) = self.ax_.plot(
408+
(line,) = plot_ax.plot(
382409
recall_est, precision_est, **line_kwargs_validated
383410
)
384411
lines.append(line)
@@ -414,17 +441,17 @@ def _plot_comparison_estimator(
414441
f"(AP = {average_precision_mean:0.2f})"
415442
)
416443

417-
(line,) = self.ax_.plot(
444+
(line,) = plot_ax.plot(
418445
recall_est_class, precision_est_class, **line_kwargs_validated
419446
)
420447
lines.append(line)
421448

422-
self.ax_.legend(
449+
plot_ax.legend(
423450
bbox_to_anchor=(1.02, 1),
424451
title=f"{self.ml_task.title()} on $\\bf{{{self.data_source}}}$ set",
425452
)
426453

427-
return self.ax_, lines, info_pos_label
454+
return plot_ax, lines, info_pos_label
428455

429456
@StyleDisplayMixin.style_plot
430457
def plot(
@@ -434,6 +461,10 @@ def plot(
434461
estimator_name: Optional[str] = None,
435462
pr_curve_kwargs: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
436463
despine: bool = True,
464+
subplots: bool = False,
465+
nrows: Optional[int] = None,
466+
ncols: Optional[int] = None,
467+
figsize: Optional[tuple[float, float]] = None,
437468
) -> None:
438469
"""Plot visualization.
439470
@@ -494,8 +525,16 @@ def plot(
494525
>>> report = EstimatorReport(classifier, **split_data)
495526
>>> display = report.metrics.precision_recall()
496527
>>> display.plot(pr_curve_kwargs={"color": "tab:red"})
528+
529+
With subplots:
530+
531+
>>> display.plot(subplots=True)
497532
"""
498-
self.figure_, self.ax_ = (ax.figure, ax) if ax is not None else plt.subplots()
533+
if ax is not None and subplots:
534+
raise ValueError(
535+
"Cannot specify both 'ax' and 'subplots=True'. "
536+
"Either provide an axes object or use subplots, but not both."
537+
)
499538

500539
if pr_curve_kwargs is None:
501540
pr_curve_kwargs = self._default_pr_curve_kwargs
@@ -725,6 +764,7 @@ def plot(
725764
else estimator_name
726765
),
727766
pr_curve_kwargs=pr_curve_kwargs,
767+
ax=self.ax_,
728768
)
729769
elif self.report_type == "cross-validation":
730770
self.ax_, self.lines_, info_pos_label = (
@@ -735,12 +775,14 @@ def plot(
735775
else estimator_name
736776
),
737777
pr_curve_kwargs=pr_curve_kwargs,
778+
ax=self.ax_,
738779
)
739780
)
740781
elif self.report_type == "comparison-estimator":
741782
self.ax_, self.lines_, info_pos_label = self._plot_comparison_estimator(
742783
estimator_names=self.estimator_names,
743784
pr_curve_kwargs=pr_curve_kwargs,
785+
ax=self.ax_,
744786
)
745787
else:
746788
raise ValueError(
@@ -765,6 +807,8 @@ def plot(
765807
if despine:
766808
_despine_matplotlib_axis(self.ax_)
767809

810+
return self.figure_
811+
768812
@classmethod
769813
def _compute_data_for_display(
770814
cls,

0 commit comments

Comments
 (0)