@@ -671,3 +671,190 @@ def test_roc_curve_display_wrong_report_type(pyplot, binary_classification_data)
671671 )
672672 with pytest .raises (ValueError , match = err_msg ):
673673 display .plot ()
674+
675+
676+ def test_roc_curve_display_subplots_basic_binary (pyplot , binary_classification_data ):
677+ """Test that subplots=True creates multiple subplots with default parameters
678+ for binary classification."""
679+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
680+
681+ # Create a comparison report with multiple estimators
682+ est1 = clone (estimator )
683+ est2 = clone (estimator )
684+ est1 .fit (X_train , y_train )
685+ est2 .fit (X_train , y_train )
686+
687+ report = ComparisonReport (
688+ reports = {
689+ "estimator 1" : EstimatorReport (
690+ est1 ,
691+ X_train = X_train ,
692+ y_train = y_train ,
693+ X_test = X_test ,
694+ y_test = y_test ,
695+ ),
696+ "estimator 2" : EstimatorReport (
697+ est2 ,
698+ X_train = X_train ,
699+ y_train = y_train ,
700+ X_test = X_test ,
701+ y_test = y_test ,
702+ ),
703+ },
704+ )
705+ display = report .metrics .roc ()
706+ display .plot (subplots = True )
707+
708+ assert hasattr (display , "figure_" )
709+
710+ axes = display .figure_ .get_axes ()
711+ assert len (axes ) == 2
712+
713+ assert "Model: estimator 1" in axes [0 ].get_title ()
714+ assert "Model: estimator 2" in axes [1 ].get_title ()
715+
716+ # Each subplot should have correct labels
717+ for ax in axes :
718+ assert "False Positive Rate" in ax .get_xlabel ()
719+ assert "True Positive Rate" in ax .get_ylabel ()
720+ assert ax .get_aspect () in ("equal" , 1.0 )
721+
722+
723+ def test_roc_curve_display_subplots_basic_multiclass (
724+ pyplot , multiclass_classification_data
725+ ):
726+ """Test that subplots=True creates multiple subplots with default parameters
727+ for multiclass classification."""
728+ estimator , X_train , X_test , y_train , y_test = multiclass_classification_data
729+ report = EstimatorReport (
730+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
731+ )
732+
733+ # In multiclass case, we should get one subplot per class
734+ display = report .metrics .roc ()
735+ display .plot (subplots = True )
736+
737+ assert hasattr (display , "figure_" )
738+
739+ # Check correct number of subplots (one per class)
740+ axes = display .figure_ .get_axes ()
741+ assert len (axes ) == len (estimator .classes_ )
742+
743+ for i , class_label in enumerate (estimator .classes_ ):
744+ assert f"Class: { class_label } " in axes [i ].get_title ()
745+
746+ # Each subplot should have correct labels
747+ for ax in axes :
748+ assert "False Positive Rate" in ax .get_xlabel ()
749+ assert "True Positive Rate" in ax .get_ylabel ()
750+ assert ax .get_aspect () in ("equal" , 1.0 )
751+
752+
753+ def test_roc_curve_display_subplots_cv_binary (
754+ pyplot , binary_classification_data_no_split
755+ ):
756+ """Test subplots with cross-validation for binary classification."""
757+ (estimator , X , y ), cv = binary_classification_data_no_split , 3
758+ report = CrossValidationReport (estimator , X = X , y = y , cv_splitter = cv )
759+ display = report .metrics .roc ()
760+ display .plot (subplots = True )
761+
762+ assert hasattr (display , "figure_" )
763+
764+ # Check number of subplots matches number of CV folds
765+ axes = display .figure_ .get_axes ()
766+ assert len (axes ) == cv
767+
768+ # Check titles for each fold
769+ for i , ax in enumerate (axes ):
770+ assert f"Fold #{ i + 1 } " in ax .get_title ()
771+
772+ # Each subplot should have correct labels
773+ for ax in axes :
774+ assert "False Positive Rate" in ax .get_xlabel ()
775+ assert "True Positive Rate" in ax .get_ylabel ()
776+ assert ax .get_aspect () in ("equal" , 1.0 )
777+
778+
779+ def test_roc_curve_display_subplots_custom_layout (pyplot , binary_classification_data ):
780+ """Test subplots with custom layout parameters."""
781+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
782+
783+ # Create a comparison report with multiple estimators
784+ est1 = clone (estimator )
785+ est2 = clone (estimator )
786+ est3 = clone (estimator )
787+ est1 .fit (X_train , y_train )
788+ est2 .fit (X_train , y_train )
789+ est3 .fit (X_train , y_train )
790+
791+ report = ComparisonReport (
792+ reports = {
793+ "estimator 1" : EstimatorReport (
794+ est1 ,
795+ X_train = X_train ,
796+ y_train = y_train ,
797+ X_test = X_test ,
798+ y_test = y_test ,
799+ ),
800+ "estimator 2" : EstimatorReport (
801+ est2 ,
802+ X_train = X_train ,
803+ y_train = y_train ,
804+ X_test = X_test ,
805+ y_test = y_test ,
806+ ),
807+ "estimator 3" : EstimatorReport (
808+ est3 ,
809+ X_train = X_train ,
810+ y_train = y_train ,
811+ X_test = X_test ,
812+ y_test = y_test ,
813+ ),
814+ },
815+ )
816+ display = report .metrics .roc ()
817+
818+ # Test with custom nrows and ncols
819+ figsize = (10 , 8 )
820+ display .plot (subplots = True , nrows = 1 , ncols = 3 , figsize = figsize )
821+
822+ # Check figure was created with correct size
823+ assert hasattr (display , "figure_" )
824+ assert display .figure_ .get_size_inches ()[0 ] == figsize [0 ]
825+ assert display .figure_ .get_size_inches ()[1 ] == figsize [1 ]
826+
827+ # Check layout is correct
828+ axes = display .figure_ .get_axes ()
829+ assert len (axes ) == 3
830+
831+ # Check subplot arrangement (1 row, 3 columns)
832+ pos1 = axes [0 ].get_position ()
833+ pos2 = axes [1 ].get_position ()
834+ pos3 = axes [2 ].get_position ()
835+
836+ # Same row (similar y positions)
837+ assert abs (pos1 .y0 - pos2 .y0 ) < 0.1
838+ assert abs (pos2 .y0 - pos3 .y0 ) < 0.1
839+
840+ # Different columns (increasing x positions)
841+ assert pos1 .x0 < pos2 .x0
842+ assert pos2 .x0 < pos3 .x0
843+
844+
845+ def test_roc_curve_display_ax_and_subplots_error (pyplot , binary_classification_data ):
846+ """Test that an error is raised when both ax and subplots=True are specified."""
847+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
848+ report = EstimatorReport (
849+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
850+ )
851+ display = report .metrics .roc ()
852+
853+ # Create a figure and axis to pass
854+ fig , ax = pyplot .subplots ()
855+
856+ # Test that error is raised when both ax and subplots=True are specified
857+ with pytest .raises (
858+ ValueError , match = "Cannot specify both 'ax' and 'subplots=True'"
859+ ):
860+ display .plot (ax = ax , subplots = True )
0 commit comments