|
1 | 1 | import pandas as pd |
2 | 2 | import pytest |
3 | 3 | from pandas.testing import assert_index_equal |
4 | | -from sklearn.datasets import make_classification |
5 | | -from sklearn.dummy import DummyClassifier |
| 4 | +from sklearn.datasets import make_classification, make_regression |
| 5 | +from sklearn.dummy import DummyClassifier, DummyRegressor |
6 | 6 | from skore import CrossValidationComparisonReport, CrossValidationReport |
7 | 7 |
|
8 | 8 |
|
@@ -162,12 +162,47 @@ def case_aggregate_different_split_numbers(): |
162 | 162 | return report, kwargs, expected_index, expected_columns |
163 | 163 |
|
164 | 164 |
|
| 165 | +def comparison_report_regression(): |
| 166 | + X, y = make_regression(random_state=42) |
| 167 | + |
| 168 | + report = CrossValidationComparisonReport( |
| 169 | + [ |
| 170 | + CrossValidationReport(DummyRegressor(), X, y), |
| 171 | + CrossValidationReport(DummyRegressor(), X, y, cv_splitter=3), |
| 172 | + ] |
| 173 | + ) |
| 174 | + |
| 175 | + return report |
| 176 | + |
| 177 | + |
| 178 | +def case_regression(): |
| 179 | + kwargs = {} |
| 180 | + |
| 181 | + expected_columns = pd.Index(["mean", "std"]) |
| 182 | + expected_index = pd.MultiIndex.from_tuples( |
| 183 | + [ |
| 184 | + ("R²", "DummyRegressor_1"), |
| 185 | + ("RMSE", "DummyRegressor_1"), |
| 186 | + ("Fit time", "DummyRegressor_1"), |
| 187 | + ("Predict time", "DummyRegressor_1"), |
| 188 | + ("R²", "DummyRegressor_2"), |
| 189 | + ("RMSE", "DummyRegressor_2"), |
| 190 | + ("Fit time", "DummyRegressor_2"), |
| 191 | + ("Predict time", "DummyRegressor_2"), |
| 192 | + ], |
| 193 | + names=["Metric", "Estimator"], |
| 194 | + ) |
| 195 | + |
| 196 | + return comparison_report_regression(), kwargs, expected_index, expected_columns |
| 197 | + |
| 198 | + |
165 | 199 | @pytest.mark.parametrize( |
166 | 200 | "case", |
167 | 201 | [ |
168 | 202 | case_different_split_numbers, |
169 | 203 | case_flat_index_different_split_numbers, |
170 | 204 | case_aggregate_different_split_numbers, |
| 205 | + case_regression, |
171 | 206 | ], |
172 | 207 | ) |
173 | 208 | def test_report_metrics(case): |
|
0 commit comments