Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import time
import warnings
from itertools import product

import joblib
Expand Down Expand Up @@ -100,6 +102,18 @@ def __init__(
else: # fit is False
self._estimator = estimator

try:
self._estimator = copy.deepcopy(self._estimator)
except Exception as e:
warnings.warn(
"Deepcopy failed; using estimator as-is. "
"Be aware that modifying the estimator outside of "
f"{self.__class__.__name__} will modify the internal estimator. "
"Consider using a FrozenEstimator from scikit-learn to prevent this. "
f"Original error: {e}",
stacklevel=1,
)

# private storage to be able to invalidate the cache when the user alters
# those attributes
self._X_train = X_train
Expand Down
51 changes: 49 additions & 2 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
estimator, X, y = binary_classification_data
report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, RandomForestClassifier)
assert report.X_train is None
assert report.y_train is None
assert report.X_test is X
Expand All @@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
estimator, X, y = binary_classification_data_pipeline
report = EstimatorReport(estimator, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, Pipeline)
assert report.estimator_name == estimator[-1].__class__.__name__
assert report.X_train is None
assert report.y_train is None
Expand Down Expand Up @@ -925,3 +927,48 @@ def test_estimator_report_get_X_y_and_data_source_hash(data_source):
assert X is X_test
assert y is y_test
assert data_source_hash == joblib.hash((X_test, y_test))


def test_estimator_has_side_effects():
"""Re-fitting the estimator outside the EstimatorReport
should not have an effect on the EstimatorReport's internal estimator"""
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
X_train2, _, y_train2, _ = train_test_split(X, y, random_state=420)

estimator = LogisticRegression().fit(X_train, y_train)
report = EstimatorReport(
estimator,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)

predictions_before = report.estimator.predict_proba(X_test)

estimator.fit(X_train2, y_train2)

predictions_after = report.estimator.predict_proba(X_test)

np.testing.assert_array_equal(predictions_before, predictions_after)


def test_estimator_has_no_deep_copy():
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression()
# Make it so deepcopy does not work
estimator.__reduce_ex__ = None
estimator.__reduce__ = None

with pytest.warns(UserWarning, match="Deepcopy failed"):
EstimatorReport(
estimator,
fit=False,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)
Loading