Skip to content

Commit 055e9af

Browse files
fix(EstimatorReport): Use sklearn's FrozenEstimator
1 parent 4c819ec commit 055e9af

File tree

2 files changed

+164
-1
lines changed

2 files changed

+164
-1
lines changed

skore/src/skore/externals/_sklearn_compat.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,9 +837,170 @@ def parametrize_with_checks(
837837

838838
return parametrize_with_checks(estimators)
839839

840+
from sklearn.base import BaseEstimator
841+
from sklearn.utils.metaestimators import available_if
842+
from sklearn.utils.metaestimators import available_if
843+
from sklearn.utils.validation import check_is_fitted
844+
845+
def _estimator_has(attr):
846+
"""Check that final_estimator has `attr`.
847+
848+
Used together with `available_if`.
849+
"""
850+
851+
def check(self):
852+
# raise original `AttributeError` if `attr` does not exist
853+
getattr(self.estimator, attr)
854+
return True
855+
856+
return check
857+
858+
class FrozenEstimator(BaseEstimator):
859+
"""Estimator that wraps a fitted estimator to prevent re-fitting.
860+
861+
This meta-estimator takes an estimator and freezes it, in the sense that calling
862+
`fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled.
863+
All other methods are delegated to the original estimator and original estimator's
864+
attributes are accessible as well.
865+
866+
This is particularly useful when you have a fitted or a pre-trained model as a
867+
transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this
868+
step.
869+
870+
Parameters
871+
----------
872+
estimator : estimator
873+
The estimator which is to be kept frozen.
874+
875+
See Also
876+
--------
877+
None: No similar entry in the scikit-learn documentation.
878+
879+
Examples
880+
--------
881+
>>> from sklearn.datasets import make_classification
882+
>>> from sklearn.frozen import FrozenEstimator
883+
>>> from sklearn.linear_model import LogisticRegression
884+
>>> X, y = make_classification(random_state=0)
885+
>>> clf = LogisticRegression(random_state=0).fit(X, y)
886+
>>> frozen_clf = FrozenEstimator(clf)
887+
>>> frozen_clf.fit(X, y) # No-op
888+
FrozenEstimator(estimator=LogisticRegression(random_state=0))
889+
>>> frozen_clf.predict(X) # Predictions from `clf.predict`
890+
array(...)
891+
"""
892+
893+
def __init__(self, estimator):
894+
self.estimator = estimator
895+
896+
@available_if(_estimator_has("__getitem__"))
897+
def __getitem__(self, *args, **kwargs):
898+
"""__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \
899+
:class:`~sklearn.compose.ColumnTransformer`.
900+
"""
901+
return self.estimator.__getitem__(*args, **kwargs)
902+
903+
def __getattr__(self, name):
904+
# `estimator`'s attributes are now accessible except `fit_predict` and
905+
# `fit_transform`
906+
if name in ["fit_predict", "fit_transform"]:
907+
raise AttributeError(f"{name} is not available for frozen estimators.")
908+
return getattr(self.estimator, name)
909+
910+
def __sklearn_clone__(self):
911+
return self
912+
913+
def __sklearn_is_fitted__(self):
914+
try:
915+
check_is_fitted(self.estimator)
916+
return True
917+
except NotFittedError:
918+
return False
919+
920+
def fit(self, X, y, *args, **kwargs):
921+
"""No-op.
922+
923+
As a frozen estimator, calling `fit` has no effect.
924+
925+
Parameters
926+
----------
927+
X : object
928+
Ignored.
929+
930+
y : object
931+
Ignored.
932+
933+
*args : tuple
934+
Additional positional arguments. Ignored, but present for API compatibility
935+
with `self.estimator`.
936+
937+
**kwargs : dict
938+
Additional keyword arguments. Ignored, but present for API compatibility
939+
with `self.estimator`.
940+
941+
Returns
942+
-------
943+
self : object
944+
Returns the instance itself.
945+
"""
946+
breakpoint()
947+
check_is_fitted(self.estimator)
948+
return self
949+
950+
def set_params(self, **kwargs):
951+
"""Set the parameters of this estimator.
952+
953+
The only valid key here is `estimator`. You cannot set the parameters of the
954+
inner estimator.
955+
956+
Parameters
957+
----------
958+
**kwargs : dict
959+
Estimator parameters.
960+
961+
Returns
962+
-------
963+
self : FrozenEstimator
964+
This estimator.
965+
"""
966+
estimator = kwargs.pop("estimator", None)
967+
if estimator is not None:
968+
self.estimator = estimator
969+
if kwargs:
970+
raise ValueError(
971+
"You cannot set parameters of the inner estimator in a frozen "
972+
"estimator since calling `fit` has no effect. You can use "
973+
"`frozenestimator.estimator.set_params` to set parameters of the inner "
974+
"estimator."
975+
)
976+
977+
def get_params(self, deep=True):
978+
"""Get parameters for this estimator.
979+
980+
Returns a `{"estimator": estimator}` dict. The parameters of the inner
981+
estimator are not included.
982+
983+
Parameters
984+
----------
985+
deep : bool, default=True
986+
Ignored.
987+
988+
Returns
989+
-------
990+
params : dict
991+
Parameter names mapped to their values.
992+
"""
993+
return {"estimator": self.estimator}
994+
995+
def __sklearn_tags__(self):
996+
tags = deepcopy(get_tags(self.estimator))
997+
tags._skip_test = True
998+
return tags
999+
8401000
else:
8411001
# base
8421002
from sklearn.base import is_clusterer # noqa: F401
1003+
from sklearn.frozen import FrozenEstimator # noqa: F401
8431004

8441005
# test_common
8451006
# tags infrastructure

skore/src/skore/sklearn/_estimator/report.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sklearn.utils.validation import check_is_fitted
1414

1515
from skore.externals._pandas_accessors import DirNamesMixin
16-
from skore.externals._sklearn_compat import is_clusterer
16+
from skore.externals._sklearn_compat import FrozenEstimator, is_clusterer
1717
from skore.sklearn._estimator.base import _BaseAccessor, _HelpMixin
1818
from skore.sklearn.find_ml_task import _find_ml_task
1919

@@ -100,6 +100,8 @@ def __init__(
100100
else: # fit is False
101101
self._estimator = estimator
102102

103+
self._estimator = FrozenEstimator(self._estimator)
104+
103105
# private storage to be able to invalidate the cache when the user alters
104106
# those attributes
105107
self._X_train = X_train

0 commit comments

Comments
 (0)