@@ -137,6 +137,7 @@ def _plot_single_estimator(
137137 * ,
138138 estimator_name : str ,
139139 pr_curve_kwargs : list [dict [str , Any ]],
140+ ax : Optional [Axes ] = None ,
140141 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
141142 """Plot precision-recall curve for a single estimator.
142143
@@ -150,6 +151,9 @@ def _plot_single_estimator(
150151 binary case, we should have a single dict. In multiclass case, we should
151152 have a list of dicts, one per class.
152153
154+ ax : matplotlib.axes.Axes, default=None
155+ The axes to plot on. If None, self.ax_ is used.
156+
153157 Returns
154158 -------
155159 ax : matplotlib.axes.Axes
@@ -165,6 +169,11 @@ def _plot_single_estimator(
165169 lines : list [Line2D ] = []
166170 line_kwargs : dict [str , Any ] = {"drawstyle" : "steps-post" }
167171
172+ # Use the provided axis or self.ax_
173+ plot_ax = ax if ax is not None else self .ax_
174+ if plot_ax is None :
175+ _ , plot_ax = plt .subplots ()
176+
168177 if self .ml_task == "binary-classification" :
169178 pos_label = cast (PositiveLabel , self .pos_label )
170179
@@ -181,7 +190,7 @@ def _plot_single_estimator(
181190 f"AP = { self .average_precision [pos_label ][0 ]:0.2f} "
182191 )
183192
184- (line ,) = self . ax_ .plot (
193+ (line ,) = plot_ax .plot (
185194 self .recall [pos_label ][0 ],
186195 self .precision [pos_label ][0 ],
187196 ** line_kwargs_validated ,
@@ -219,22 +228,23 @@ def _plot_single_estimator(
219228 f"AP = { average_precision_class :0.2f} "
220229 )
221230
222- (line ,) = self . ax_ .plot (
231+ (line ,) = plot_ax .plot (
223232 recall_class , precision_class , ** line_kwargs_validated
224233 )
225234 lines .append (line )
226235
227236 info_pos_label = None # irrelevant for multiclass
228237
229- self . ax_ .legend (bbox_to_anchor = (1.02 , 1 ), title = estimator_name )
238+ plot_ax .legend (bbox_to_anchor = (1.02 , 1 ), title = estimator_name )
230239
231- return self . ax_ , lines , info_pos_label
240+ return plot_ax , lines , info_pos_label
232241
233242 def _plot_cross_validated_estimator (
234243 self ,
235244 * ,
236245 estimator_name : str ,
237246 pr_curve_kwargs : list [dict [str , Any ]],
247+ ax : Optional [Axes ] = None ,
238248 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
239249 """Plot precision-recall curve for a cross-validated estimator.
240250
@@ -248,6 +258,9 @@ def _plot_cross_validated_estimator(
248258 precision-recall curves. The length of the list should match the number of
249259 curves to plot.
250260
261+ ax : matplotlib.axes.Axes, default=None
262+ The axes to plot on. If None, self.ax_ is used.
263+
251264 Returns
252265 -------
253266 ax : matplotlib.axes.Axes
@@ -263,6 +276,11 @@ def _plot_cross_validated_estimator(
263276 lines : list [Line2D ] = []
264277 line_kwargs : dict [str , Any ] = {"drawstyle" : "steps-post" }
265278
279+ # Use the provided axis or self.ax_
280+ plot_ax = ax if ax is not None else self .ax_
281+ if plot_ax is None :
282+ _ , plot_ax = plt .subplots ()
283+
266284 if self .ml_task == "binary-classification" :
267285 pos_label = cast (PositiveLabel , self .pos_label )
268286 for split_idx in range (len (self .precision [pos_label ])):
@@ -278,7 +296,7 @@ def _plot_cross_validated_estimator(
278296 f"(AP = { average_precision_split :0.2f} )"
279297 )
280298
281- (line ,) = self . ax_ .plot (
299+ (line ,) = plot_ax .plot (
282300 recall_split , precision_split , ** line_kwargs_validated
283301 )
284302 lines .append (line )
@@ -319,7 +337,7 @@ def _plot_cross_validated_estimator(
319337 else :
320338 line_kwargs_validated ["label" ] = None
321339
322- (line ,) = self . ax_ .plot (
340+ (line ,) = plot_ax .plot (
323341 recall_split , precision_split , ** line_kwargs_validated
324342 )
325343 lines .append (line )
@@ -328,15 +346,16 @@ def _plot_cross_validated_estimator(
328346 title = f"{ estimator_name } on $\\ bf{{{ self .data_source } }}$ set"
329347 else :
330348 title = f"{ estimator_name } on $\\ bf{{external}}$ set"
331- self . ax_ .legend (bbox_to_anchor = (1.02 , 1 ), title = title )
349+ plot_ax .legend (bbox_to_anchor = (1.02 , 1 ), title = title )
332350
333- return self . ax_ , lines , info_pos_label
351+ return plot_ax , lines , info_pos_label
334352
335353 def _plot_comparison_estimator (
336354 self ,
337355 * ,
338356 estimator_names : list [str ],
339357 pr_curve_kwargs : list [dict [str , Any ]],
358+ ax : Optional [Axes ] = None ,
340359 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
341360 """Plot precision-recall curve of several estimators.
342361
@@ -350,6 +369,9 @@ def _plot_comparison_estimator(
350369 precision-recall curves. The length of the list should match the number of
351370 curves to plot.
352371
372+ ax : matplotlib.axes.Axes, default=None
373+ The axes to plot on. If None, self.ax_ is used.
374+
353375 Returns
354376 -------
355377 ax : matplotlib.axes.Axes
@@ -365,6 +387,11 @@ def _plot_comparison_estimator(
365387 lines : list [Line2D ] = []
366388 line_kwargs : dict [str , Any ] = {"drawstyle" : "steps-post" }
367389
390+ # Use the provided axis or self.ax_
391+ plot_ax = ax if ax is not None else self .ax_
392+ if plot_ax is None :
393+ _ , plot_ax = plt .subplots ()
394+
368395 if self .ml_task == "binary-classification" :
369396 pos_label = cast (PositiveLabel , self .pos_label )
370397 for est_idx , est_name in enumerate (estimator_names ):
@@ -378,7 +405,7 @@ def _plot_comparison_estimator(
378405 line_kwargs_validated ["label" ] = (
379406 f"{ est_name } (AP = { average_precision_est :0.2f} )"
380407 )
381- (line ,) = self . ax_ .plot (
408+ (line ,) = plot_ax .plot (
382409 recall_est , precision_est , ** line_kwargs_validated
383410 )
384411 lines .append (line )
@@ -414,17 +441,17 @@ def _plot_comparison_estimator(
414441 f"(AP = { average_precision_mean :0.2f} )"
415442 )
416443
417- (line ,) = self . ax_ .plot (
444+ (line ,) = plot_ax .plot (
418445 recall_est_class , precision_est_class , ** line_kwargs_validated
419446 )
420447 lines .append (line )
421448
422- self . ax_ .legend (
449+ plot_ax .legend (
423450 bbox_to_anchor = (1.02 , 1 ),
424451 title = f"{ self .ml_task .title ()} on $\\ bf{{{ self .data_source } }}$ set" ,
425452 )
426453
427- return self . ax_ , lines , info_pos_label
454+ return plot_ax , lines , info_pos_label
428455
429456 @StyleDisplayMixin .style_plot
430457 def plot (
@@ -434,6 +461,10 @@ def plot(
434461 estimator_name : Optional [str ] = None ,
435462 pr_curve_kwargs : Optional [Union [dict [str , Any ], list [dict [str , Any ]]]] = None ,
436463 despine : bool = True ,
464+ subplots : bool = False ,
465+ nrows : Optional [int ] = None ,
466+ ncols : Optional [int ] = None ,
467+ figsize : Optional [tuple [float , float ]] = None ,
437468 ) -> None :
438469 """Plot visualization.
439470
@@ -494,8 +525,16 @@ def plot(
494525 >>> report = EstimatorReport(classifier, **split_data)
495526 >>> display = report.metrics.precision_recall()
496527 >>> display.plot(pr_curve_kwargs={"color": "tab:red"})
528+
529+ With subplots:
530+
531+ >>> display.plot(subplots=True)
497532 """
498- self .figure_ , self .ax_ = (ax .figure , ax ) if ax is not None else plt .subplots ()
533+ if ax is not None and subplots :
534+ raise ValueError (
535+ "Cannot specify both 'ax' and 'subplots=True'. "
536+ "Either provide an axes object or use subplots, but not both."
537+ )
499538
500539 if pr_curve_kwargs is None :
501540 pr_curve_kwargs = self ._default_pr_curve_kwargs
@@ -725,6 +764,7 @@ def plot(
725764 else estimator_name
726765 ),
727766 pr_curve_kwargs = pr_curve_kwargs ,
767+ ax = self .ax_ ,
728768 )
729769 elif self .report_type == "cross-validation" :
730770 self .ax_ , self .lines_ , info_pos_label = (
@@ -735,12 +775,14 @@ def plot(
735775 else estimator_name
736776 ),
737777 pr_curve_kwargs = pr_curve_kwargs ,
778+ ax = self .ax_ ,
738779 )
739780 )
740781 elif self .report_type == "comparison-estimator" :
741782 self .ax_ , self .lines_ , info_pos_label = self ._plot_comparison_estimator (
742783 estimator_names = self .estimator_names ,
743784 pr_curve_kwargs = pr_curve_kwargs ,
785+ ax = self .ax_ ,
744786 )
745787 else :
746788 raise ValueError (
@@ -765,6 +807,8 @@ def plot(
765807 if despine :
766808 _despine_matplotlib_axis (self .ax_ )
767809
810+ return self .figure_
811+
768812 @classmethod
769813 def _compute_data_for_display (
770814 cls ,
0 commit comments