@@ -216,7 +216,7 @@ def _plot_single_estimator(
216216 )
217217
218218 else : # multiclass-classification
219- labels = self .roc_curve ["label" ].unique ()
219+ labels = self .roc_curve ["label" ].cat . categories
220220 class_colors = sample_mpl_colormap (
221221 colormaps .get_cmap ("tab10" ), 10 if len (labels ) < 10 else len (labels )
222222 )
@@ -309,7 +309,7 @@ def _plot_cross_validated_estimator(
309309 line_kwargs : dict [str , Any ] = {}
310310
311311 if self .ml_task == "binary-classification" :
312- for split_idx in self .roc_curve ["split_index" ].unique () :
312+ for split_idx in self .roc_curve ["split_index" ].cat . categories :
313313 roc_curve = self .roc_curve .query (
314314 f"label == { self .pos_label } & split_index == { split_idx } "
315315 )
@@ -338,7 +338,7 @@ def _plot_cross_validated_estimator(
338338 )
339339 else : # multiclass-classification
340340 info_pos_label = None # irrelevant for multiclass
341- labels = self .roc_curve ["label" ].unique ()
341+ labels = self .roc_curve ["label" ].cat . categories
342342 class_colors = sample_mpl_colormap (
343343 colormaps .get_cmap ("tab10" ), 10 if len (labels ) < 10 else len (labels )
344344 )
@@ -347,7 +347,7 @@ def _plot_cross_validated_estimator(
347347 roc_auc = self .roc_auc .query (f"label == { class_label } " )["roc_auc" ]
348348 roc_curve_kwargs_class = roc_curve_kwargs [class_idx ]
349349
350- for split_idx in self .roc_curve ["split_index" ].unique () :
350+ for split_idx in self .roc_curve ["split_index" ].cat . categories :
351351 roc_curve_label = self .roc_curve .query (
352352 f"label == { class_label } & split_index == { split_idx } "
353353 )
@@ -461,7 +461,7 @@ def _plot_comparison_estimator(
461461 )
462462 else : # multiclass-classification
463463 info_pos_label = None # irrelevant for multiclass
464- labels = self .roc_curve ["label" ].unique ()
464+ labels = self .roc_curve ["label" ].cat . categories
465465 class_colors = sample_mpl_colormap (
466466 colormaps .get_cmap ("tab10" ), 10 if len (labels ) < 10 else len (labels )
467467 )
@@ -555,7 +555,7 @@ def _plot_comparison_cross_validation(
555555 line_kwargs : dict [str , Any ] = {}
556556
557557 if self .ml_task == "binary-classification" :
558- labels = self .roc_curve ["label" ].unique ()
558+ labels = self .roc_curve ["label" ].cat . categories
559559 colors = sample_mpl_colormap (
560560 colormaps .get_cmap ("tab10" ),
561561 10 if len (estimator_names ) < 10 else len (estimator_names ),
@@ -575,7 +575,9 @@ def _plot_comparison_cross_validation(
575575 line_kwargs , roc_curve_kwargs [report_idx ]
576576 )
577577
578- for split_index , segment in roc_curve .groupby ("split_index" ):
578+ for split_index , segment in roc_curve .groupby (
579+ "split_index" , observed = True
580+ ):
579581 if split_index == 0 :
580582 label_kwargs = {
581583 "label" : (
@@ -616,7 +618,7 @@ def _plot_comparison_cross_validation(
616618
617619 else : # multiclass-classification
618620 info_pos_label = None # irrelevant for multiclass
619- labels = self .roc_curve ["label" ].unique ()
621+ labels = self .roc_curve ["label" ].cat . categories
620622 colors = sample_mpl_colormap (
621623 colormaps .get_cmap ("tab10" ),
622624 10 if len (estimator_names ) < 10 else len (estimator_names ),
@@ -635,7 +637,9 @@ def _plot_comparison_cross_validation(
635637 f"label == { label } & estimator_name == '{ estimator_name } '"
636638 )["roc_auc" ]
637639
638- for split_index , segment in roc_curve .groupby ("split_index" ):
640+ for split_index , segment in roc_curve .groupby (
641+ "split_index" , observed = True
642+ ):
639643 if split_index == 0 :
640644 label_kwargs = {
641645 "label" : (
@@ -740,7 +744,7 @@ def plot(
740744 self .report_type == "comparison-cross-validation"
741745 and self .ml_task == "multiclass-classification"
742746 ):
743- n_labels = len (self .roc_auc ["label" ].unique () )
747+ n_labels = len (self .roc_auc ["label" ].cat . categories )
744748 self .figure_ , self .ax_ = plt .subplots (ncols = n_labels )
745749 else :
746750 self .figure_ , self .ax_ = plt .subplots ()
@@ -762,31 +766,37 @@ def plot(
762766
763767 if self .report_type == "estimator" :
764768 self .ax_ , self .lines_ , info_pos_label = self ._plot_single_estimator (
765- estimator_name = estimator_name or self .roc_auc ["estimator_name" ][0 ],
769+ estimator_name = (
770+ estimator_name
771+ or self .roc_auc ["estimator_name" ].cat .categories .item ()
772+ ),
766773 roc_curve_kwargs = roc_curve_kwargs ,
767774 plot_chance_level = plot_chance_level ,
768775 chance_level_kwargs = chance_level_kwargs ,
769776 )
770777 elif self .report_type == "cross-validation" :
771778 self .ax_ , self .lines_ , info_pos_label = (
772779 self ._plot_cross_validated_estimator (
773- estimator_name = estimator_name or self .roc_auc ["estimator_name" ][0 ],
780+ estimator_name = (
781+ estimator_name
782+ or self .roc_auc ["estimator_name" ].cat .categories .item ()
783+ ),
774784 roc_curve_kwargs = roc_curve_kwargs ,
775785 plot_chance_level = plot_chance_level ,
776786 chance_level_kwargs = chance_level_kwargs ,
777787 )
778788 )
779789 elif self .report_type == "comparison-estimator" :
780790 self .ax_ , self .lines_ , info_pos_label = self ._plot_comparison_estimator (
781- estimator_names = self .roc_auc ["estimator_name" ].unique () ,
791+ estimator_names = self .roc_auc ["estimator_name" ].cat . categories ,
782792 roc_curve_kwargs = roc_curve_kwargs ,
783793 plot_chance_level = plot_chance_level ,
784794 chance_level_kwargs = chance_level_kwargs ,
785795 )
786796 elif self .report_type == "comparison-cross-validation" :
787797 self .ax_ , self .lines_ , info_pos_label = (
788798 self ._plot_comparison_cross_validation (
789- estimator_names = self .roc_auc ["estimator_name" ].unique () ,
799+ estimator_names = self .roc_auc ["estimator_name" ].cat . categories ,
790800 roc_curve_kwargs = roc_curve_kwargs ,
791801 plot_chance_level = plot_chance_level ,
792802 chance_level_kwargs = chance_level_kwargs ,
@@ -943,9 +953,15 @@ def _compute_data_for_display(
943953 }
944954 )
945955
956+ dtypes = {
957+ "estimator_name" : "category" ,
958+ "split_index" : "category" ,
959+ "label" : "category" ,
960+ }
961+
946962 return cls (
947- roc_curve = DataFrame .from_records (roc_curve_records ),
948- roc_auc = DataFrame .from_records (roc_auc_records ),
963+ roc_curve = DataFrame .from_records (roc_curve_records ). astype ( dtypes ) ,
964+ roc_auc = DataFrame .from_records (roc_auc_records ). astype ( dtypes ) ,
949965 pos_label = pos_label_validated ,
950966 data_source = data_source ,
951967 ml_task = ml_task ,
0 commit comments