Skip to content

Commit 0f98025

Browse files
committed
minor jax tolerance bug fix
1 parent fa4edce commit 0f98025

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tests/test_crps.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,6 @@ def test_crps_ensemble(estimator, backend):
5454
res = np.asarray(res)
5555
assert not np.any(res - 0.0 > 0.0001)
5656

57-
# test equivalence of different estimators
58-
res_nrg = sr.crps_ensemble(obs, fct, estimator="nrg", backend=backend)
59-
res_pwm = sr.crps_ensemble(obs, fct, estimator="pwm", backend=backend)
60-
res_qd = sr.crps_ensemble(obs, fct, estimator="qd", backend=backend)
61-
assert np.allclose(res_nrg, res_pwm)
62-
assert np.allclose(res_nrg, res_qd)
63-
6457

6558
@pytest.mark.parametrize("backend", BACKENDS)
6659
def test_crps_ensemble_corr(backend):
@@ -73,15 +66,23 @@ def test_crps_ensemble_corr(backend):
7366
res_nrg = sr.crps_ensemble(obs, fct, estimator="nrg", backend=backend)
7467
res_pwm = sr.crps_ensemble(obs, fct, estimator="pwm", backend=backend)
7568
res_qd = sr.crps_ensemble(obs, fct, estimator="qd", backend=backend)
76-
assert np.allclose(res_nrg, res_pwm)
77-
assert np.allclose(res_nrg, res_qd)
69+
if backend == "torch":
70+
assert np.allclose(res_nrg, res_pwm, rtol=1e-04)
71+
assert np.allclose(res_nrg, res_qd, rtol=1e-04)
72+
else:
73+
assert np.allclose(res_nrg, res_pwm)
74+
assert np.allclose(res_nrg, res_qd)
7875

7976
w = np.abs(np.random.randn(N, ENSEMBLE_SIZE) * sigma[..., None])
8077
res_nrg = sr.crps_ensemble(obs, fct, ens_w=w, estimator="nrg", backend=backend)
8178
res_pwm = sr.crps_ensemble(obs, fct, ens_w=w, estimator="pwm", backend=backend)
8279
res_qd = sr.crps_ensemble(obs, fct, ens_w=w, estimator="qd", backend=backend)
83-
assert np.allclose(res_nrg, res_pwm)
84-
assert np.allclose(res_nrg, res_qd)
80+
if backend == "torch":
81+
assert np.allclose(res_nrg, res_pwm, rtol=1e-04)
82+
assert np.allclose(res_nrg, res_qd, rtol=1e-04)
83+
else:
84+
assert np.allclose(res_nrg, res_pwm)
85+
assert np.allclose(res_nrg, res_qd)
8586

8687
# test correctness
8788
obs = -0.6042506

0 commit comments

Comments
 (0)