Skip to content

Commit 67b8882

Browse files
committed
fix bugs in variogram scores
1 parent 9562a6a commit 67b8882

File tree

4 files changed

+16
-26
lines changed

4 files changed

+16
-26
lines changed

scoringrules/_variogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def vs_ensemble(
1515
m_axis: int = -2,
1616
v_axis: int = -1,
1717
*,
18-
p: float = 1.0,
18+
p: float = 0.5,
1919
estimator: str = "nrg",
2020
backend: "Backend" = None,
2121
) -> "Array":

scoringrules/core/variogram/_score.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def vs_ensemble(
2323
obs_diff = B.abs(B.expand_dims(obs, -2) - B.expand_dims(obs, -1)) ** p # (... D D)
2424

2525
if estimator == "nrg":
26-
vfct = B.sum(B.abs(fct_diff) ** p, axis=-3) / M # (... D D)
26+
vfct = B.sum(fct_diff, axis=-3) / M # (... D D)
2727
out = B.sum((obs_diff - vfct) ** 2, axis=(-2, -1)) # (...)
2828

2929
elif estimator == "fair":
@@ -52,29 +52,20 @@ def owvs_ensemble(
5252
) -> "Array":
5353
"""Compute the Outcome-Weighted Variogram Score for a multivariate finite ensemble."""
5454
B = backends.active if backend is None else backends[backend]
55-
M: int = fct.shape[-2]
56-
wbar = B.mean(fw, axis=-1)
57-
58-
fct_diff = B.expand_dims(fct, -2) - B.expand_dims(fct, -1) # (... M D D)
59-
fct_diff = B.abs(fct_diff) ** p # (... M D D)
55+
M = fct.shape[-2]
56+
wbar = B.sum(fw, -1) / M
6057

61-
obs_diff = B.expand_dims(obs, -2) - B.expand_dims(obs, -1) # (... D D)
62-
obs_diff = B.abs(obs_diff) ** p # (... D D)
63-
del obs, fct
64-
65-
E_1 = (fct_diff - B.expand_dims(obs_diff, -3)) ** 2 # (... M D D)
66-
E_1 = B.sum(E_1, axis=(-2, -1)) # (... M)
67-
E_1 = B.sum(E_1 * fw * B.expand_dims(ow, -1), axis=-1) / (M * wbar) # (...)
58+
fct_diff = (
59+
B.abs(B.expand_dims(fct, -2) - B.expand_dims(fct, -1)) ** p
60+
) # (... M D D)
61+
obs_diff = B.abs(B.expand_dims(obs, -2) - B.expand_dims(obs, -1)) ** p # (... D D)
6862

69-
fct_diff_spread = B.expand_dims(fct_diff, -3) - B.expand_dims(
70-
fct_diff, -4
71-
) # (... M M D D)
72-
fw_prod = B.expand_dims(fw, -2) * B.expand_dims(fw, -1) # (... M M)
73-
E_2 = B.sum(fct_diff_spread**2, axis=(-2, -1)) # (... M M)
74-
E_2 *= fw_prod * B.expand_dims(ow, (-2, -1)) # (... M M)
75-
E_2 = B.sum(E_2, axis=(-2, -1)) / (M**2 * wbar**2) # (...)
63+
vfct = B.sum(fct_diff * B.expand_dims(fw, (-2, -1)), axis=-3) / (
64+
M * B.expand_dims(wbar, (-2, -1))
65+
) # (... D D)
66+
out = B.sum(((obs_diff - vfct) ** 2), axis=(-2, -1)) * ow # (...)
7667

77-
return E_1 - 0.5 * E_2
68+
return out
7869

7970

8071
def vrvs_ensemble(

tests/test_variogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def test_variogram_score_correctness(backend):
4949
np.testing.assert_allclose(res, 0.05083489, rtol=1e-5)
5050

5151
res = sr.vs_ensemble(obs, fct.T, p=1.0, backend=backend)
52-
np.testing.assert_allclose(res, 0.04856365, rtol=1e-5)
52+
np.testing.assert_allclose(res, 0.04856366, rtol=1e-5)

tests/test_wvariogram.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
45
import scoringrules as sr
56
from scoringrules.backend import backends
67

@@ -23,9 +24,7 @@ def test_owvs_vs_vs(backend):
2324
lambda x: backends[backend].mean(x) * 0.0 + 1.0,
2425
backend=backend,
2526
)
26-
np.testing.assert_allclose(
27-
res, resw, rtol=1e-3
28-
) # TODO: not sure why tolerance must be so high
27+
np.testing.assert_allclose(res, resw, rtol=1e-3)
2928

3029

3130
@pytest.mark.parametrize("backend", BACKENDS)

0 commit comments

Comments
 (0)