Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import time
import warnings
from itertools import product

import joblib
Expand Down Expand Up @@ -27,7 +29,8 @@ class EstimatorReport(_HelpMixin, DirNamesMixin):
Parameters
----------
estimator : estimator object
Estimator to make report from.
Estimator to make the report from. When the estimator is not fitted,
it is deep-copied to avoid side-effects. If it is fitted, it is cloned instead.

fit : {"auto", True, False}, default="auto"
Whether to fit the estimator on the training data. If "auto", the estimator
Expand Down Expand Up @@ -79,6 +82,21 @@ def _fit_estimator(estimator, X_train, y_train):
)
return clone(estimator).fit(X_train, y_train)

@classmethod
def _copy_estimator(cls, estimator):
try:
return copy.deepcopy(estimator)
except Exception as e:
warnings.warn(
"Deepcopy failed; using estimator as-is. "
"Be aware that modifying the estimator outside of "
f"{cls.__name__} will modify the internal estimator. "
"Consider using a FrozenEstimator from scikit-learn to prevent this. "
f"Original error: {e}",
stacklevel=1,
)
return estimator

def __init__(
self,
estimator,
Expand All @@ -92,13 +110,13 @@ def __init__(
if fit == "auto":
try:
check_is_fitted(estimator)
self._estimator = estimator
self._estimator = self._copy_estimator(estimator)
except NotFittedError:
self._estimator = self._fit_estimator(estimator, X_train, y_train)
elif fit is True:
self._estimator = self._fit_estimator(estimator, X_train, y_train)
else: # fit is False
self._estimator = estimator
self._estimator = self._copy_estimator(estimator)

# private storage to be able to invalidate the cache when the user alters
# those attributes
Expand Down
53 changes: 51 additions & 2 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
estimator, X, y = binary_classification_data
report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, RandomForestClassifier)
assert report.X_train is None
assert report.y_train is None
assert report.X_test is X
Expand All @@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
estimator, X, y = binary_classification_data_pipeline
report = EstimatorReport(estimator, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, Pipeline)
assert report.estimator_name == estimator[-1].__class__.__name__
assert report.X_train is None
assert report.y_train is None
Expand Down Expand Up @@ -925,3 +927,50 @@ def test_estimator_report_get_X_y_and_data_source_hash(data_source):
assert X is X_test
assert y is y_test
assert data_source_hash == joblib.hash((X_test, y_test))


@pytest.mark.parametrize("prefit_estimator", [True, False])
def test_estimator_has_side_effects(prefit_estimator):
"""Re-fitting the estimator outside the EstimatorReport
should not have an effect on the EstimatorReport's internal estimator."""
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression()
if prefit_estimator:
estimator.fit(X_train, y_train)

report = EstimatorReport(
estimator,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)

predictions_before = report.estimator.predict_proba(X_test)
estimator.fit(X_test, y_test)
predictions_after = report.estimator.predict_proba(X_test)
np.testing.assert_array_equal(predictions_before, predictions_after)


def test_estimator_has_no_deep_copy():
"""Check that we raise a warning if the deep copy failed with a fitted
estimator."""
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression()
# Make it so deepcopy does not work
estimator.__reduce_ex__ = None
estimator.__reduce__ = None

with pytest.warns(UserWarning, match="Deepcopy failed"):
EstimatorReport(
estimator,
fit=False,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)
Loading