Skip to content

Commit d547033

Browse files
committed
test: add cases for different subplots functionality for PrecisionRecallCurveDisplay
1 parent 696dfdb commit d547033

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

skore/tests/unit/sklearn/plot/test_precision_recall_curve.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,228 @@ def test_precision_recall_curve_display_wrong_report_type(
631631
)
632632
with pytest.raises(ValueError, match=err_msg):
633633
display.plot()
634+
635+
636+
def test_precision_recall_curve_display_subplots_basic_binary(
637+
pyplot, binary_classification_data
638+
):
639+
"""Test that subplots=True creates multiple subplots with default parameters
640+
for binary classification."""
641+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
642+
643+
# Create a comparison report with multiple estimators
644+
est1 = clone(estimator)
645+
est2 = clone(estimator)
646+
est1.fit(X_train, y_train)
647+
est2.fit(X_train, y_train)
648+
649+
report = ComparisonReport(
650+
reports={
651+
"estimator 1": EstimatorReport(
652+
est1,
653+
X_train=X_train,
654+
y_train=y_train,
655+
X_test=X_test,
656+
y_test=y_test,
657+
),
658+
"estimator 2": EstimatorReport(
659+
est2,
660+
X_train=X_train,
661+
y_train=y_train,
662+
X_test=X_test,
663+
y_test=y_test,
664+
),
665+
},
666+
)
667+
display = report.metrics.precision_recall()
668+
display.plot(subplots=True)
669+
670+
assert hasattr(display, "figure_")
671+
672+
# Check correct number of subplots
673+
axes = display.figure_.get_axes()
674+
assert len(axes) == 2
675+
676+
# Check titles were set correctly
677+
assert "Model: estimator 1" in axes[0].get_title()
678+
assert "Model: estimator 2" in axes[1].get_title()
679+
680+
# Each subplot should have correct labels
681+
for ax in axes:
682+
assert "Recall" in ax.get_xlabel()
683+
assert "Precision" in ax.get_ylabel()
684+
assert ax.get_aspect() in ("equal", 1.0)
685+
686+
687+
def test_precision_recall_curve_display_subplots_basic_multiclass(
688+
pyplot, multiclass_classification_data
689+
):
690+
"""Test that subplots=True creates multiple subplots with default parameters
691+
for multiclass classification."""
692+
estimator, X_train, X_test, y_train, y_test = multiclass_classification_data
693+
report = EstimatorReport(
694+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
695+
)
696+
697+
# In multiclass case, we get one subplot per class
698+
display = report.metrics.precision_recall()
699+
display.plot(subplots=True)
700+
701+
assert hasattr(display, "figure_")
702+
703+
# Check correct number of subplots (one per class)
704+
axes = display.figure_.get_axes()
705+
assert len(axes) == len(estimator.classes_)
706+
707+
# Check titles were set correctly
708+
for i, class_label in enumerate(estimator.classes_):
709+
assert f"Class: {class_label}" in axes[i].get_title()
710+
711+
# Each subplot should have correct labels
712+
for ax in axes:
713+
assert "Recall" in ax.get_xlabel()
714+
assert "Precision" in ax.get_ylabel()
715+
assert ax.get_aspect() in ("equal", 1.0)
716+
717+
718+
def test_precision_recall_curve_display_subplots_cv_binary(
719+
pyplot, binary_classification_data_no_split
720+
):
721+
"""Test subplots with cross-validation for binary classification."""
722+
(estimator, X, y), cv = binary_classification_data_no_split, 3
723+
report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv)
724+
display = report.metrics.precision_recall()
725+
display.plot(subplots=True)
726+
727+
assert hasattr(display, "figure_")
728+
729+
# Check number of subplots matches number of CV folds
730+
axes = display.figure_.get_axes()
731+
assert len(axes) == cv
732+
733+
# Check titles for each fold
734+
for i, ax in enumerate(axes):
735+
assert f"Fold #{i + 1}" in ax.get_title()
736+
737+
# Each subplot should have correct labels
738+
for ax in axes:
739+
assert "Recall" in ax.get_xlabel()
740+
assert "Precision" in ax.get_ylabel()
741+
assert ax.get_aspect() in ("equal", 1.0)
742+
743+
744+
def test_precision_recall_curve_display_subplots_custom_layout(
745+
pyplot, binary_classification_data
746+
):
747+
"""Test subplots with custom layout parameters."""
748+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
749+
750+
# Create a comparison report with multiple estimators
751+
est1 = clone(estimator)
752+
est2 = clone(estimator)
753+
est3 = clone(estimator)
754+
est4 = clone(estimator)
755+
est1.fit(X_train, y_train)
756+
est2.fit(X_train, y_train)
757+
est3.fit(X_train, y_train)
758+
est4.fit(X_train, y_train)
759+
760+
report = ComparisonReport(
761+
reports={
762+
"estimator 1": EstimatorReport(
763+
est1,
764+
X_train=X_train,
765+
y_train=y_train,
766+
X_test=X_test,
767+
y_test=y_test,
768+
),
769+
"estimator 2": EstimatorReport(
770+
est2,
771+
X_train=X_train,
772+
y_train=y_train,
773+
X_test=X_test,
774+
y_test=y_test,
775+
),
776+
"estimator 3": EstimatorReport(
777+
est3,
778+
X_train=X_train,
779+
y_train=y_train,
780+
X_test=X_test,
781+
y_test=y_test,
782+
),
783+
"estimator 4": EstimatorReport(
784+
est4,
785+
X_train=X_train,
786+
y_train=y_train,
787+
X_test=X_test,
788+
y_test=y_test,
789+
),
790+
},
791+
)
792+
display = report.metrics.precision_recall()
793+
794+
# Test with custom nrows and ncols
795+
figsize = (12, 10)
796+
display.plot(subplots=True, nrows=2, ncols=2, figsize=figsize)
797+
798+
# Check figure was created with correct size
799+
assert hasattr(display, "figure_")
800+
assert display.figure_.get_size_inches()[0] == figsize[0]
801+
assert display.figure_.get_size_inches()[1] == figsize[1]
802+
803+
# Check layout is correct
804+
axes = display.figure_.get_axes()
805+
assert len(axes) == 4
806+
807+
# Check subplot arrangement (2 rows, 2 columns)
808+
pos1 = axes[0].get_position()
809+
pos2 = axes[1].get_position()
810+
pos3 = axes[2].get_position()
811+
pos4 = axes[3].get_position()
812+
813+
# First row: similar y positions for axes 0 and 1
814+
assert abs(pos1.y0 - pos2.y0) < 0.1
815+
# Second row: similar y positions for axes 2 and 3
816+
assert abs(pos3.y0 - pos4.y0) < 0.1
817+
# First column: similar x positions for axes 0 and 2
818+
assert abs(pos1.x0 - pos3.x0) < 0.1
819+
# Second column: similar x positions for axes 1 and 3
820+
assert abs(pos2.x0 - pos4.x0) < 0.1
821+
822+
823+
def test_precision_recall_curve_display_ax_and_subplots_error(
824+
pyplot, binary_classification_data
825+
):
826+
"""Test that an error is raised when both ax and subplots=True are specified."""
827+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
828+
report = EstimatorReport(
829+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
830+
)
831+
display = report.metrics.precision_recall()
832+
833+
# Create a figure and axis to pass
834+
fig, ax = pyplot.subplots()
835+
836+
# Test that error is raised when both ax and subplots=True are specified
837+
with pytest.raises(
838+
ValueError, match="Cannot specify both 'ax' and 'subplots=True'"
839+
):
840+
display.plot(ax=ax, subplots=True)
841+
842+
843+
def test_precision_recall_curve_display_subplots_estimator_report(
844+
pyplot, binary_classification_data
845+
):
846+
"""Test subplots with simple estimator report (should be a single plot)."""
847+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
848+
report = EstimatorReport(
849+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
850+
)
851+
display = report.metrics.precision_recall()
852+
display.plot(subplots=True)
853+
854+
# For a single estimator, we should get a single plot
855+
assert hasattr(display, "figure_")
856+
axes = display.figure_.get_axes()
857+
assert len(axes) == 1
858+
assert "Model: LogisticRegression" in axes[0].get_title()

0 commit comments

Comments
 (0)