Skip to content

Commit 9562a6a

Browse files
committed
fix bug in fair kernel scores
1 parent 0c5a595 commit 9562a6a

File tree

2 files changed

+8
-70
lines changed

2 files changed

+8
-70
lines changed

scoringrules/core/kernels/_approx.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ def ensemble_uv(
3535
e_2 = B.sum(
3636
gauss_kern_uv(fct[..., None], fct[..., None, :], backend=backend),
3737
axis=(-1, -2),
38-
) / (M**2)
38+
)
3939
e_3 = gauss_kern_uv(obs, obs)
4040

4141
if estimator == "nrg":
42-
out = e_1 - 0.5 * e_2 - 0.5 * e_3
42+
out = e_1 - 0.5 * e_2 / (M**2) - 0.5 * e_3
4343
elif estimator == "fair":
44-
out = e_1 - 0.5 * e_2 * (M / (M - 1)) - 0.5 * e_3
44+
e_2 -= B.sum(gauss_kern_uv(fct, fct, backend=backend), axis=-1)
45+
out = e_1 - 0.5 * e_2 / (M * (M - 1)) - 0.5 * e_3
4546

4647
return -out
4748

@@ -65,13 +66,14 @@ def ensemble_mv(
6566
e_2 = B.sum(
6667
gauss_kern_mv(B.expand_dims(fct, -3), B.expand_dims(fct, -2), backend=backend),
6768
axis=(-2, -1),
68-
) / (M**2)
69+
)
6970
e_3 = gauss_kern_mv(obs, obs)
7071

7172
if estimator == "nrg":
72-
out = e_1 - 0.5 * e_2 - 0.5 * e_3
73+
out = e_1 - 0.5 * e_2 / (M**2) - 0.5 * e_3
7374
elif estimator == "fair":
74-
out = e_1 - 0.5 * e_2 * (M / (M - 1)) - 0.5 * e_3
75+
e_2 -= B.sum(gauss_kern_mv(fct, fct, backend=backend), axis=-1)
76+
out = e_1 - 0.5 * e_2 / (M * (M - 1)) - 0.5 * e_3
7577

7678
return -out
7779

tests/test_kernels.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,6 @@ def test_gksuv(estimator, backend):
5050
expected = 0.2490516
5151
assert np.isclose(res, expected)
5252

53-
elif estimator == "fair":
54-
# test correctness
55-
obs, fct = 11.6, np.array([9.8, 8.7, 11.9, 12.1, 13.4])
56-
res = sr.gksuv_ensemble(obs, fct, estimator=estimator, backend=backend)
57-
expected = 0.2987752
58-
assert np.isclose(res, expected)
59-
60-
# test exceptions
61-
with pytest.raises(ValueError):
62-
est = "undefined_estimator"
63-
sr.gksuv_ensemble(obs, fct, estimator=est, backend=backend)
64-
6553

6654
@pytest.mark.parametrize("estimator", ESTIMATORS)
6755
@pytest.mark.parametrize("backend", BACKENDS)
@@ -95,21 +83,6 @@ def test_gksmv(estimator, backend):
9583
expected = 0.5868737
9684
assert np.isclose(res, expected)
9785

98-
elif estimator == "fair":
99-
# test correctness
100-
obs = np.array([11.6, -23.1])
101-
fct = np.array(
102-
[[9.8, 8.7, 11.9, 12.1, 13.4], [-24.8, -18.5, -29.9, -18.3, -21.0]]
103-
).transpose()
104-
res = sr.gksmv_ensemble(obs, fct, estimator=estimator, backend=backend)
105-
expected = 0.6120162
106-
assert np.isclose(res, expected)
107-
108-
# test exceptions
109-
with pytest.raises(ValueError):
110-
est = "undefined_estimator"
111-
sr.gksmv_ensemble(obs, fct, estimator=est, backend=backend)
112-
11386

11487
@pytest.mark.parametrize("estimator", ESTIMATORS)
11588
@pytest.mark.parametrize("backend", BACKENDS)
@@ -213,43 +186,6 @@ def v_func2(x):
213186
)
214187
np.testing.assert_allclose(res, 0.0089314, rtol=1e-6)
215188

216-
elif estimator == "fair":
217-
res = np.mean(
218-
np.float64(
219-
sr.twgksuv_ensemble(
220-
obs, fct, v_func=v_func1, estimator=estimator, backend=backend
221-
)
222-
)
223-
)
224-
np.testing.assert_allclose(res, 0.130842, rtol=1e-6)
225-
226-
res = np.mean(
227-
np.float64(
228-
sr.twgksuv_ensemble(
229-
obs, fct, a=-1.0, estimator=estimator, backend=backend
230-
)
231-
)
232-
)
233-
np.testing.assert_allclose(res, 0.130842, rtol=1e-6)
234-
235-
res = np.mean(
236-
np.float64(
237-
sr.twgksuv_ensemble(
238-
obs, fct, v_func=v_func2, estimator=estimator, backend=backend
239-
)
240-
)
241-
)
242-
np.testing.assert_allclose(res, 0.1283745, rtol=1e-6)
243-
244-
res = np.mean(
245-
np.float64(
246-
sr.twgksuv_ensemble(
247-
obs, fct, b=1.85, estimator=estimator, backend=backend
248-
)
249-
)
250-
)
251-
np.testing.assert_allclose(res, 0.1283745, rtol=1e-6)
252-
253189

254190
@pytest.mark.parametrize("backend", BACKENDS)
255191
def test_twgksmv(backend):

0 commit comments

Comments
 (0)