@@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
187187 estimator , X , y = binary_classification_data
188188 report = EstimatorReport (estimator , fit = fit , X_test = X , y_test = y )
189189
190- assert report .estimator is estimator # we should not clone the estimator
190+ check_is_fitted (report .estimator )
191+ assert isinstance (report .estimator , RandomForestClassifier )
191192 assert report .X_train is None
192193 assert report .y_train is None
193194 assert report .X_test is X
@@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
209210 estimator , X , y = binary_classification_data_pipeline
210211 report = EstimatorReport (estimator , X_test = X , y_test = y )
211212
212- assert report .estimator is estimator # we should not clone the estimator
213+ check_is_fitted (report .estimator )
214+ assert isinstance (report .estimator , Pipeline )
213215 assert report .estimator_name == estimator [- 1 ].__class__ .__name__
214216 assert report .X_train is None
215217 assert report .y_train is None
@@ -925,3 +927,50 @@ def test_estimator_report_get_X_y_and_data_source_hash(data_source):
925927 assert X is X_test
926928 assert y is y_test
927929 assert data_source_hash == joblib .hash ((X_test , y_test ))
930+
931+
932+ @pytest .mark .parametrize ("prefit_estimator" , [True , False ])
933+ def test_estimator_has_side_effects (prefit_estimator ):
934+ """Re-fitting the estimator outside the EstimatorReport
935+ should not have an effect on the EstimatorReport's internal estimator."""
936+ X , y = make_classification (n_classes = 2 , random_state = 42 )
937+ X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 42 )
938+
939+ estimator = LogisticRegression ()
940+ if prefit_estimator :
941+ estimator .fit (X_train , y_train )
942+
943+ report = EstimatorReport (
944+ estimator ,
945+ X_train = X_train ,
946+ X_test = X_test ,
947+ y_train = y_train ,
948+ y_test = y_test ,
949+ )
950+
951+ predictions_before = report .estimator .predict_proba (X_test )
952+ estimator .fit (X_test , y_test )
953+ predictions_after = report .estimator .predict_proba (X_test )
954+ np .testing .assert_array_equal (predictions_before , predictions_after )
955+
956+
957+ def test_estimator_has_no_deep_copy ():
958+ """Check that we raise a warning if the deep copy failed with a fitted
959+ estimator."""
960+ X , y = make_classification (n_classes = 2 , random_state = 42 )
961+ X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 42 )
962+
963+ estimator = LogisticRegression ()
964+ # Make it so deepcopy does not work
965+ estimator .__reduce_ex__ = None
966+ estimator .__reduce__ = None
967+
968+ with pytest .warns (UserWarning , match = "Deepcopy failed" ):
969+ EstimatorReport (
970+ estimator ,
971+ fit = False ,
972+ X_train = X_train ,
973+ X_test = X_test ,
974+ y_train = y_train ,
975+ y_test = y_test ,
976+ )
0 commit comments