@@ -631,3 +631,228 @@ def test_precision_recall_curve_display_wrong_report_type(
631631 )
632632 with pytest .raises (ValueError , match = err_msg ):
633633 display .plot ()
634+
635+
636+ def test_precision_recall_curve_display_subplots_basic_binary (
637+ pyplot , binary_classification_data
638+ ):
639+ """Test that subplots=True creates multiple subplots with default parameters
640+ for binary classification."""
641+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
642+
643+ # Create a comparison report with multiple estimators
644+ est1 = clone (estimator )
645+ est2 = clone (estimator )
646+ est1 .fit (X_train , y_train )
647+ est2 .fit (X_train , y_train )
648+
649+ report = ComparisonReport (
650+ reports = {
651+ "estimator 1" : EstimatorReport (
652+ est1 ,
653+ X_train = X_train ,
654+ y_train = y_train ,
655+ X_test = X_test ,
656+ y_test = y_test ,
657+ ),
658+ "estimator 2" : EstimatorReport (
659+ est2 ,
660+ X_train = X_train ,
661+ y_train = y_train ,
662+ X_test = X_test ,
663+ y_test = y_test ,
664+ ),
665+ },
666+ )
667+ display = report .metrics .precision_recall ()
668+ display .plot (subplots = True )
669+
670+ assert hasattr (display , "figure_" )
671+
672+ # Check correct number of subplots
673+ axes = display .figure_ .get_axes ()
674+ assert len (axes ) == 2
675+
676+ # Check titles were set correctly
677+ assert "Model: estimator 1" in axes [0 ].get_title ()
678+ assert "Model: estimator 2" in axes [1 ].get_title ()
679+
680+ # Each subplot should have correct labels
681+ for ax in axes :
682+ assert "Recall" in ax .get_xlabel ()
683+ assert "Precision" in ax .get_ylabel ()
684+ assert ax .get_aspect () in ("equal" , 1.0 )
685+
686+
687+ def test_precision_recall_curve_display_subplots_basic_multiclass (
688+ pyplot , multiclass_classification_data
689+ ):
690+ """Test that subplots=True creates multiple subplots with default parameters
691+ for multiclass classification."""
692+ estimator , X_train , X_test , y_train , y_test = multiclass_classification_data
693+ report = EstimatorReport (
694+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
695+ )
696+
697+ # In multiclass case, we get one subplot per class
698+ display = report .metrics .precision_recall ()
699+ display .plot (subplots = True )
700+
701+ assert hasattr (display , "figure_" )
702+
703+ # Check correct number of subplots (one per class)
704+ axes = display .figure_ .get_axes ()
705+ assert len (axes ) == len (estimator .classes_ )
706+
707+ # Check titles were set correctly
708+ for i , class_label in enumerate (estimator .classes_ ):
709+ assert f"Class: { class_label } " in axes [i ].get_title ()
710+
711+ # Each subplot should have correct labels
712+ for ax in axes :
713+ assert "Recall" in ax .get_xlabel ()
714+ assert "Precision" in ax .get_ylabel ()
715+ assert ax .get_aspect () in ("equal" , 1.0 )
716+
717+
718+ def test_precision_recall_curve_display_subplots_cv_binary (
719+ pyplot , binary_classification_data_no_split
720+ ):
721+ """Test subplots with cross-validation for binary classification."""
722+ (estimator , X , y ), cv = binary_classification_data_no_split , 3
723+ report = CrossValidationReport (estimator , X = X , y = y , cv_splitter = cv )
724+ display = report .metrics .precision_recall ()
725+ display .plot (subplots = True )
726+
727+ assert hasattr (display , "figure_" )
728+
729+ # Check number of subplots matches number of CV folds
730+ axes = display .figure_ .get_axes ()
731+ assert len (axes ) == cv
732+
733+ # Check titles for each fold
734+ for i , ax in enumerate (axes ):
735+ assert f"Fold #{ i + 1 } " in ax .get_title ()
736+
737+ # Each subplot should have correct labels
738+ for ax in axes :
739+ assert "Recall" in ax .get_xlabel ()
740+ assert "Precision" in ax .get_ylabel ()
741+ assert ax .get_aspect () in ("equal" , 1.0 )
742+
743+
744+ def test_precision_recall_curve_display_subplots_custom_layout (
745+ pyplot , binary_classification_data
746+ ):
747+ """Test subplots with custom layout parameters."""
748+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
749+
750+ # Create a comparison report with multiple estimators
751+ est1 = clone (estimator )
752+ est2 = clone (estimator )
753+ est3 = clone (estimator )
754+ est4 = clone (estimator )
755+ est1 .fit (X_train , y_train )
756+ est2 .fit (X_train , y_train )
757+ est3 .fit (X_train , y_train )
758+ est4 .fit (X_train , y_train )
759+
760+ report = ComparisonReport (
761+ reports = {
762+ "estimator 1" : EstimatorReport (
763+ est1 ,
764+ X_train = X_train ,
765+ y_train = y_train ,
766+ X_test = X_test ,
767+ y_test = y_test ,
768+ ),
769+ "estimator 2" : EstimatorReport (
770+ est2 ,
771+ X_train = X_train ,
772+ y_train = y_train ,
773+ X_test = X_test ,
774+ y_test = y_test ,
775+ ),
776+ "estimator 3" : EstimatorReport (
777+ est3 ,
778+ X_train = X_train ,
779+ y_train = y_train ,
780+ X_test = X_test ,
781+ y_test = y_test ,
782+ ),
783+ "estimator 4" : EstimatorReport (
784+ est4 ,
785+ X_train = X_train ,
786+ y_train = y_train ,
787+ X_test = X_test ,
788+ y_test = y_test ,
789+ ),
790+ },
791+ )
792+ display = report .metrics .precision_recall ()
793+
794+ # Test with custom nrows and ncols
795+ figsize = (12 , 10 )
796+ display .plot (subplots = True , nrows = 2 , ncols = 2 , figsize = figsize )
797+
798+ # Check figure was created with correct size
799+ assert hasattr (display , "figure_" )
800+ assert display .figure_ .get_size_inches ()[0 ] == figsize [0 ]
801+ assert display .figure_ .get_size_inches ()[1 ] == figsize [1 ]
802+
803+ # Check layout is correct
804+ axes = display .figure_ .get_axes ()
805+ assert len (axes ) == 4
806+
807+ # Check subplot arrangement (2 rows, 2 columns)
808+ pos1 = axes [0 ].get_position ()
809+ pos2 = axes [1 ].get_position ()
810+ pos3 = axes [2 ].get_position ()
811+ pos4 = axes [3 ].get_position ()
812+
813+ # First row: similar y positions for axes 0 and 1
814+ assert abs (pos1 .y0 - pos2 .y0 ) < 0.1
815+ # Second row: similar y positions for axes 2 and 3
816+ assert abs (pos3 .y0 - pos4 .y0 ) < 0.1
817+ # First column: similar x positions for axes 0 and 2
818+ assert abs (pos1 .x0 - pos3 .x0 ) < 0.1
819+ # Second column: similar x positions for axes 1 and 3
820+ assert abs (pos2 .x0 - pos4 .x0 ) < 0.1
821+
822+
823+ def test_precision_recall_curve_display_ax_and_subplots_error (
824+ pyplot , binary_classification_data
825+ ):
826+ """Test that an error is raised when both ax and subplots=True are specified."""
827+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
828+ report = EstimatorReport (
829+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
830+ )
831+ display = report .metrics .precision_recall ()
832+
833+ # Create a figure and axis to pass
834+ fig , ax = pyplot .subplots ()
835+
836+ # Test that error is raised when both ax and subplots=True are specified
837+ with pytest .raises (
838+ ValueError , match = "Cannot specify both 'ax' and 'subplots=True'"
839+ ):
840+ display .plot (ax = ax , subplots = True )
841+
842+
843+ def test_precision_recall_curve_display_subplots_estimator_report (
844+ pyplot , binary_classification_data
845+ ):
846+ """Test subplots with simple estimator report (should be a single plot)."""
847+ estimator , X_train , X_test , y_train , y_test = binary_classification_data
848+ report = EstimatorReport (
849+ estimator , X_train = X_train , y_train = y_train , X_test = X_test , y_test = y_test
850+ )
851+ display = report .metrics .precision_recall ()
852+ display .plot (subplots = True )
853+
854+ # For a single estimator, we should get a single plot
855+ assert hasattr (display , "figure_" )
856+ axes = display .figure_ .get_axes ()
857+ assert len (axes ) == 1
858+ assert "Model: LogisticRegression" in axes [0 ].get_title ()
0 commit comments