|
5 | 5 | from sklearn import datasets, linear_model |
6 | 6 | from sklearn.cluster import KMeans |
7 | 7 | from sklearn.ensemble import RandomForestClassifier |
| 8 | +from sklearn.linear_model import LinearRegression, LogisticRegression |
| 9 | +from sklearn.model_selection import cross_validate |
8 | 10 | from sklearn.multiclass import OneVsOneClassifier |
9 | 11 | from sklearn.svm import SVC |
10 | 12 | from skore import CrossValidationReporter |
11 | 13 | from skore.item.cross_validation_item import CrossValidationItem |
| 14 | +from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add |
12 | 15 |
|
13 | 16 |
|
14 | 17 | @pytest.fixture |
@@ -155,23 +158,103 @@ def test_cross_validate(self, in_memory_project, get_args, fake_cross_validate): |
155 | 158 | assert all(len(v) == 5 for v in cv_results.values()) |
156 | 159 |
|
157 | 160 |
|
158 | | -def prepare_cv(): |
159 | | - from sklearn import datasets, linear_model |
| 161 | +@pytest.fixture |
| 162 | +def binary_classifier(): |
| 163 | + X, y = datasets.make_classification(n_classes=2, random_state=42) |
| 164 | + return LogisticRegression(), X, y |
| 165 | + |
160 | 166 |
|
161 | | - diabetes = datasets.load_diabetes() |
162 | | - X = diabetes.data[:150] |
163 | | - y = diabetes.target[:150] |
164 | | - lasso = linear_model.Lasso() |
| 167 | +@pytest.fixture |
| 168 | +def multiclass_classifier(): |
| 169 | + X, y = datasets.make_classification( |
| 170 | + n_classes=3, n_clusters_per_class=1, random_state=42 |
| 171 | + ) |
| 172 | + return LogisticRegression(), X, y |
165 | 173 |
|
166 | | - return lasso, X, y |
167 | 174 |
|
| 175 | +@pytest.fixture |
| 176 | +def single_output_regression(): |
| 177 | + X, y = datasets.make_regression(n_targets=1, random_state=42) |
| 178 | + return LinearRegression(), X, y |
168 | 179 |
|
169 | | -def test_put_cross_validation_reporter(in_memory_project): |
170 | | - project = in_memory_project |
171 | 180 |
|
172 | | - lasso, X, y = prepare_cv() |
173 | | - reporter = CrossValidationReporter(lasso, X, y, cv=3) |
| 181 | +@pytest.fixture |
| 182 | +def multi_output_regression(): |
| 183 | + X, y = datasets.make_regression(n_targets=2, random_state=42) |
| 184 | + return LinearRegression(), X, y |
| 185 | + |
| 186 | + |
| 187 | +@pytest.mark.parametrize( |
| 188 | + "fixture_name", |
| 189 | + [ |
| 190 | + "binary_classifier", |
| 191 | + "multiclass_classifier", |
| 192 | + "single_output_regression", |
| 193 | + "multi_output_regression", |
| 194 | + ], |
| 195 | +) |
| 196 | +def test_cross_validation_reporter(in_memory_project, fixture_name, request): |
| 197 | + """Check that we can serialize the `CrossValidationReporter` and retrieve it.""" |
| 198 | + model, X, y = request.getfixturevalue(fixture_name) |
| 199 | + reporter = CrossValidationReporter(model, X, y, cv=3) |
| 200 | + |
| 201 | + in_memory_project.put("cross-validation", reporter) |
| 202 | + |
| 203 | + retrieved_item = in_memory_project.get_item("cross-validation") |
| 204 | + assert isinstance(retrieved_item, CrossValidationItem) |
| 205 | + |
| 206 | + |
| 207 | +@pytest.mark.parametrize( |
| 208 | + "fixture_name", |
| 209 | + [ |
| 210 | + "binary_classifier", |
| 211 | + "multiclass_classifier", |
| 212 | + "single_output_regression", |
| 213 | + "multi_output_regression", |
| 214 | + ], |
| 215 | +) |
| 216 | +def test_cross_validation_reporter_equivalence_cross_validate( |
| 217 | + in_memory_project, fixture_name, request |
| 218 | +): |
| 219 | + """Check that we have an equivalent result to `cross_validate`.""" |
| 220 | + # mapping between the scorers names in skore and in sklearn |
| 221 | + map_skore_to_sklearn = { |
| 222 | + "r2": "r2", |
| 223 | + "root_mean_squared_error": "neg_root_mean_squared_error", |
| 224 | + "roc_auc": "roc_auc", |
| 225 | + "brier_score_loss": "neg_brier_score", |
| 226 | + "recall": "recall", |
| 227 | + "precision": "precision", |
| 228 | + "recall_weighted": "recall_weighted", |
| 229 | + "precision_weighted": "precision_weighted", |
| 230 | + "roc_auc_ovr_weighted": "roc_auc_ovr_weighted", |
| 231 | + "log_loss": "neg_log_loss", |
| 232 | + } |
| 233 | + model, X, y = request.getfixturevalue(fixture_name) |
| 234 | + reporter = CrossValidationReporter( |
| 235 | + model, X, y, cv=3, return_estimator=True, return_train_score=True |
| 236 | + ) |
174 | 237 |
|
175 | | - project.put("cross-validation", reporter) |
| 238 | + scorers_used_skore = _get_scorers_to_add(model, y) |
| 239 | + scorers_sklearn = [map_skore_to_sklearn[k] for k in scorers_used_skore] |
| 240 | + cv_results_sklearn = cross_validate( |
| 241 | + model, |
| 242 | + X, |
| 243 | + y, |
| 244 | + cv=3, |
| 245 | + scoring=scorers_sklearn, |
| 246 | + return_estimator=True, |
| 247 | + return_train_score=True, |
| 248 | + ) |
176 | 249 |
|
177 | | - assert isinstance(project.get_item("cross-validation"), CrossValidationItem) |
| 250 | + # check the equivalence between the scores |
| 251 | + for scorer_skore_name in scorers_used_skore: |
| 252 | + for type_set in ["test", "train"]: |
| 253 | + score_skore = reporter._cv_results[f"{type_set}_{scorer_skore_name}"] |
| 254 | + score_sklearn = cv_results_sklearn[ |
| 255 | + f"{type_set}_{map_skore_to_sklearn[scorer_skore_name]}" |
| 256 | + ] |
| 257 | + if map_skore_to_sklearn[scorer_skore_name].startswith("neg_"): |
| 258 | + numpy.testing.assert_allclose(score_skore, -score_sklearn) |
| 259 | + else: |
| 260 | + numpy.testing.assert_allclose(score_skore, score_sklearn) |
0 commit comments