Skip to content

Commit 5b9b55a

Browse files
fix(EstimatorReport): Deepcopy estimator
1 parent 4c819ec commit 5b9b55a

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

skore/src/skore/sklearn/_estimator/report.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import copy
12
import inspect
23
import time
4+
import warnings
35
from itertools import product
46

57
import joblib
@@ -100,6 +102,18 @@ def __init__(
100102
else: # fit is False
101103
self._estimator = estimator
102104

105+
try:
106+
self._estimator = copy.deepcopy(self._estimator)
107+
except Exception as e:
108+
warnings.warn(
109+
"Deepcopy failed; using estimator as-is. "
110+
"Be aware that modifying the estimator outside of "
111+
f"{self.__class__.__name__} will modify the internal estimator. "
112+
"Consider using a FrozenEstimator from scikit-learn to prevent this. "
113+
f"Original error: {e}",
114+
stacklevel=1,
115+
)
116+
103117
# private storage to be able to invalidate the cache when the user alters
104118
# those attributes
105119
self._X_train = X_train

skore/tests/unit/sklearn/test_estimator.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
187187
estimator, X, y = binary_classification_data
188188
report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y)
189189

190-
assert report.estimator is estimator # we should not clone the estimator
190+
check_is_fitted(report.estimator)
191+
assert isinstance(report.estimator, RandomForestClassifier)
191192
assert report.X_train is None
192193
assert report.y_train is None
193194
assert report.X_test is X
@@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
209210
estimator, X, y = binary_classification_data_pipeline
210211
report = EstimatorReport(estimator, X_test=X, y_test=y)
211212

212-
assert report.estimator is estimator # we should not clone the estimator
213+
check_is_fitted(report.estimator)
214+
assert isinstance(report.estimator, Pipeline)
213215
assert report.estimator_name == estimator[-1].__class__.__name__
214216
assert report.X_train is None
215217
assert report.y_train is None
@@ -950,3 +952,23 @@ def test_estimator_has_side_effects():
950952
predictions_after = report.estimator.predict_proba(X_test)
951953

952954
np.testing.assert_array_equal(predictions_before, predictions_after)
955+
956+
957+
def test_estimator_has_no_deep_copy():
958+
X, y = make_classification(n_classes=2, random_state=42)
959+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
960+
961+
estimator = LogisticRegression()
962+
# Make it so deepcopy does not work
963+
estimator.__reduce_ex__ = None
964+
estimator.__reduce__ = None
965+
966+
with pytest.warns(UserWarning, match="Deepcopy failed"):
967+
EstimatorReport(
968+
estimator,
969+
fit=False,
970+
X_train=X_train,
971+
X_test=X_test,
972+
y_train=y_train,
973+
y_test=y_test,
974+
)

0 commit comments

Comments
 (0)