Skip to content

Commit 2302006

Browse files
committed
add test for base class
1 parent bfcbd59 commit 2302006

File tree

1 file changed

+130
-1
lines changed

1 file changed

+130
-1
lines changed

skore/tests/unit/sklearn/test_base.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import re
2+
13
import joblib
24
import numpy as np
35
import pytest
46
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
5-
from skore.sklearn._base import _get_cached_response_values
7+
from sklearn.cluster import KMeans
8+
from sklearn.datasets import make_classification
9+
from sklearn.linear_model import LogisticRegression
10+
from sklearn.model_selection import train_test_split
11+
from skore.sklearn._base import _BaseAccessor, _get_cached_response_values
612

713

814
class MockClassifier(ClassifierMixin, BaseEstimator):
@@ -168,3 +174,126 @@ def test_get_cached_response_values_different_data_source_hash(
168174
f"Passing a hash not present in the cache keys should trigger new "
169175
f"computation for {response_method}"
170176
)
177+
178+
179+
class MockReport:
180+
def __init__(self, estimator, X_train=None, y_train=None, X_test=None, y_test=None):
181+
"""Mock a reporter with the minimal required attributes."""
182+
self._estimator = estimator
183+
self._X_train = X_train
184+
self._y_train = y_train
185+
self._X_test = X_test
186+
self._y_test = y_test
187+
188+
@property
189+
def estimator(self):
190+
return self._estimator
191+
192+
@property
193+
def X_train(self):
194+
return self._X_train
195+
196+
@property
197+
def y_train(self):
198+
return self._y_train
199+
200+
@property
201+
def X_test(self):
202+
return self._X_test
203+
204+
@property
205+
def y_test(self):
206+
return self._y_test
207+
208+
209+
def test_base_accessor_get_X_y_and_data_source_hash_error():
210+
"""Check that we raise the proper error in `get_X_y_and_use_cache`."""
211+
X, y = make_classification(n_samples=10, n_classes=2, random_state=42)
212+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
213+
214+
estimator = LogisticRegression().fit(X_train, y_train)
215+
report = MockReport(estimator, X_train=None, y_train=None, X_test=None, y_test=None)
216+
accessor = _BaseAccessor(parent=report, icon="")
217+
218+
err_msg = re.escape(
219+
"Invalid data source: unknown. Possible values are: " "test, train, X_y."
220+
)
221+
with pytest.raises(ValueError, match=err_msg):
222+
accessor._get_X_y_and_data_source_hash(data_source="unknown")
223+
224+
for data_source in ("train", "test"):
225+
err_msg = re.escape(
226+
f"No {data_source} data (i.e. X_{data_source} and y_{data_source}) were "
227+
f"provided when creating the reporter. Please provide the {data_source} "
228+
"data either when creating the reporter or by setting data_source to "
229+
"'X_y' and providing X and y."
230+
)
231+
with pytest.raises(ValueError, match=err_msg):
232+
accessor._get_X_y_and_data_source_hash(data_source=data_source)
233+
234+
report = MockReport(
235+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
236+
)
237+
accessor = _BaseAccessor(parent=report, icon="")
238+
239+
for data_source in ("train", "test"):
240+
err_msg = f"X and y must be None when data_source is {data_source}."
241+
with pytest.raises(ValueError, match=err_msg):
242+
accessor._get_X_y_and_data_source_hash(
243+
data_source=data_source, X=X_test, y=y_test
244+
)
245+
246+
err_msg = "X and y must be provided."
247+
with pytest.raises(ValueError, match=err_msg):
248+
accessor._get_X_y_and_data_source_hash(data_source="X_y")
249+
250+
# FIXME: once we choose some basic metrics for clustering, then we don't need to
251+
# use `custom_metric` for them.
252+
estimator = KMeans(n_clusters=2).fit(X_train)
253+
report = MockReport(estimator, X_test=X_test)
254+
accessor = _BaseAccessor(parent=report, icon="")
255+
err_msg = "X must be provided."
256+
with pytest.raises(ValueError, match=err_msg):
257+
accessor._get_X_y_and_data_source_hash(data_source="X_y")
258+
259+
report = MockReport(estimator)
260+
accessor = _BaseAccessor(parent=report, icon="")
261+
for data_source in ("train", "test"):
262+
err_msg = re.escape(
263+
f"No {data_source} data (i.e. X_{data_source}) were provided when "
264+
f"creating the reporter. Please provide the {data_source} data either "
265+
f"when creating the reporter or by setting data_source to 'X_y' and "
266+
f"providing X and y."
267+
)
268+
with pytest.raises(ValueError, match=err_msg):
269+
accessor._get_X_y_and_data_source_hash(data_source=data_source)
270+
271+
272+
@pytest.mark.parametrize("data_source", ("train", "test", "X_y"))
273+
def test_base_accessor_get_X_y_and_data_source_hash(data_source):
274+
"""Check the general behaviour of `get_X_y_and_use_cache`."""
275+
X, y = make_classification(n_samples=10, n_classes=2, random_state=42)
276+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
277+
278+
estimator = LogisticRegression().fit(X_train, y_train)
279+
report = MockReport(
280+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
281+
)
282+
accessor = _BaseAccessor(parent=report, icon="")
283+
kwargs = {"X": X_test, "y": y_test} if data_source == "X_y" else {}
284+
X, y, data_source_hash = accessor._get_X_y_and_data_source_hash(
285+
data_source=data_source, **kwargs
286+
)
287+
288+
if data_source == "train":
289+
assert X is X_train
290+
assert y is y_train
291+
assert data_source_hash is None
292+
elif data_source == "test":
293+
assert X is X_test
294+
assert y is y_test
295+
assert data_source_hash is None
296+
elif data_source == "X_y":
297+
assert X is X_test
298+
assert y is y_test
299+
assert data_source_hash == joblib.hash((X_test, y_test))

0 commit comments

Comments
 (0)