22from typing import Any , Literal , Optional , Union , cast
33
44import matplotlib .pyplot as plt
5+ import numpy as np
56from matplotlib import colormaps
67from matplotlib .axes import Axes
78from matplotlib .lines import Line2D
@@ -387,6 +388,113 @@ def _plot_cross_validated_estimator(
387388
388389 return self .ax_ , lines , info_pos_label
389390
391+ def _plot_average_cross_validated_binary_estimator (
392+ self ,
393+ * ,
394+ estimator_name : str ,
395+ roc_curve_kwargs : list [dict [str , Any ]],
396+ plot_chance_level : bool = True ,
397+ chance_level_kwargs : Optional [dict [str , Any ]],
398+ ) -> tuple [Axes , list [Line2D ], Union [str , None ]]:
399+ """Plot ROC curve for a cross-validated estimator.
400+
401+ Parameters
402+ ----------
403+ estimator_name : str
404+ The name of the estimator.
405+
406+ roc_curve_kwargs : list of dict
407+ List of dictionaries containing keyword arguments to customize the ROC
408+ curves. The length of the list should match the number of curves to plot.
409+
410+ plot_chance_level : bool, default=True
411+ Whether to plot the chance level.
412+
413+ chance_level_kwargs : dict, default=None
414+ Keyword arguments to be passed to matplotlib's `plot` for rendering
415+ the chance level line.
416+
417+ Returns
418+ -------
419+ ax : matplotlib.axes.Axes
420+ The axes with the ROC curves plotted.
421+
422+ lines : list of matplotlib.lines.Line2D
423+ The plotted ROC curve lines.
424+
425+ info_pos_label : str or None
426+ String containing positive label information for binary classification,
427+ None for multiclass.
428+ """
429+ lines : list [Line2D ] = []
430+ average_type = self .roc_curve ["average" ].cat .categories .item ()
431+ n_folds : int = 0
432+
433+ for split_idx in self .roc_curve ["split_index" ].cat .categories :
434+ if split_idx is None :
435+ continue
436+ split_idx = int (split_idx )
437+ query = f"label == { self .pos_label !r} & split_index == { split_idx } "
438+ roc_curve = self .roc_curve .query (query )
439+
440+ line_kwargs_validated = _validate_style_kwargs (
441+ {"color" : "grey" , "alpha" : 0.3 , "lw" : 0.75 }, roc_curve_kwargs [split_idx ]
442+ )
443+
444+ (line ,) = self .ax_ .plot (
445+ roc_curve ["fpr" ],
446+ roc_curve ["tpr" ],
447+ ** line_kwargs_validated ,
448+ )
449+ lines .append (line )
450+ n_folds += 1
451+
452+ info_pos_label = (
453+ f"\n (Positive label: { self .pos_label } )"
454+ if self .pos_label is not None
455+ else ""
456+ )
457+
458+ query = f"label == { self .pos_label !r} & average == '{ average_type } '"
459+ average_roc_curve = self .roc_curve .query (query )
460+ average_roc_auc = self .roc_auc .query (query )["roc_auc" ].item ()
461+
462+ line_kwargs_validated = _validate_style_kwargs ({}, {})
463+ line_kwargs_validated ["label" ] = (
464+ f"{ average_type .capitalize ()} average of { n_folds } folds"
465+ f"(AUC = { average_roc_auc :0.2f} )"
466+ )
467+
468+ (line ,) = self .ax_ .plot (
469+ average_roc_curve ["fpr" ],
470+ average_roc_curve ["tpr" ],
471+ ** line_kwargs_validated ,
472+ )
473+ lines .append (line )
474+
475+ info_pos_label = (
476+ f"\n (Positive label: { self .pos_label } )"
477+ if self .pos_label is not None
478+ else ""
479+ )
480+
481+ if plot_chance_level :
482+ self .chance_level_ = _add_chance_level (
483+ self .ax_ ,
484+ chance_level_kwargs ,
485+ self ._default_chance_level_kwargs ,
486+ )
487+ else :
488+ self .chance_level_ = None
489+
490+ if self .data_source in ("train" , "test" ):
491+ title = f"{ estimator_name } on $\\ bf{{{ self .data_source } }}$ set"
492+ else :
493+ title = f"{ estimator_name } on $\\ bf{{external}}$ set"
494+ self .ax_ .legend (bbox_to_anchor = (1.02 , 1 ), title = title )
495+
496+ return self .ax_ , lines , info_pos_label
497+
390498 def _plot_comparison_estimator (
391499 self ,
392500 * ,
@@ -760,17 +868,30 @@ def plot(
760868 chance_level_kwargs = chance_level_kwargs ,
761869 )
762870 elif self .report_type == "cross-validation" :
763- self .ax_ , self .lines_ , info_pos_label = (
764- self ._plot_cross_validated_estimator (
765- estimator_name = (
766- estimator_name
767- or self .roc_auc ["estimator_name" ].cat .categories .item ()
768- ),
769- roc_curve_kwargs = roc_curve_kwargs ,
770- plot_chance_level = plot_chance_level ,
771- chance_level_kwargs = chance_level_kwargs ,
871+ if "average" in self .roc_auc .columns :
872+ self .ax_ , self .lines_ , info_pos_label = (
873+ self ._plot_average_cross_validated_binary_estimator (
874+ estimator_name = (
875+ estimator_name
876+ or self .roc_auc ["estimator_name" ].cat .categories .item ()
877+ ),
878+ roc_curve_kwargs = roc_curve_kwargs ,
879+ plot_chance_level = plot_chance_level ,
880+ chance_level_kwargs = chance_level_kwargs ,
881+ )
882+ )
883+ else :
884+ self .ax_ , self .lines_ , info_pos_label = (
885+ self ._plot_cross_validated_estimator (
886+ estimator_name = (
887+ estimator_name
888+ or self .roc_auc ["estimator_name" ].cat .categories .item ()
889+ ),
890+ roc_curve_kwargs = roc_curve_kwargs ,
891+ plot_chance_level = plot_chance_level ,
892+ chance_level_kwargs = chance_level_kwargs ,
893+ )
772894 )
773- )
774895 elif self .report_type == "comparison-estimator" :
775896 self .ax_ , self .lines_ , info_pos_label = self ._plot_comparison_estimator (
776897 estimator_names = self .roc_auc ["estimator_name" ].cat .categories ,
@@ -812,6 +933,7 @@ def _compute_data_for_display(
812933 cls ,
813934 y_true : Sequence [YPlotData ],
814935 y_pred : Sequence [YPlotData ],
936+ average : Optional [Literal ["threshold" ]] = None ,
815937 * ,
816938 report_type : ReportType ,
817939 estimators : Sequence [BaseEstimator ],
@@ -869,6 +991,7 @@ def _compute_data_for_display(
869991 roc_auc_records = []
870992
871993 if ml_task == "binary-classification" :
994+ pos_label_validated = cast (PositiveLabel , pos_label_validated )
872995 for y_true_i , y_pred_i in zip (y_true , y_pred ):
873996 fpr_i , tpr_i , thresholds_i = roc_curve (
874997 y_true_i .y ,
@@ -878,8 +1001,6 @@ def _compute_data_for_display(
8781001 )
8791002 roc_auc_i = auc (fpr_i , tpr_i )
8801003
881- pos_label_validated = cast (PositiveLabel , pos_label_validated )
882-
8831004 for fpr , tpr , threshold in zip (fpr_i , tpr_i , thresholds_i ):
8841005 roc_curve_records .append (
8851006 {
@@ -900,8 +1021,63 @@ def _compute_data_for_display(
9001021 "roc_auc" : roc_auc_i ,
9011022 }
9021023 )
1024+ if average is not None :
1025+ if average == "threshold" :
1026+ all_thresholds = []
1027+ all_fprs = []
1028+ all_tprs = []
1029+
1030+ roc_curves_df = DataFrame .from_records (roc_curve_records )
1031+ for _ , group in roc_curves_df .groupby ("split_index" ):
1032+ sorted_group = group .sort_values ("threshold" , ascending = False )
1033+ all_thresholds .append (
1034+ np .array (sorted_group ["threshold" ].values )
1035+ )
1036+ all_fprs .append (np .array (sorted_group ["fpr" ].values ))
1037+ all_tprs .append (np .array (sorted_group ["tpr" ].values ))
1038+
1039+ average_fpr , average_tpr , average_threshold = (
1040+ cls ._threshold_average (
1041+ xs = all_fprs ,
1042+ ys = all_tprs ,
1043+ thresholds = all_thresholds ,
1044+ )
1045+ )
1046+ else :
1047+ raise TypeError (
1048+ "'threshold' is the only supported option for `average`,"
1049+ f"but got { average } instead"
1050+ )
1051+ average_roc_auc = auc (average_fpr , average_tpr )
1052+ for fpr , tpr , threshold in zip (
1053+ average_fpr , average_tpr , average_threshold
1054+ ):
1055+ roc_curve_records .append (
1056+ {
1057+ "estimator_name" : y_true_i .estimator_name ,
1058+ "split_index" : None ,
1059+ "label" : pos_label_validated ,
1060+ "threshold" : threshold ,
1061+ "fpr" : fpr ,
1062+ "tpr" : tpr ,
1063+ "average" : "threshold" ,
1064+ }
1065+ )
1066+ roc_auc_records .append (
1067+ {
1068+ "estimator_name" : y_true_i .estimator_name ,
1069+ "split_index" : None ,
1070+ "label" : pos_label_validated ,
1071+ "roc_auc" : average_roc_auc ,
1072+ "average" : "threshold" ,
1073+ }
1074+ )
9031075
9041076 else : # multiclass-classification
1077+ if average is not None :
1078+ raise ValueError (
1079+ "Averaging is not implemented for multi class classification"
1080+ )
9051081 # OvR fashion to collect fpr, tpr, and roc_auc
9061082 for y_true_i , y_pred_i , est in zip (y_true , y_pred , estimators ):
9071083 label_binarizer = LabelBinarizer ().fit (est .classes_ )
@@ -942,7 +1118,7 @@ def _compute_data_for_display(
9421118 "estimator_name" : "category" ,
9431119 "split_index" : "category" ,
9441120 "label" : "category" ,
945- }
1121+ } | ({ "average" : "category" } if average is not None else {})
9461122
9471123 return cls (
9481124 roc_curve = DataFrame .from_records (roc_curve_records ).astype (dtypes ),
0 commit comments