@@ -153,7 +153,7 @@ def _plot_single_estimator(
153153 estimator_name : str ,
154154 roc_curve_kwargs : list [dict [str , Any ]],
155155 plot_chance_level : bool = True ,
156- chance_level_kwargs : Optional [dict [str , Any ]] = None ,
156+ chance_level_kwargs : Optional [dict [str , Any ]],
157157 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
158158 """Plot ROC curve for a single estimator.
159159
@@ -272,7 +272,7 @@ def _plot_cross_validated_estimator(
272272 estimator_name : str ,
273273 roc_curve_kwargs : list [dict [str , Any ]],
274274 plot_chance_level : bool = True ,
275- chance_level_kwargs : Optional [dict [str , Any ]] = None ,
275+ chance_level_kwargs : Optional [dict [str , Any ]],
276276 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
277277 """Plot ROC curve for a cross-validated estimator.
278278
@@ -398,7 +398,7 @@ def _plot_comparison_estimator(
398398 estimator_names : list [str ],
399399 roc_curve_kwargs : list [dict [str , Any ]],
400400 plot_chance_level : bool = True ,
401- chance_level_kwargs : Optional [dict [str , Any ]] = None ,
401+ chance_level_kwargs : Optional [dict [str , Any ]],
402402 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
403403 """Plot ROC curve of several estimators.
404404
@@ -518,7 +518,7 @@ def _plot_comparison_cross_validation(
518518 estimator_names : list [str ],
519519 roc_curve_kwargs : list [dict [str , Any ]],
520520 plot_chance_level : bool = True ,
521- chance_level_kwargs : Optional [dict [str , Any ]] = None ,
521+ chance_level_kwargs : Optional [dict [str , Any ]],
522522 ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
523523 """Plot ROC curve of several cross-validations.
524524
@@ -568,11 +568,11 @@ def _plot_comparison_cross_validation(
568568 "roc_auc"
569569 ]
570570
571+ line_kwargs ["color" ] = colors [report_idx ]
572+ line_kwargs ["alpha" ] = 0.6
571573 line_kwargs_validated = _validate_style_kwargs (
572574 line_kwargs , roc_curve_kwargs [report_idx ]
573575 )
574- line_kwargs_validated ["color" ] = colors [report_idx ]
575- line_kwargs_validated ["alpha" ] = 0.6
576576
577577 for split_index , segment in roc_curve .groupby ("split_index" ):
578578 if split_index == 0 :
@@ -620,6 +620,7 @@ def _plot_comparison_cross_validation(
620620 colormaps .get_cmap ("tab10" ),
621621 10 if len (estimator_names ) < 10 else len (estimator_names ),
622622 )
623+ idx = 0
623624
624625 for est_idx , estimator_name in enumerate (estimator_names ):
625626 est_color = colors [est_idx ]
@@ -633,12 +634,6 @@ def _plot_comparison_cross_validation(
633634 f"label == { label } & estimator_name == '{ estimator_name } '"
634635 )["roc_auc" ]
635636
636- line_kwargs_validated = _validate_style_kwargs (
637- line_kwargs , roc_curve_kwargs [est_idx ]
638- )
639- line_kwargs_validated ["color" ] = est_color
640- line_kwargs_validated ["alpha" ] = 0.6
641-
642637 for split_index , segment in roc_curve .groupby ("split_index" ):
643638 if split_index == 0 :
644639 label_kwargs = {
@@ -651,13 +646,21 @@ def _plot_comparison_cross_validation(
651646 else :
652647 label_kwargs = {}
653648
649+ line_kwargs ["color" ] = est_color
650+ line_kwargs ["alpha" ] = 0.6
651+ line_kwargs_validated = _validate_style_kwargs (
652+ line_kwargs , roc_curve_kwargs [idx ]
653+ )
654+
654655 (line ,) = self .ax_ [label_idx ].plot (
655656 segment ["fpr" ],
656657 segment ["tpr" ],
657658 ** (line_kwargs_validated | label_kwargs ),
658659 )
659660 lines .append (line )
660661
662+ idx = idx + 1
663+
661664 info_pos_label = f"\n (Positive label: { label } )"
662665 _set_axis_labels (self .ax_ [label_idx ], info_pos_label )
663666
@@ -784,6 +787,8 @@ def plot(
784787 self ._plot_comparison_cross_validation (
785788 estimator_names = self .roc_auc ["estimator_name" ].unique (),
786789 roc_curve_kwargs = roc_curve_kwargs ,
790+ plot_chance_level = plot_chance_level ,
791+ chance_level_kwargs = chance_level_kwargs ,
787792 )
788793 )
789794 else :
0 commit comments