|
| 1 | +import re |
| 2 | + |
1 | 3 | import joblib |
2 | 4 | import numpy as np |
3 | 5 | import pytest |
4 | 6 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin |
5 | | -from skore.sklearn._base import _get_cached_response_values |
| 7 | +from sklearn.cluster import KMeans |
| 8 | +from sklearn.datasets import make_classification |
| 9 | +from sklearn.linear_model import LogisticRegression |
| 10 | +from sklearn.model_selection import train_test_split |
| 11 | +from skore.sklearn._base import _BaseAccessor, _get_cached_response_values |
6 | 12 |
|
7 | 13 |
|
8 | 14 | class MockClassifier(ClassifierMixin, BaseEstimator): |
@@ -168,3 +174,126 @@ def test_get_cached_response_values_different_data_source_hash( |
168 | 174 | f"Passing a hash not present in the cache keys should trigger new " |
169 | 175 | f"computation for {response_method}" |
170 | 176 | ) |
| 177 | + |
| 178 | + |
| 179 | +class MockReport: |
| 180 | + def __init__(self, estimator, X_train=None, y_train=None, X_test=None, y_test=None): |
| 181 | + """Mock a reporter with the minimal required attributes.""" |
| 182 | + self._estimator = estimator |
| 183 | + self._X_train = X_train |
| 184 | + self._y_train = y_train |
| 185 | + self._X_test = X_test |
| 186 | + self._y_test = y_test |
| 187 | + |
| 188 | + @property |
| 189 | + def estimator(self): |
| 190 | + return self._estimator |
| 191 | + |
| 192 | + @property |
| 193 | + def X_train(self): |
| 194 | + return self._X_train |
| 195 | + |
| 196 | + @property |
| 197 | + def y_train(self): |
| 198 | + return self._y_train |
| 199 | + |
| 200 | + @property |
| 201 | + def X_test(self): |
| 202 | + return self._X_test |
| 203 | + |
| 204 | + @property |
| 205 | + def y_test(self): |
| 206 | + return self._y_test |
| 207 | + |
| 208 | + |
| 209 | +def test_base_accessor_get_X_y_and_data_source_hash_error(): |
| 210 | + """Check that we raise the proper error in `get_X_y_and_use_cache`.""" |
| 211 | + X, y = make_classification(n_samples=10, n_classes=2, random_state=42) |
| 212 | + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) |
| 213 | + |
| 214 | + estimator = LogisticRegression().fit(X_train, y_train) |
| 215 | + report = MockReport(estimator, X_train=None, y_train=None, X_test=None, y_test=None) |
| 216 | + accessor = _BaseAccessor(parent=report, icon="") |
| 217 | + |
| 218 | + err_msg = re.escape( |
| 219 | + "Invalid data source: unknown. Possible values are: " "test, train, X_y." |
| 220 | + ) |
| 221 | + with pytest.raises(ValueError, match=err_msg): |
| 222 | + accessor._get_X_y_and_data_source_hash(data_source="unknown") |
| 223 | + |
| 224 | + for data_source in ("train", "test"): |
| 225 | + err_msg = re.escape( |
| 226 | + f"No {data_source} data (i.e. X_{data_source} and y_{data_source}) were " |
| 227 | + f"provided when creating the reporter. Please provide the {data_source} " |
| 228 | + "data either when creating the reporter or by setting data_source to " |
| 229 | + "'X_y' and providing X and y." |
| 230 | + ) |
| 231 | + with pytest.raises(ValueError, match=err_msg): |
| 232 | + accessor._get_X_y_and_data_source_hash(data_source=data_source) |
| 233 | + |
| 234 | + report = MockReport( |
| 235 | + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test |
| 236 | + ) |
| 237 | + accessor = _BaseAccessor(parent=report, icon="") |
| 238 | + |
| 239 | + for data_source in ("train", "test"): |
| 240 | + err_msg = f"X and y must be None when data_source is {data_source}." |
| 241 | + with pytest.raises(ValueError, match=err_msg): |
| 242 | + accessor._get_X_y_and_data_source_hash( |
| 243 | + data_source=data_source, X=X_test, y=y_test |
| 244 | + ) |
| 245 | + |
| 246 | + err_msg = "X and y must be provided." |
| 247 | + with pytest.raises(ValueError, match=err_msg): |
| 248 | + accessor._get_X_y_and_data_source_hash(data_source="X_y") |
| 249 | + |
| 250 | + # FIXME: once we choose some basic metrics for clustering, then we don't need to |
| 251 | + # use `custom_metric` for them. |
| 252 | + estimator = KMeans(n_clusters=2).fit(X_train) |
| 253 | + report = MockReport(estimator, X_test=X_test) |
| 254 | + accessor = _BaseAccessor(parent=report, icon="") |
| 255 | + err_msg = "X must be provided." |
| 256 | + with pytest.raises(ValueError, match=err_msg): |
| 257 | + accessor._get_X_y_and_data_source_hash(data_source="X_y") |
| 258 | + |
| 259 | + report = MockReport(estimator) |
| 260 | + accessor = _BaseAccessor(parent=report, icon="") |
| 261 | + for data_source in ("train", "test"): |
| 262 | + err_msg = re.escape( |
| 263 | + f"No {data_source} data (i.e. X_{data_source}) were provided when " |
| 264 | + f"creating the reporter. Please provide the {data_source} data either " |
| 265 | + f"when creating the reporter or by setting data_source to 'X_y' and " |
| 266 | + f"providing X and y." |
| 267 | + ) |
| 268 | + with pytest.raises(ValueError, match=err_msg): |
| 269 | + accessor._get_X_y_and_data_source_hash(data_source=data_source) |
| 270 | + |
| 271 | + |
| 272 | +@pytest.mark.parametrize("data_source", ("train", "test", "X_y")) |
| 273 | +def test_base_accessor_get_X_y_and_data_source_hash(data_source): |
| 274 | + """Check the general behaviour of `get_X_y_and_use_cache`.""" |
| 275 | + X, y = make_classification(n_samples=10, n_classes=2, random_state=42) |
| 276 | + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) |
| 277 | + |
| 278 | + estimator = LogisticRegression().fit(X_train, y_train) |
| 279 | + report = MockReport( |
| 280 | + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test |
| 281 | + ) |
| 282 | + accessor = _BaseAccessor(parent=report, icon="") |
| 283 | + kwargs = {"X": X_test, "y": y_test} if data_source == "X_y" else {} |
| 284 | + X, y, data_source_hash = accessor._get_X_y_and_data_source_hash( |
| 285 | + data_source=data_source, **kwargs |
| 286 | + ) |
| 287 | + |
| 288 | + if data_source == "train": |
| 289 | + assert X is X_train |
| 290 | + assert y is y_train |
| 291 | + assert data_source_hash is None |
| 292 | + elif data_source == "test": |
| 293 | + assert X is X_test |
| 294 | + assert y is y_test |
| 295 | + assert data_source_hash is None |
| 296 | + elif data_source == "X_y": |
| 297 | + assert X is X_test |
| 298 | + assert y is y_test |
| 299 | + assert data_source_hash == joblib.hash((X_test, y_test)) |
0 commit comments