Skip to content

Commit c0069eb

Browse files
authored
lin reg: mention missing & superfluous predictors (#359)
1 parent 80856c7 commit c0069eb

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

Diff for: mesmer/stats/_linear_regression.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,15 @@ def predict(
8080
required_predictors = set(params.data_vars) - non_predictor_vars - exclude
8181
available_predictors = set(predictors.keys())
8282

83-
if required_predictors != available_predictors:
84-
raise ValueError("Missing or superfluous predictors.")
83+
if required_predictors - available_predictors:
84+
missing = sorted(required_predictors - available_predictors)
85+
missing = "', '".join(missing)
86+
raise ValueError(f"Missing predictors: '{missing}'")
87+
88+
if available_predictors - required_predictors:
89+
superfluous = sorted(available_predictors - required_predictors)
90+
superfluous = "', '".join(superfluous)
91+
raise ValueError(f"Superfluous predictors: '{superfluous}'")
8592

8693
if "intercept" in exclude:
8794
prediction = xr.zeros_like(params.intercept)

Diff for: tests/unit/test_linear_regression.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def LinearRegression_fit_wrapper(*args, **kwargs):
3636
# TEST LinearRegression class
3737

3838

39-
def test_LR_params():
39+
def test_lr_params():
4040

4141
lr = mesmer.stats.LinearRegression()
4242

@@ -80,20 +80,14 @@ def test_LR_params():
8080

8181

8282
@pytest.mark.parametrize("as_2D", [True, False])
83-
def test_LR_predict(as_2D):
83+
def test_lr_predict(as_2D):
8484
lr = mesmer.stats.LinearRegression()
8585

8686
params = xr.Dataset(
8787
data_vars={"intercept": ("x", [5]), "fit_intercept": True, "tas": ("x", [3])}
8888
)
8989
lr.params = params if as_2D else params.squeeze()
9090

91-
with pytest.raises(ValueError, match="Missing or superfluous predictors"):
92-
lr.predict({})
93-
94-
with pytest.raises(ValueError, match="Missing or superfluous predictors"):
95-
lr.predict({"tas": None, "something else": None})
96-
9791
tas = xr.DataArray([0, 1, 2], dims="time")
9892

9993
result = lr.predict({"tas": tas})
@@ -103,6 +97,32 @@ def test_LR_predict(as_2D):
10397
xr.testing.assert_equal(result, expected)
10498

10599

100+
def test_lr_predict_missing_superfluous():
101+
lr = mesmer.stats.LinearRegression()
102+
103+
params = xr.Dataset(
104+
data_vars={
105+
"intercept": ("x", [5]),
106+
"fit_intercept": True,
107+
"tas": ("x", [3]),
108+
"tas2": ("x", [1]),
109+
}
110+
)
111+
lr.params = params
112+
113+
with pytest.raises(ValueError, match="Missing predictors: 'tas', 'tas2'"):
114+
lr.predict({})
115+
116+
with pytest.raises(ValueError, match="Missing predictors: 'tas'"):
117+
lr.predict({"tas2": None})
118+
119+
with pytest.raises(ValueError, match="Superfluous predictors: 'something else'"):
120+
lr.predict({"tas": None, "tas2": None, "something else": None})
121+
122+
with pytest.raises(ValueError, match="Superfluous predictors: 'bar', 'foo'"):
123+
lr.predict({"tas": None, "tas2": None, "foo": None, "bar": None})
124+
125+
106126
@pytest.mark.parametrize("as_2D", [True, False])
107127
def test_lr_predict_exclude(as_2D):
108128
lr = mesmer.stats.LinearRegression()

0 commit comments

Comments
 (0)