@@ -577,3 +577,143 @@ def test_prediction_error_display_wrong_report_type(pyplot, regression_data):
577577 )
578578 with pytest .raises (ValueError , match = err_msg ):
579579 display .plot ()
580+
581+
582+ def test_prediction_error_display_subplots_basic (pyplot , regression_data ):
583+ """Test that subplots=True creates multiple subplots with default parameters."""
584+ estimator , X_train , X_test , y_train , y_test = regression_data
585+ report = ComparisonReport (
586+ reports = {
587+ "estimator 1" : EstimatorReport (
588+ estimator ,
589+ X_train = X_train ,
590+ y_train = y_train ,
591+ X_test = X_test ,
592+ y_test = y_test ,
593+ ),
594+ "estimator 2" : EstimatorReport (
595+ estimator ,
596+ X_train = X_train ,
597+ y_train = y_train ,
598+ X_test = X_test ,
599+ y_test = y_test ,
600+ ),
601+ },
602+ )
603+ display = report .metrics .prediction_error ()
604+ display .plot (subplots = True )
605+
606+ assert hasattr (display , "figure_" )
607+ assert len (display .scatter_ ) == 2
608+
609+ axes = display .figure_ .get_axes ()
610+ assert len (axes ) == 2
611+ assert "Model: estimator 1" in axes [0 ].get_title ()
612+ assert "Model: estimator 2" in axes [1 ].get_title ()
613+
614+
615+ def test_prediction_error_display_subplots_custom_layout (pyplot , regression_data ):
616+ """Test subplots with custom layout parameters."""
617+ estimator , X_train , X_test , y_train , y_test = regression_data
618+ report = ComparisonReport (
619+ reports = {
620+ "estimator 1" : EstimatorReport (
621+ estimator ,
622+ X_train = X_train ,
623+ y_train = y_train ,
624+ X_test = X_test ,
625+ y_test = y_test ,
626+ ),
627+ "estimator 2" : EstimatorReport (
628+ estimator ,
629+ X_train = X_train ,
630+ y_train = y_train ,
631+ X_test = X_test ,
632+ y_test = y_test ,
633+ ),
634+ "estimator 3" : EstimatorReport (
635+ estimator ,
636+ X_train = X_train ,
637+ y_train = y_train ,
638+ X_test = X_test ,
639+ y_test = y_test ,
640+ ),
641+ },
642+ )
643+ display = report .metrics .prediction_error ()
644+
645+ figsize = (10 , 8 )
646+ display .plot (subplots = True , nrows = 3 , ncols = 1 , figsize = figsize )
647+
648+ assert hasattr (display , "figure_" )
649+ assert display .figure_ .get_size_inches ()[0 ] == figsize [0 ]
650+ assert display .figure_ .get_size_inches ()[1 ] == figsize [1 ]
651+
652+ axes = display .figure_ .get_axes ()
653+ assert len (axes ) == 3
654+
655+ pos1 = axes [0 ].get_position ()
656+ pos2 = axes [1 ].get_position ()
657+ pos3 = axes [2 ].get_position ()
658+
659+ # Same column (similar x positions)
660+ assert abs (pos1 .x0 - pos2 .x0 ) < 0.1
661+ assert abs (pos2 .x0 - pos3 .x0 ) < 0.1
662+
663+ # Different rows (decreasing y positions)
664+ assert pos1 .y0 > pos2 .y0
665+ assert pos2 .y0 > pos3 .y0
666+
667+
668+ def test_prediction_error_display_subplots_cross_validation (
669+ pyplot , regression_data_no_split
670+ ):
671+ """Test subplots with cross-validation data."""
672+ (estimator , X , y ), cv = regression_data_no_split , 3
673+ report = CrossValidationReport (estimator , X = X , y = y , cv_splitter = cv )
674+ display = report .metrics .prediction_error ()
675+ display .plot (subplots = True )
676+
677+ assert hasattr (display , "figure_" )
678+
679+ # Check number of subplots matches number of CV folds
680+ axes = display .figure_ .get_axes ()
681+ assert len (axes ) == cv
682+
683+ # Check titles for each fold
684+ for i , ax in enumerate (axes ):
685+ assert f"Fold #{ i + 1 } " in ax .get_title ()
686+
687+
688+ def test_prediction_error_display_ax_and_subplots_error (pyplot , regression_data ):
689+ """Test that an error is raised when both ax and subplots=True are specified."""
690+ estimator , X_train , X_test , y_train , y_test = regression_data
691+ report = EstimatorReport (
692+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
693+ )
694+ display = report .metrics .prediction_error ()
695+
696+ # Create a figure and axis to pass
697+ fig , ax = pyplot .subplots ()
698+
699+ # Test that error is raised when both ax and subplots=True are specified
700+ with pytest .raises (
701+ ValueError , match = "Cannot specify both 'ax' and 'subplots=True'"
702+ ):
703+ display .plot (ax = ax , subplots = True )
704+
705+
706+ def test_prediction_error_display_subplots_estimator_report (pyplot , regression_data ):
707+ """Test subplots with simple estimator report (should be a single plot)."""
708+ estimator , X_train , X_test , y_train , y_test = regression_data
709+ report = EstimatorReport (
710+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
711+ )
712+ display = report .metrics .prediction_error ()
713+ display .plot (subplots = True )
714+
715+ # For a single estimator, we should get a single plot
716+ assert hasattr (display , "figure_" )
717+ axes = display .figure_ .get_axes ()
718+ assert len (axes ) == 1
719+ assert "Model: LinearRegression" in axes [0 ].get_title ()
0 commit comments