44import pytest
55
66from sklearn .base import clone
7+ from sklearn .cluster import DBSCAN , KMeans
78from sklearn .datasets import (
89 load_iris ,
910 make_classification ,
1011 make_multilabel_classification ,
11- make_regression ,
1212)
1313from sklearn .ensemble import IsolationForest
14- from sklearn .linear_model import (
15- LinearRegression ,
16- LogisticRegression ,
17- )
14+ from sklearn .linear_model import LinearRegression , LogisticRegression
1815from sklearn .multioutput import ClassifierChain
1916from sklearn .preprocessing import scale
2017from sklearn .tree import DecisionTreeClassifier , DecisionTreeRegressor
21- from sklearn .utils ._mocking import _MockEstimatorOnOffPrediction
2218from sklearn .utils ._response import _get_response_values , _get_response_values_binary
2319from sklearn .utils ._testing import assert_allclose , assert_array_equal
2420
2925
3026
3127@pytest .mark .parametrize (
32- "response_method" , ["decision_function" , "predict_proba" , "predict_log_proba" ]
28+ "estimator, response_method" ,
29+ [
30+ (DecisionTreeRegressor (), "predict_proba" ),
31+ (DecisionTreeRegressor (), ["predict_proba" , "decision_function" ]),
32+ (KMeans (n_clusters = 2 , n_init = 1 ), "predict_proba" ),
33+ (KMeans (n_clusters = 2 , n_init = 1 ), ["predict_proba" , "decision_function" ]),
34+ (DBSCAN (), "predict" ),
35+ (IsolationForest (random_state = 0 ), "predict_proba" ),
36+ (IsolationForest (random_state = 0 ), ["predict_proba" , "score" ]),
37+ ],
3338)
34- def test_get_response_values_regressor_error (response_method ):
35- """Check the error message with regressor an not supported response
36- method."""
37- my_estimator = _MockEstimatorOnOffPrediction (response_methods = [response_method ])
38- X = "mocking_data" , "mocking_target"
39- err_msg = f"{ my_estimator .__class__ .__name__ } should either be a classifier"
40- with pytest .raises (ValueError , match = err_msg ):
41- _get_response_values (my_estimator , X , response_method = response_method )
42-
43-
44- @pytest .mark .parametrize ("return_response_method_used" , [True , False ])
45- def test_get_response_values_regressor (return_response_method_used ):
46- """Check the behaviour of `_get_response_values` with regressor."""
47- X , y = make_regression (n_samples = 10 , random_state = 0 )
48- regressor = LinearRegression ().fit (X , y )
49- results = _get_response_values (
50- regressor ,
51- X ,
52- response_method = "predict" ,
53- return_response_method_used = return_response_method_used ,
54- )
55- assert_array_equal (results [0 ], regressor .predict (X ))
56- assert results [1 ] is None
57- if return_response_method_used :
58- assert results [2 ] == "predict"
39+ def test_estimator_unsupported_response (pyplot , estimator , response_method ):
40+ """Check the error message with not supported response method."""
41+ X , y = np .random .RandomState (0 ).randn (10 , 2 ), np .array ([0 , 1 ] * 5 )
42+ estimator .fit (X , y )
43+ err_msg = "has none of the following attributes:"
44+ with pytest .raises (AttributeError , match = err_msg ):
45+ _get_response_values (
46+ estimator ,
47+ X ,
48+ response_method = response_method ,
49+ )
5950
6051
6152@pytest .mark .parametrize (
62- "response_method" ,
63- ["predict" , "decision_function" , ["decision_function" , "predict" ]],
53+ "estimator, response_method" ,
54+ [
55+ (LinearRegression (), "predict" ),
56+ (KMeans (n_clusters = 2 , n_init = 1 ), "predict" ),
57+ (KMeans (n_clusters = 2 , n_init = 1 ), "score" ),
58+ (KMeans (n_clusters = 2 , n_init = 1 ), ["predict" , "score" ]),
59+ (IsolationForest (random_state = 0 ), "predict" ),
60+ (IsolationForest (random_state = 0 ), "decision_function" ),
61+ (IsolationForest (random_state = 0 ), ["decision_function" , "predict" ]),
62+ ],
6463)
6564@pytest .mark .parametrize ("return_response_method_used" , [True , False ])
66- def test_get_response_values_outlier_detection (
67- response_method , return_response_method_used
65+ def test_estimator_get_response_values (
66+ estimator , response_method , return_response_method_used
6867):
69- """Check the behaviour of `_get_response_values` with outlier detector ."""
70- X , y = make_classification ( n_samples = 50 , random_state = 0 )
71- outlier_detector = IsolationForest ( random_state = 0 ) .fit (X , y )
68+ """Check the behaviour of `_get_response_values`."""
69+ X , y = np . random . RandomState ( 0 ). randn ( 10 , 2 ), np . array ([ 0 , 1 ] * 5 )
70+ estimator .fit (X , y )
7271 results = _get_response_values (
73- outlier_detector ,
72+ estimator ,
7473 X ,
7574 response_method = response_method ,
7675 return_response_method_used = return_response_method_used ,
7776 )
7877 chosen_response_method = (
7978 response_method [0 ] if isinstance (response_method , list ) else response_method
8079 )
81- prediction_method = getattr (outlier_detector , chosen_response_method )
80+ prediction_method = getattr (estimator , chosen_response_method )
8281 assert_array_equal (results [0 ], prediction_method (X ))
8382 assert results [1 ] is None
8483 if return_response_method_used :
@@ -417,6 +416,8 @@ def test_response_values_type_of_target_on_classes_no_warning():
417416 (IsolationForest (), "predict" , "multiclass" , (10 ,)),
418417 (DecisionTreeRegressor (), "predict" , "binary" , (10 ,)),
419418 (DecisionTreeRegressor (), "predict" , "multiclass" , (10 ,)),
419+ (KMeans (n_clusters = 2 , n_init = 1 ), "predict" , "binary" , (10 ,)),
420+ (KMeans (n_clusters = 2 , n_init = 1 ), "predict" , "multiclass" , (10 ,)),
420421 ],
421422)
422423def test_response_values_output_shape_ (
@@ -430,8 +431,8 @@ def test_response_values_output_shape_(
430431 - with response_method="predict", it is a 1d array of shape `(n_samples,)`;
431432 - otherwise, it is a 2d array of shape `(n_samples, n_classes)`;
432433 - for multilabel classification, it is a 2d array of shape `(n_samples, n_outputs)`;
433- - for outlier detection, it is a 1d array of shape `(n_samples,)`;
434- - for regression, it is a 1d array of shape `(n_samples,)`.
434+ - for outlier detection, regression and clustering,
435+ it is a 1d array of shape `(n_samples,)`.
435436 """
436437 X = np .random .RandomState (0 ).randn (10 , 2 )
437438 if target_type == "binary" :
0 commit comments