Skip to content

Commit 5cdd2e4

Browse files
committed
fix tolerance in crps jax tests
1 parent b7fa375 commit 5cdd2e4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_crps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_crps_ensemble_corr(backend):
6666
res_nrg = sr.crps_ensemble(obs, fct, estimator="nrg", backend=backend)
6767
res_pwm = sr.crps_ensemble(obs, fct, estimator="pwm", backend=backend)
6868
res_qd = sr.crps_ensemble(obs, fct, estimator="qd", backend=backend)
69-
if backend == "torch":
69+
if backend in ["torch", "jax"]:
7070
assert np.allclose(res_nrg, res_pwm, rtol=1e-03)
7171
assert np.allclose(res_nrg, res_qd, rtol=1e-03)
7272
else:
@@ -77,7 +77,7 @@ def test_crps_ensemble_corr(backend):
7777
res_nrg = sr.crps_ensemble(obs, fct, ens_w=w, estimator="nrg", backend=backend)
7878
res_pwm = sr.crps_ensemble(obs, fct, ens_w=w, estimator="pwm", backend=backend)
7979
res_qd = sr.crps_ensemble(obs, fct, ens_w=w, estimator="qd", backend=backend)
80-
if backend == "torch":
80+
if backend in ["torch", "jax"]:
8181
assert np.allclose(res_nrg, res_pwm, rtol=1e-03)
8282
assert np.allclose(res_nrg, res_qd, rtol=1e-03)
8383
else:

0 commit comments

Comments
 (0)