@@ -344,10 +344,37 @@ def test_frame_multiclass_classification(
344344 assert df ["label" ].nunique () == len (estimator .classes_ )
345345
346346 if with_roc_auc :
347- for ( _ ) , group in df .groupby (["label" ], observed = True ):
347+ for _ , group in df .groupby (["label" ], observed = True ):
348348 assert group ["roc_auc" ].nunique () == 1
349349
350350
351+ @pytest .mark .parametrize ("with_roc_auc" , [False , True ])
352+ def test_frame_multiclass_classification_data_source_both (
353+ logistic_multiclass_classification_with_train_test , with_roc_auc
354+ ):
355+ """
356+ Test the frame method with multiclass classification data and data_source="both".
357+ """
358+ estimator , X_train , X_test , y_train , y_test = (
359+ logistic_multiclass_classification_with_train_test
360+ )
361+ report = EstimatorReport (
362+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
363+ )
364+ df = report .metrics .roc (data_source = "both" ).frame (with_roc_auc = with_roc_auc )
365+ expected_index = ["data_source" , "label" ]
366+ expected_columns = ["threshold" , "fpr" , "tpr" ]
367+ if with_roc_auc :
368+ expected_columns .append ("roc_auc" )
369+
370+ check_frame_structure (df , expected_index , expected_columns )
371+ assert df ["label" ].nunique () == len (estimator .classes_ )
372+
373+ if with_roc_auc :
374+ for _ , group in df .groupby (["label" ], observed = True ):
375+ assert group ["roc_auc" ].nunique () == 2
376+
377+
351378def test_legend (
352379 pyplot ,
353380 logistic_binary_classification_with_train_test ,
0 commit comments