Skip to content

Commit 97f7295

Browse files
committed
test: add cases for different subplots functionality for PredictionErrorDisplay
1 parent d547033 commit 97f7295

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

skore/tests/unit/sklearn/plot/test_prediction_error.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,143 @@ def test_prediction_error_display_wrong_report_type(pyplot, regression_data):
577577
)
578578
with pytest.raises(ValueError, match=err_msg):
579579
display.plot()
580+
581+
582+
def test_prediction_error_display_subplots_basic(pyplot, regression_data):
583+
"""Test that subplots=True creates multiple subplots with default parameters."""
584+
estimator, X_train, X_test, y_train, y_test = regression_data
585+
report = ComparisonReport(
586+
reports={
587+
"estimator 1": EstimatorReport(
588+
estimator,
589+
X_train=X_train,
590+
y_train=y_train,
591+
X_test=X_test,
592+
y_test=y_test,
593+
),
594+
"estimator 2": EstimatorReport(
595+
estimator,
596+
X_train=X_train,
597+
y_train=y_train,
598+
X_test=X_test,
599+
y_test=y_test,
600+
),
601+
},
602+
)
603+
display = report.metrics.prediction_error()
604+
display.plot(subplots=True)
605+
606+
assert hasattr(display, "figure_")
607+
assert len(display.scatter_) == 2
608+
609+
axes = display.figure_.get_axes()
610+
assert len(axes) == 2
611+
assert "Model: estimator 1" in axes[0].get_title()
612+
assert "Model: estimator 2" in axes[1].get_title()
613+
614+
615+
def test_prediction_error_display_subplots_custom_layout(pyplot, regression_data):
616+
"""Test subplots with custom layout parameters."""
617+
estimator, X_train, X_test, y_train, y_test = regression_data
618+
report = ComparisonReport(
619+
reports={
620+
"estimator 1": EstimatorReport(
621+
estimator,
622+
X_train=X_train,
623+
y_train=y_train,
624+
X_test=X_test,
625+
y_test=y_test,
626+
),
627+
"estimator 2": EstimatorReport(
628+
estimator,
629+
X_train=X_train,
630+
y_train=y_train,
631+
X_test=X_test,
632+
y_test=y_test,
633+
),
634+
"estimator 3": EstimatorReport(
635+
estimator,
636+
X_train=X_train,
637+
y_train=y_train,
638+
X_test=X_test,
639+
y_test=y_test,
640+
),
641+
},
642+
)
643+
display = report.metrics.prediction_error()
644+
645+
figsize = (10, 8)
646+
display.plot(subplots=True, nrows=3, ncols=1, figsize=figsize)
647+
648+
assert hasattr(display, "figure_")
649+
assert display.figure_.get_size_inches()[0] == figsize[0]
650+
assert display.figure_.get_size_inches()[1] == figsize[1]
651+
652+
axes = display.figure_.get_axes()
653+
assert len(axes) == 3
654+
655+
pos1 = axes[0].get_position()
656+
pos2 = axes[1].get_position()
657+
pos3 = axes[2].get_position()
658+
659+
# Same column (similar x positions)
660+
assert abs(pos1.x0 - pos2.x0) < 0.1
661+
assert abs(pos2.x0 - pos3.x0) < 0.1
662+
663+
# Different rows (decreasing y positions)
664+
assert pos1.y0 > pos2.y0
665+
assert pos2.y0 > pos3.y0
666+
667+
668+
def test_prediction_error_display_subplots_cross_validation(
669+
pyplot, regression_data_no_split
670+
):
671+
"""Test subplots with cross-validation data."""
672+
(estimator, X, y), cv = regression_data_no_split, 3
673+
report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=cv)
674+
display = report.metrics.prediction_error()
675+
display.plot(subplots=True)
676+
677+
assert hasattr(display, "figure_")
678+
679+
# Check number of subplots matches number of CV folds
680+
axes = display.figure_.get_axes()
681+
assert len(axes) == cv
682+
683+
# Check titles for each fold
684+
for i, ax in enumerate(axes):
685+
assert f"Fold #{i + 1}" in ax.get_title()
686+
687+
688+
def test_prediction_error_display_ax_and_subplots_error(pyplot, regression_data):
689+
"""Test that an error is raised when both ax and subplots=True are specified."""
690+
estimator, X_train, X_test, y_train, y_test = regression_data
691+
report = EstimatorReport(
692+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
693+
)
694+
display = report.metrics.prediction_error()
695+
696+
# Create a figure and axis to pass
697+
fig, ax = pyplot.subplots()
698+
699+
# Test that error is raised when both ax and subplots=True are specified
700+
with pytest.raises(
701+
ValueError, match="Cannot specify both 'ax' and 'subplots=True'"
702+
):
703+
display.plot(ax=ax, subplots=True)
704+
705+
706+
def test_prediction_error_display_subplots_estimator_report(pyplot, regression_data):
707+
"""Test subplots with simple estimator report (should be a single plot)."""
708+
estimator, X_train, X_test, y_train, y_test = regression_data
709+
report = EstimatorReport(
710+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
711+
)
712+
display = report.metrics.prediction_error()
713+
display.plot(subplots=True)
714+
715+
# For a single estimator, we should get a single plot
716+
assert hasattr(display, "figure_")
717+
axes = display.figure_.get_axes()
718+
assert len(axes) == 1
719+
assert "Model: LinearRegression" in axes[0].get_title()

0 commit comments

Comments
 (0)