@@ -36,7 +36,7 @@ def LinearRegression_fit_wrapper(*args, **kwargs):
36
36
# TEST LinearRegression class
37
37
38
38
39
- def test_LR_params ():
39
+ def test_lr_params ():
40
40
41
41
lr = mesmer .stats .LinearRegression ()
42
42
@@ -80,20 +80,14 @@ def test_LR_params():
80
80
81
81
82
82
@pytest .mark .parametrize ("as_2D" , [True , False ])
83
- def test_LR_predict (as_2D ):
83
+ def test_lr_predict (as_2D ):
84
84
lr = mesmer .stats .LinearRegression ()
85
85
86
86
params = xr .Dataset (
87
87
data_vars = {"intercept" : ("x" , [5 ]), "fit_intercept" : True , "tas" : ("x" , [3 ])}
88
88
)
89
89
lr .params = params if as_2D else params .squeeze ()
90
90
91
- with pytest .raises (ValueError , match = "Missing or superfluous predictors" ):
92
- lr .predict ({})
93
-
94
- with pytest .raises (ValueError , match = "Missing or superfluous predictors" ):
95
- lr .predict ({"tas" : None , "something else" : None })
96
-
97
91
tas = xr .DataArray ([0 , 1 , 2 ], dims = "time" )
98
92
99
93
result = lr .predict ({"tas" : tas })
@@ -103,6 +97,32 @@ def test_LR_predict(as_2D):
103
97
xr .testing .assert_equal (result , expected )
104
98
105
99
100
+ def test_lr_predict_missing_superfluous ():
101
+ lr = mesmer .stats .LinearRegression ()
102
+
103
+ params = xr .Dataset (
104
+ data_vars = {
105
+ "intercept" : ("x" , [5 ]),
106
+ "fit_intercept" : True ,
107
+ "tas" : ("x" , [3 ]),
108
+ "tas2" : ("x" , [1 ]),
109
+ }
110
+ )
111
+ lr .params = params
112
+
113
+ with pytest .raises (ValueError , match = "Missing predictors: 'tas', 'tas2'" ):
114
+ lr .predict ({})
115
+
116
+ with pytest .raises (ValueError , match = "Missing predictors: 'tas'" ):
117
+ lr .predict ({"tas2" : None })
118
+
119
+ with pytest .raises (ValueError , match = "Superfluous predictors: 'something else'" ):
120
+ lr .predict ({"tas" : None , "tas2" : None , "something else" : None })
121
+
122
+ with pytest .raises (ValueError , match = "Superfluous predictors: 'bar', 'foo'" ):
123
+ lr .predict ({"tas" : None , "tas2" : None , "foo" : None , "bar" : None })
124
+
125
+
106
126
@pytest .mark .parametrize ("as_2D" , [True , False ])
107
127
def test_lr_predict_exclude (as_2D ):
108
128
lr = mesmer .stats .LinearRegression ()
0 commit comments