Skip to content

Commit ee1be03

Browse files
add test for regression
1 parent a48295a commit ee1be03

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

skore/tests/unit/sklearn/comparison/test_compare_cross_validation_reports.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import pandas as pd
22
import pytest
33
from pandas.testing import assert_index_equal
4-
from sklearn.datasets import make_classification
5-
from sklearn.dummy import DummyClassifier
4+
from sklearn.datasets import make_classification, make_regression
5+
from sklearn.dummy import DummyClassifier, DummyRegressor
66
from skore import CrossValidationComparisonReport, CrossValidationReport
77

88

@@ -162,12 +162,47 @@ def case_aggregate_different_split_numbers():
162162
return report, kwargs, expected_index, expected_columns
163163

164164

165+
def comparison_report_regression():
166+
X, y = make_regression(random_state=42)
167+
168+
report = CrossValidationComparisonReport(
169+
[
170+
CrossValidationReport(DummyRegressor(), X, y),
171+
CrossValidationReport(DummyRegressor(), X, y, cv_splitter=3),
172+
]
173+
)
174+
175+
return report
176+
177+
178+
def case_regression():
179+
kwargs = {}
180+
181+
expected_columns = pd.Index(["mean", "std"])
182+
expected_index = pd.MultiIndex.from_tuples(
183+
[
184+
("R²", "DummyRegressor_1"),
185+
("RMSE", "DummyRegressor_1"),
186+
("Fit time", "DummyRegressor_1"),
187+
("Predict time", "DummyRegressor_1"),
188+
("R²", "DummyRegressor_2"),
189+
("RMSE", "DummyRegressor_2"),
190+
("Fit time", "DummyRegressor_2"),
191+
("Predict time", "DummyRegressor_2"),
192+
],
193+
names=["Metric", "Estimator"],
194+
)
195+
196+
return comparison_report_regression(), kwargs, expected_index, expected_columns
197+
198+
165199
@pytest.mark.parametrize(
166200
"case",
167201
[
168202
case_different_split_numbers,
169203
case_flat_index_different_split_numbers,
170204
case_aggregate_different_split_numbers,
205+
case_regression,
171206
],
172207
)
173208
def test_report_metrics(case):

0 commit comments

Comments
 (0)