@@ -837,9 +837,170 @@ def parametrize_with_checks(
837837
838838 return parametrize_with_checks (estimators )
839839
840+ from sklearn .base import BaseEstimator
841+ from sklearn .utils .metaestimators import available_if
842+ from sklearn .utils .metaestimators import available_if
843+ from sklearn .utils .validation import check_is_fitted
844+
845+ def _estimator_has (attr ):
846+ """Check that final_estimator has `attr`.
847+
848+ Used together with `available_if`.
849+ """
850+
851+ def check (self ):
852+ # raise original `AttributeError` if `attr` does not exist
853+ getattr (self .estimator , attr )
854+ return True
855+
856+ return check
857+
858+ class FrozenEstimator (BaseEstimator ):
859+ """Estimator that wraps a fitted estimator to prevent re-fitting.
860+
861+ This meta-estimator takes an estimator and freezes it, in the sense that calling
862+ `fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled.
863+ All other methods are delegated to the original estimator and original estimator's
864+ attributes are accessible as well.
865+
866+ This is particularly useful when you have a fitted or a pre-trained model as a
867+ transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this
868+ step.
869+
870+ Parameters
871+ ----------
872+ estimator : estimator
873+ The estimator which is to be kept frozen.
874+
875+ See Also
876+ --------
877+ None: No similar entry in the scikit-learn documentation.
878+
879+ Examples
880+ --------
881+ >>> from sklearn.datasets import make_classification
882+ >>> from sklearn.frozen import FrozenEstimator
883+ >>> from sklearn.linear_model import LogisticRegression
884+ >>> X, y = make_classification(random_state=0)
885+ >>> clf = LogisticRegression(random_state=0).fit(X, y)
886+ >>> frozen_clf = FrozenEstimator(clf)
887+ >>> frozen_clf.fit(X, y) # No-op
888+ FrozenEstimator(estimator=LogisticRegression(random_state=0))
889+ >>> frozen_clf.predict(X) # Predictions from `clf.predict`
890+ array(...)
891+ """
892+
893+ def __init__ (self , estimator ):
894+ self .estimator = estimator
895+
896+ @available_if (_estimator_has ("__getitem__" ))
897+ def __getitem__ (self , * args , ** kwargs ):
898+ """__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \
899+ :class:`~sklearn.compose.ColumnTransformer`.
900+ """
901+ return self .estimator .__getitem__ (* args , ** kwargs )
902+
903+ def __getattr__ (self , name ):
904+ # `estimator`'s attributes are now accessible except `fit_predict` and
905+ # `fit_transform`
906+ if name in ["fit_predict" , "fit_transform" ]:
907+ raise AttributeError (f"{ name } is not available for frozen estimators." )
908+ return getattr (self .estimator , name )
909+
910+ def __sklearn_clone__ (self ):
911+ return self
912+
913+ def __sklearn_is_fitted__ (self ):
914+ try :
915+ check_is_fitted (self .estimator )
916+ return True
917+ except NotFittedError :
918+ return False
919+
920+ def fit (self , X , y , * args , ** kwargs ):
921+ """No-op.
922+
923+ As a frozen estimator, calling `fit` has no effect.
924+
925+ Parameters
926+ ----------
927+ X : object
928+ Ignored.
929+
930+ y : object
931+ Ignored.
932+
933+ *args : tuple
934+ Additional positional arguments. Ignored, but present for API compatibility
935+ with `self.estimator`.
936+
937+ **kwargs : dict
938+ Additional keyword arguments. Ignored, but present for API compatibility
939+ with `self.estimator`.
940+
941+ Returns
942+ -------
943+ self : object
944+ Returns the instance itself.
945+ """
946+ breakpoint ()
947+ check_is_fitted (self .estimator )
948+ return self
949+
950+ def set_params (self , ** kwargs ):
951+ """Set the parameters of this estimator.
952+
953+ The only valid key here is `estimator`. You cannot set the parameters of the
954+ inner estimator.
955+
956+ Parameters
957+ ----------
958+ **kwargs : dict
959+ Estimator parameters.
960+
961+ Returns
962+ -------
963+ self : FrozenEstimator
964+ This estimator.
965+ """
966+ estimator = kwargs .pop ("estimator" , None )
967+ if estimator is not None :
968+ self .estimator = estimator
969+ if kwargs :
970+ raise ValueError (
971+ "You cannot set parameters of the inner estimator in a frozen "
972+ "estimator since calling `fit` has no effect. You can use "
973+ "`frozenestimator.estimator.set_params` to set parameters of the inner "
974+ "estimator."
975+ )
976+
977+ def get_params (self , deep = True ):
978+ """Get parameters for this estimator.
979+
980+ Returns a `{"estimator": estimator}` dict. The parameters of the inner
981+ estimator are not included.
982+
983+ Parameters
984+ ----------
985+ deep : bool, default=True
986+ Ignored.
987+
988+ Returns
989+ -------
990+ params : dict
991+ Parameter names mapped to their values.
992+ """
993+ return {"estimator" : self .estimator }
994+
995+ def __sklearn_tags__ (self ):
996+ tags = deepcopy (get_tags (self .estimator ))
997+ tags ._skip_test = True
998+ return tags
999+
8401000else :
8411001 # base
8421002 from sklearn .base import is_clusterer # noqa: F401
1003+ from sklearn .frozen import FrozenEstimator # noqa: F401
8431004
8441005 # test_common
8451006 # tags infrastructure
0 commit comments