Skip to content

Commit 92a4f89

Browse files
committed
fix: Cache all possible predictions in cache_predictions
1 parent ad34c2d commit 92a4f89

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

skore/src/skore/sklearn/_estimator/report.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
163163
if response_methods == "auto":
164164
response_methods = ("predict",)
165165
if hasattr(self._estimator, "predict_proba"):
166-
response_methods = ("predict_proba",)
166+
response_methods += ("predict_proba",)
167167
if hasattr(self._estimator, "decision_function"):
168-
response_methods = ("decision_function",)
168+
response_methods += ("decision_function",)
169169
pos_labels = self._estimator.classes_
170170
else:
171171
if response_methods == "auto":
@@ -175,8 +175,8 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
175175
data_sources = ("test",)
176176
Xs = (self._X_test,)
177177
if self._X_train is not None:
178-
data_sources = ("train",)
179-
Xs = (self._X_train,)
178+
data_sources += ("train",)
179+
Xs += (self._X_train,)
180180

181181
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered")
182182
generator = parallel(
@@ -188,8 +188,8 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
188188
pos_label=pos_label,
189189
data_source=data_source,
190190
)
191-
for response_method, pos_label, data_source, X in product(
192-
response_methods, pos_labels, data_sources, Xs
191+
for response_method, pos_label, (data_source, X) in product(
192+
response_methods, pos_labels, zip(data_sources, Xs)
193193
)
194194
)
195195
# trigger the computation

skore/tests/unit/sklearn/test_estimator.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,17 +289,33 @@ def test_estimator_report_repr(binary_classification_data):
289289

290290

291291
@pytest.mark.parametrize(
292-
"fixture_name", ["binary_classification_data", "regression_data"]
292+
"fixture_name, pass_train_data, expected_n_keys",
293+
[
294+
("binary_classification_data", True, 6),
295+
("binary_classification_data_svc", True, 6),
296+
("multiclass_classification_data", True, 8),
297+
("regression_data", True, 2),
298+
("binary_classification_data", False, 3),
299+
("binary_classification_data_svc", False, 3),
300+
("multiclass_classification_data", False, 4),
301+
("regression_data", False, 1),
302+
],
293303
)
294-
def test_estimator_report_cache_predictions(request, fixture_name):
304+
def test_estimator_report_cache_predictions(
305+
request, fixture_name, pass_train_data, expected_n_keys
306+
):
295307
"""Check that calling cache_predictions fills the cache."""
296308
estimator, X_test, y_test = request.getfixturevalue(fixture_name)
297-
report = EstimatorReport(
298-
estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test
299-
)
309+
if pass_train_data:
310+
report = EstimatorReport(
311+
estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test
312+
)
313+
else:
314+
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
300315

301316
assert report._cache == {}
302317
report.cache_predictions()
318+
assert len(report._cache) == expected_n_keys
303319
assert report._cache != {}
304320
stored_cache = deepcopy(report._cache)
305321
report.cache_predictions()

0 commit comments

Comments
 (0)