22from typing import Any , Literal , Optional , Union , cast
33
44import matplotlib .pyplot as plt
5- import numpy as np
65from matplotlib import colormaps
76from matplotlib .axes import Axes
87from matplotlib .lines import Line2D
1817 HelpDisplayMixin ,
1918 _ClassifierCurveDisplayMixin ,
2019 _despine_matplotlib_axis ,
20+ _filter_by ,
2121 _validate_style_kwargs ,
2222 sample_mpl_colormap ,
2323)
@@ -221,30 +221,38 @@ def _plot_single_estimator(
221221 )
222222
223223 for class_idx , class_label in enumerate (labels ):
224- roc_curve_label = self .roc_curve [self .roc_curve ["label" ] == class_label ]
225- fpr_class = roc_curve_label ["fpr" ]
226- tpr_class = roc_curve_label ["tpr" ]
227- roc_auc_class = self .roc_auc [self .roc_auc ["label" ] == class_label ][
228- "roc_auc"
229- ].iloc [0 ]
224+ roc_curve = _filter_by (
225+ self .roc_curve ,
226+ label = class_label ,
227+ )
228+
229+ roc_auc = _filter_by (
230+ self .roc_auc ,
231+ label = class_label ,
232+ )["roc_auc" ].iloc [0 ]
233+
230234 roc_curve_kwargs_class = roc_curve_kwargs [class_idx ]
231235
232236 default_line_kwargs : dict [str , Any ] = {"color" : class_colors [class_idx ]}
233237 if self .data_source in ("train" , "test" ):
234238 default_line_kwargs ["label" ] = (
235239 f"{ str (class_label ).title ()} - { self .data_source } "
236- f"set (AUC = { roc_auc_class :0.2f} )"
240+ f"set (AUC = { roc_auc :0.2f} )"
237241 )
238242 else : # data_source in (None, "X_y")
239243 default_line_kwargs ["label" ] = (
240- f"{ str (class_label ).title ()} - AUC = { roc_auc_class :0.2f} "
244+ f"{ str (class_label ).title ()} - AUC = { roc_auc :0.2f} "
241245 )
242246
243247 line_kwargs = _validate_style_kwargs (
244248 default_line_kwargs , roc_curve_kwargs_class
245249 )
246250
247- (line ,) = self .ax_ .plot (fpr_class , tpr_class , ** line_kwargs )
251+ (line ,) = self .ax_ .plot (
252+ roc_curve ["fpr" ],
253+ roc_curve ["tpr" ],
254+ ** line_kwargs ,
255+ )
248256 lines .append (line )
249257
250258 info_pos_label = None # irrelevant for multiclass
@@ -306,27 +314,29 @@ def _plot_cross_validated_estimator(
306314 if self .ml_task == "binary-classification" :
307315 pos_label = cast (PositiveLabel , self .pos_label )
308316 for split_idx in self .roc_curve ["split_index" ].unique ():
309- fpr_split = self .roc_curve [
310- (self .roc_curve ["label" ] == pos_label )
311- & (self .roc_curve ["split_index" ] == split_idx )
312- ]["fpr" ]
313- tpr_split = self .roc_curve [
314- (self .roc_curve ["label" ] == pos_label )
315- & (self .roc_curve ["split_index" ] == split_idx )
316- ]["tpr" ]
317- roc_auc_split = self .roc_auc [
318- (self .roc_auc ["label" ] == pos_label )
319- & (self .roc_auc ["split_index" ] == split_idx )
320- ]["roc_auc" ].iloc [0 ]
317+ roc_curve = _filter_by (
318+ self .roc_curve ,
319+ label = pos_label ,
320+ split_index = split_idx ,
321+ )
322+ roc_auc = _filter_by (
323+ self .roc_auc ,
324+ label = pos_label ,
325+ split_index = split_idx ,
326+ )["roc_auc" ].iloc [0 ]
321327
322328 line_kwargs_validated = _validate_style_kwargs (
323329 line_kwargs , roc_curve_kwargs [split_idx ]
324330 )
325331 line_kwargs_validated ["label" ] = (
326- f"Estimator of fold #{ split_idx + 1 } (AUC = { roc_auc_split :0.2f} )"
332+ f"Estimator of fold #{ split_idx + 1 } (AUC = { roc_auc :0.2f} )"
327333 )
328334
329- (line ,) = self .ax_ .plot (fpr_split , tpr_split , ** line_kwargs_validated )
335+ (line ,) = self .ax_ .plot (
336+ roc_curve ["fpr" ],
337+ roc_curve ["tpr" ],
338+ ** line_kwargs_validated ,
339+ )
330340 lines .append (line )
331341
332342 info_pos_label = (
@@ -340,20 +350,18 @@ def _plot_cross_validated_estimator(
340350 )
341351
342352 for class_idx , class_label in enumerate (labels ):
343- roc_auc_class = self .roc_auc [self .roc_auc ["label" ] == class_label ][
344- "roc_auc"
345- ].iloc [0 ]
353+ roc_auc = _filter_by (
354+ self .roc_auc ,
355+ label = class_label ,
356+ )["roc_auc" ].iloc [0 ]
346357 roc_curve_kwargs_class = roc_curve_kwargs [class_idx ]
347358
348359 for split_idx in self .roc_curve ["split_index" ].unique ():
349- roc_curve_label = self .roc_curve [
350- (self .roc_curve ["label" ] == class_label )
351- & (self .roc_curve ["split_index" ] == split_idx )
352- ]
353- fpr_split = roc_curve_label ["fpr" ]
354- tpr_split = roc_curve_label ["tpr" ]
355- roc_auc_mean = np .mean (roc_auc_class )
356- roc_auc_std = np .std (roc_auc_class )
360+ roc_curve_label = _filter_by (
361+ self .roc_curve ,
362+ label = class_label ,
363+ split_index = split_idx ,
364+ )
357365
358366 line_kwargs_validated = _validate_style_kwargs (
359367 {
@@ -365,14 +373,16 @@ def _plot_cross_validated_estimator(
365373 if split_idx == 0 :
366374 line_kwargs_validated ["label" ] = (
367375 f"{ str (class_label ).title ()} "
368- f"(AUC = { roc_auc_mean :0.2f} +/- "
369- f"{ roc_auc_std :0.2f} )"
376+ f"(AUC = { roc_auc . mean () :0.2f} +/- "
377+ f"{ roc_auc . std () :0.2f} )"
370378 )
371379 else :
372380 line_kwargs_validated ["label" ] = None
373381
374382 (line ,) = self .ax_ .plot (
375- fpr_split , tpr_split , ** line_kwargs_validated
383+ roc_curve_label ["fpr" ],
384+ roc_curve_label ["tpr" ],
385+ ** line_kwargs_validated ,
376386 )
377387 lines .append (line )
378388
@@ -437,24 +447,27 @@ def _plot_comparison_estimator(
437447 if self .ml_task == "binary-classification" :
438448 pos_label = cast (PositiveLabel , self .pos_label )
439449 for est_idx , est_name in enumerate (estimator_names ):
440- roc_curve_estimator = self .roc_curve [
441- (self .roc_curve ["label" ] == pos_label )
442- & (self .roc_curve ["estimator_name" ] == est_name )
443- ]
444- fpr_est = roc_curve_estimator ["fpr" ]
445- tpr_est = roc_curve_estimator ["tpr" ]
446- roc_auc_est = self .roc_auc [
447- (self .roc_auc ["label" ] == pos_label )
448- & (self .roc_auc ["estimator_name" ] == est_name )
449- ]["roc_auc" ].iloc [0 ]
450+ roc_curve = _filter_by (
451+ self .roc_curve ,
452+ label = pos_label ,
453+ estimator_name = est_name ,
454+ )
455+
456+ roc_auc = _filter_by (
457+ self .roc_auc ,
458+ label = pos_label ,
459+ estimator_name = est_name ,
460+ )["roc_auc" ].iloc [0 ]
450461
451462 line_kwargs_validated = _validate_style_kwargs (
452463 line_kwargs , roc_curve_kwargs [est_idx ]
453464 )
454- line_kwargs_validated ["label" ] = (
455- f"{ est_name } (AUC = { roc_auc_est :0.2f} )"
465+ line_kwargs_validated ["label" ] = f"{ est_name } (AUC = { roc_auc :0.2f} )"
466+ (line ,) = self .ax_ .plot (
467+ roc_curve ["fpr" ],
468+ roc_curve ["tpr" ],
469+ ** line_kwargs_validated ,
456470 )
457- (line ,) = self .ax_ .plot (fpr_est , tpr_est , ** line_kwargs_validated )
458471 lines .append (line )
459472
460473 info_pos_label = (
@@ -471,16 +484,18 @@ def _plot_comparison_estimator(
471484 est_color = class_colors [est_idx ]
472485
473486 for class_idx , class_label in enumerate (labels ):
474- roc_curve_estimator = self .roc_curve [
475- (self .roc_curve ["label" ] == class_label )
476- & (self .roc_curve ["estimator_name" ] == est_name )
477- ]
478- fpr_est_class = roc_curve_estimator ["fpr" ]
479- tpr_est_class = roc_curve_estimator ["tpr" ]
480- roc_auc_mean = self .roc_auc [
481- (self .roc_auc ["label" ] == class_label )
482- & (self .roc_auc ["estimator_name" ] == est_name )
483- ]["roc_auc" ].iloc [0 ]
487+ roc_curve = _filter_by (
488+ self .roc_curve ,
489+ label = class_label ,
490+ estimator_name = est_name ,
491+ )
492+
493+ roc_auc = _filter_by (
494+ self .roc_auc ,
495+ label = class_label ,
496+ estimator_name = est_name ,
497+ )["roc_auc" ].iloc [0 ]
498+
484499 class_linestyle = LINESTYLE [(class_idx % len (LINESTYLE ))][1 ]
485500
486501 line_kwargs ["color" ] = est_color
@@ -492,11 +507,11 @@ def _plot_comparison_estimator(
492507 )
493508 line_kwargs_validated ["label" ] = (
494509 f"{ est_name } - { str (class_label ).title ()} "
495- f"(AUC = { roc_auc_mean :0.2f} )"
510+ f"(AUC = { roc_auc :0.2f} )"
496511 )
497512
498513 (line ,) = self .ax_ .plot (
499- fpr_est_class , tpr_est_class , ** line_kwargs_validated
514+ roc_curve [ "fpr" ], roc_curve [ "tpr" ] , ** line_kwargs_validated
500515 )
501516 lines .append (line )
502517
@@ -564,28 +579,30 @@ def _plot_comparison_cross_validation(
564579 10 if len (estimator_names ) < 10 else len (estimator_names ),
565580 )
566581 for report_idx , estimator_name in enumerate (estimator_names ):
567- roc_auc_estimator = self .roc_auc [
568- self .roc_auc ["estimator_name" ] == estimator_name
569- ]["roc_auc" ]
582+ roc_curve = _filter_by (
583+ self .roc_curve ,
584+ label = self .pos_label ,
585+ estimator_name = estimator_name ,
586+ )
587+
588+ roc_auc = _filter_by (
589+ self .roc_auc ,
590+ estimator_name = estimator_name ,
591+ )["roc_auc" ]
570592
571593 line_kwargs_validated = _validate_style_kwargs (
572594 line_kwargs , roc_curve_kwargs [report_idx ]
573595 )
574596 line_kwargs_validated ["color" ] = colors [report_idx ]
575597 line_kwargs_validated ["alpha" ] = 0.6
576598
577- roc_curve_estimator = self .roc_curve [
578- (self .roc_curve ["label" ] == self .pos_label )
579- & (self .roc_curve ["estimator_name" ] == estimator_name )
580- ]
581-
582- for split_index , segment in roc_curve_estimator .groupby ("split_index" ):
599+ for split_index , segment in roc_curve .groupby ("split_index" ):
583600 if split_index == 0 :
584601 label_kwargs = {
585602 "label" : (
586603 f"{ estimator_name } "
587- f"(AUC = { roc_auc_estimator .mean ():0.2f} +/- "
588- f"{ roc_auc_estimator .std ():0.2f} )"
604+ f"(AUC = { roc_auc .mean ():0.2f} +/- "
605+ f"{ roc_auc .std ():0.2f} )"
589606 )
590607 }
591608 else :
@@ -630,31 +647,31 @@ def _plot_comparison_cross_validation(
630647 est_color = colors [est_idx ]
631648
632649 for label_idx , label in enumerate (labels ):
633- roc_auc_estimator = self .roc_auc [
634- (self .roc_auc ["label" ] == label )
635- & (self .roc_auc ["estimator_name" ] == estimator_name )
636- ]["roc_auc" ]
650+ roc_curve = _filter_by (
651+ self .roc_curve ,
652+ label = label ,
653+ estimator_name = estimator_name ,
654+ )
655+
656+ roc_auc = _filter_by (
657+ self .roc_auc ,
658+ label = label ,
659+ estimator_name = estimator_name ,
660+ )["roc_auc" ]
637661
638662 line_kwargs_validated = _validate_style_kwargs (
639663 line_kwargs , roc_curve_kwargs [est_idx ]
640664 )
641665 line_kwargs_validated ["color" ] = est_color
642666 line_kwargs_validated ["alpha" ] = 0.6
643667
644- roc_curve_estimator = self .roc_curve [
645- (self .roc_curve ["label" ] == label )
646- & (self .roc_curve ["estimator_name" ] == estimator_name )
647- ]
648-
649- for split_index , segment in roc_curve_estimator .groupby (
650- "split_index"
651- ):
668+ for split_index , segment in roc_curve .groupby ("split_index" ):
652669 if split_index == 0 :
653670 label_kwargs = {
654671 "label" : (
655672 f"{ estimator_name } "
656- f"(AUC = { roc_auc_estimator .mean ():0.2f} +/- "
657- f"{ roc_auc_estimator .std ():0.2f} )"
673+ f"(AUC = { roc_auc .mean ():0.2f} +/- "
674+ f"{ roc_auc .std ():0.2f} )"
658675 )
659676 }
660677 else :
0 commit comments