@@ -289,17 +289,33 @@ def test_estimator_report_repr(binary_classification_data):
289289
290290
291291@pytest .mark .parametrize (
292- "fixture_name" , ["binary_classification_data" , "regression_data" ]
292+ "fixture_name, pass_train_data, expected_n_keys" ,
293+ [
294+ ("binary_classification_data" , True , 6 ),
295+ ("binary_classification_data_svc" , True , 6 ),
296+ ("multiclass_classification_data" , True , 8 ),
297+ ("regression_data" , True , 2 ),
298+ ("binary_classification_data" , False , 3 ),
299+ ("binary_classification_data_svc" , False , 3 ),
300+ ("multiclass_classification_data" , False , 4 ),
301+ ("regression_data" , False , 1 ),
302+ ],
293303)
294- def test_estimator_report_cache_predictions (request , fixture_name ):
304+ def test_estimator_report_cache_predictions (
305+ request , fixture_name , pass_train_data , expected_n_keys
306+ ):
295307 """Check that calling cache_predictions fills the cache."""
296308 estimator , X_test , y_test = request .getfixturevalue (fixture_name )
297- report = EstimatorReport (
298- estimator , X_train = X_test , y_train = y_test , X_test = X_test , y_test = y_test
299- )
309+ if pass_train_data :
310+ report = EstimatorReport (
311+ estimator , X_train = X_test , y_train = y_test , X_test = X_test , y_test = y_test
312+ )
313+ else :
314+ report = EstimatorReport (estimator , X_test = X_test , y_test = y_test )
300315
301316 assert report ._cache == {}
302317 report .cache_predictions ()
318+ assert len (report ._cache ) == expected_n_keys
303319 assert report ._cache != {}
304320 stored_cache = deepcopy (report ._cache )
305321 report .cache_predictions ()
0 commit comments