Skip to content

Commit b177651

Browse files
fix(EstimatorReport): Make a deep copy of fitted estimator in constructor to avoid side-effect (#1085)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 7e03c9c commit b177651

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

skore/src/skore/sklearn/_estimator/report.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import copy
12
import inspect
23
import time
4+
import warnings
35
from itertools import product
46

57
import joblib
@@ -27,7 +29,8 @@ class EstimatorReport(_HelpMixin, DirNamesMixin):
2729
Parameters
2830
----------
2931
estimator : estimator object
30-
Estimator to make report from.
32+
Estimator to make the report from. When the estimator is not fitted,
33+
it is deep-copied to avoid side-effects. If it is fitted, it is cloned instead.
3134
3235
fit : {"auto", True, False}, default="auto"
3336
Whether to fit the estimator on the training data. If "auto", the estimator
@@ -79,6 +82,21 @@ def _fit_estimator(estimator, X_train, y_train):
7982
)
8083
return clone(estimator).fit(X_train, y_train)
8184

85+
@classmethod
86+
def _copy_estimator(cls, estimator):
87+
try:
88+
return copy.deepcopy(estimator)
89+
except Exception as e:
90+
warnings.warn(
91+
"Deepcopy failed; using estimator as-is. "
92+
"Be aware that modifying the estimator outside of "
93+
f"{cls.__name__} will modify the internal estimator. "
94+
"Consider using a FrozenEstimator from scikit-learn to prevent this. "
95+
f"Original error: {e}",
96+
stacklevel=1,
97+
)
98+
return estimator
99+
82100
def __init__(
83101
self,
84102
estimator,
@@ -92,13 +110,13 @@ def __init__(
92110
if fit == "auto":
93111
try:
94112
check_is_fitted(estimator)
95-
self._estimator = estimator
113+
self._estimator = self._copy_estimator(estimator)
96114
except NotFittedError:
97115
self._estimator = self._fit_estimator(estimator, X_train, y_train)
98116
elif fit is True:
99117
self._estimator = self._fit_estimator(estimator, X_train, y_train)
100118
else: # fit is False
101-
self._estimator = estimator
119+
self._estimator = self._copy_estimator(estimator)
102120

103121
# private storage to be able to invalidate the cache when the user alters
104122
# those attributes

skore/tests/unit/sklearn/test_estimator.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
187187
estimator, X, y = binary_classification_data
188188
report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y)
189189

190-
assert report.estimator is estimator # we should not clone the estimator
190+
check_is_fitted(report.estimator)
191+
assert isinstance(report.estimator, RandomForestClassifier)
191192
assert report.X_train is None
192193
assert report.y_train is None
193194
assert report.X_test is X
@@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
209210
estimator, X, y = binary_classification_data_pipeline
210211
report = EstimatorReport(estimator, X_test=X, y_test=y)
211212

212-
assert report.estimator is estimator # we should not clone the estimator
213+
check_is_fitted(report.estimator)
214+
assert isinstance(report.estimator, Pipeline)
213215
assert report.estimator_name == estimator[-1].__class__.__name__
214216
assert report.X_train is None
215217
assert report.y_train is None
@@ -925,3 +927,50 @@ def test_estimator_report_get_X_y_and_data_source_hash(data_source):
925927
assert X is X_test
926928
assert y is y_test
927929
assert data_source_hash == joblib.hash((X_test, y_test))
930+
931+
932+
@pytest.mark.parametrize("prefit_estimator", [True, False])
933+
def test_estimator_has_side_effects(prefit_estimator):
934+
"""Re-fitting the estimator outside the EstimatorReport
935+
should not have an effect on the EstimatorReport's internal estimator."""
936+
X, y = make_classification(n_classes=2, random_state=42)
937+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
938+
939+
estimator = LogisticRegression()
940+
if prefit_estimator:
941+
estimator.fit(X_train, y_train)
942+
943+
report = EstimatorReport(
944+
estimator,
945+
X_train=X_train,
946+
X_test=X_test,
947+
y_train=y_train,
948+
y_test=y_test,
949+
)
950+
951+
predictions_before = report.estimator.predict_proba(X_test)
952+
estimator.fit(X_test, y_test)
953+
predictions_after = report.estimator.predict_proba(X_test)
954+
np.testing.assert_array_equal(predictions_before, predictions_after)
955+
956+
957+
def test_estimator_has_no_deep_copy():
958+
"""Check that we raise a warning if the deep copy failed with a fitted
959+
estimator."""
960+
X, y = make_classification(n_classes=2, random_state=42)
961+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
962+
963+
estimator = LogisticRegression()
964+
# Make it so deepcopy does not work
965+
estimator.__reduce_ex__ = None
966+
estimator.__reduce__ = None
967+
968+
with pytest.warns(UserWarning, match="Deepcopy failed"):
969+
EstimatorReport(
970+
estimator,
971+
fit=False,
972+
X_train=X_train,
973+
X_test=X_test,
974+
y_train=y_train,
975+
y_test=y_test,
976+
)

0 commit comments

Comments
 (0)