Skip to content

Commit 696dfdb

Browse files
committed
test: add ROC curve subplot test
1 parent 4941dab commit 696dfdb

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed

skore/tests/unit/sklearn/plot/test_roc_curve.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,190 @@ def test_roc_curve_display_wrong_report_type(pyplot, binary_classification_data)
671671
)
672672
with pytest.raises(ValueError, match=err_msg):
673673
display.plot()
674+
675+
676+
def test_roc_curve_display_subplots_basic_binary(pyplot, binary_classification_data):
677+
"""Test that subplots=True creates multiple subplots with default parameters
678+
for binary classification."""
679+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
680+
681+
# Create a comparison report with multiple estimators
682+
est1 = clone(estimator)
683+
est2 = clone(estimator)
684+
est1.fit(X_train, y_train)
685+
est2.fit(X_train, y_train)
686+
687+
report = ComparisonReport(
688+
reports={
689+
"estimator 1": EstimatorReport(
690+
est1,
691+
X_train=X_train,
692+
y_train=y_train,
693+
X_test=X_test,
694+
y_test=y_test,
695+
),
696+
"estimator 2": EstimatorReport(
697+
est2,
698+
X_train=X_train,
699+
y_train=y_train,
700+
X_test=X_test,
701+
y_test=y_test,
702+
),
703+
},
704+
)
705+
display = report.metrics.roc()
706+
display.plot(subplots=True)
707+
708+
assert hasattr(display, "figure_")
709+
710+
axes = display.figure_.get_axes()
711+
assert len(axes) == 2
712+
713+
assert "Model: estimator 1" in axes[0].get_title()
714+
assert "Model: estimator 2" in axes[1].get_title()
715+
716+
# Each subplot should have correct labels
717+
for ax in axes:
718+
assert "False Positive Rate" in ax.get_xlabel()
719+
assert "True Positive Rate" in ax.get_ylabel()
720+
assert ax.get_aspect() in ("equal", 1.0)
721+
722+
723+
def test_roc_curve_display_subplots_basic_multiclass(
724+
pyplot, multiclass_classification_data
725+
):
726+
"""Test that subplots=True creates multiple subplots with default parameters
727+
for multiclass classification."""
728+
estimator, X_train, X_test, y_train, y_test = multiclass_classification_data
729+
report = EstimatorReport(
730+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
731+
)
732+
733+
# In multiclass case, we should get one subplot per class
734+
display = report.metrics.roc()
735+
display.plot(subplots=True)
736+
737+
assert hasattr(display, "figure_")
738+
739+
# Check correct number of subplots (one per class)
740+
axes = display.figure_.get_axes()
741+
assert len(axes) == len(estimator.classes_)
742+
743+
for i, class_label in enumerate(estimator.classes_):
744+
assert f"Class: {class_label}" in axes[i].get_title()
745+
746+
# Each subplot should have correct labels
747+
for ax in axes:
748+
assert "False Positive Rate" in ax.get_xlabel()
749+
assert "True Positive Rate" in ax.get_ylabel()
750+
assert ax.get_aspect() in ("equal", 1.0)
751+
752+
753+
def test_roc_curve_display_subplots_cv_binary(
754+
pyplot, binary_classification_data_no_split
755+
):
756+
"""Test subplots with cross-validation for binary classification."""
757+
(estimator, X, y), cv = binary_classification_data_no_split, 3
758+
report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv)
759+
display = report.metrics.roc()
760+
display.plot(subplots=True)
761+
762+
assert hasattr(display, "figure_")
763+
764+
# Check number of subplots matches number of CV folds
765+
axes = display.figure_.get_axes()
766+
assert len(axes) == cv
767+
768+
# Check titles for each fold
769+
for i, ax in enumerate(axes):
770+
assert f"Fold #{i + 1}" in ax.get_title()
771+
772+
# Each subplot should have correct labels
773+
for ax in axes:
774+
assert "False Positive Rate" in ax.get_xlabel()
775+
assert "True Positive Rate" in ax.get_ylabel()
776+
assert ax.get_aspect() in ("equal", 1.0)
777+
778+
779+
def test_roc_curve_display_subplots_custom_layout(pyplot, binary_classification_data):
780+
"""Test subplots with custom layout parameters."""
781+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
782+
783+
# Create a comparison report with multiple estimators
784+
est1 = clone(estimator)
785+
est2 = clone(estimator)
786+
est3 = clone(estimator)
787+
est1.fit(X_train, y_train)
788+
est2.fit(X_train, y_train)
789+
est3.fit(X_train, y_train)
790+
791+
report = ComparisonReport(
792+
reports={
793+
"estimator 1": EstimatorReport(
794+
est1,
795+
X_train=X_train,
796+
y_train=y_train,
797+
X_test=X_test,
798+
y_test=y_test,
799+
),
800+
"estimator 2": EstimatorReport(
801+
est2,
802+
X_train=X_train,
803+
y_train=y_train,
804+
X_test=X_test,
805+
y_test=y_test,
806+
),
807+
"estimator 3": EstimatorReport(
808+
est3,
809+
X_train=X_train,
810+
y_train=y_train,
811+
X_test=X_test,
812+
y_test=y_test,
813+
),
814+
},
815+
)
816+
display = report.metrics.roc()
817+
818+
# Test with custom nrows and ncols
819+
figsize = (10, 8)
820+
display.plot(subplots=True, nrows=1, ncols=3, figsize=figsize)
821+
822+
# Check figure was created with correct size
823+
assert hasattr(display, "figure_")
824+
assert display.figure_.get_size_inches()[0] == figsize[0]
825+
assert display.figure_.get_size_inches()[1] == figsize[1]
826+
827+
# Check layout is correct
828+
axes = display.figure_.get_axes()
829+
assert len(axes) == 3
830+
831+
# Check subplot arrangement (1 row, 3 columns)
832+
pos1 = axes[0].get_position()
833+
pos2 = axes[1].get_position()
834+
pos3 = axes[2].get_position()
835+
836+
# Same row (similar y positions)
837+
assert abs(pos1.y0 - pos2.y0) < 0.1
838+
assert abs(pos2.y0 - pos3.y0) < 0.1
839+
840+
# Different columns (increasing x positions)
841+
assert pos1.x0 < pos2.x0
842+
assert pos2.x0 < pos3.x0
843+
844+
845+
def test_roc_curve_display_ax_and_subplots_error(pyplot, binary_classification_data):
846+
"""Test that an error is raised when both ax and subplots=True are specified."""
847+
estimator, X_train, X_test, y_train, y_test = binary_classification_data
848+
report = EstimatorReport(
849+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
850+
)
851+
display = report.metrics.roc()
852+
853+
# Create a figure and axis to pass
854+
fig, ax = pyplot.subplots()
855+
856+
# Test that error is raised when both ax and subplots=True are specified
857+
with pytest.raises(
858+
ValueError, match="Cannot specify both 'ax' and 'subplots=True'"
859+
):
860+
display.plot(ax=ax, subplots=True)

0 commit comments

Comments
 (0)