@@ -54,13 +54,6 @@ def test_crps_ensemble(estimator, backend):
5454 res = np .asarray (res )
5555 assert not np .any (res - 0.0 > 0.0001 )
5656
57- # test equivalence of different estimators
58- res_nrg = sr .crps_ensemble (obs , fct , estimator = "nrg" , backend = backend )
59- res_pwm = sr .crps_ensemble (obs , fct , estimator = "pwm" , backend = backend )
60- res_qd = sr .crps_ensemble (obs , fct , estimator = "qd" , backend = backend )
61- assert np .allclose (res_nrg , res_pwm )
62- assert np .allclose (res_nrg , res_qd )
63-
6457
6558@pytest .mark .parametrize ("backend" , BACKENDS )
6659def test_crps_ensemble_corr (backend ):
@@ -73,15 +66,23 @@ def test_crps_ensemble_corr(backend):
7366 res_nrg = sr .crps_ensemble (obs , fct , estimator = "nrg" , backend = backend )
7467 res_pwm = sr .crps_ensemble (obs , fct , estimator = "pwm" , backend = backend )
7568 res_qd = sr .crps_ensemble (obs , fct , estimator = "qd" , backend = backend )
76- assert np .allclose (res_nrg , res_pwm )
77- assert np .allclose (res_nrg , res_qd )
69+ if backend == "torch" :
70+ assert np .allclose (res_nrg , res_pwm , rtol = 1e-04 )
71+ assert np .allclose (res_nrg , res_qd , rtol = 1e-04 )
72+ else :
73+ assert np .allclose (res_nrg , res_pwm )
74+ assert np .allclose (res_nrg , res_qd )
7875
7976 w = np .abs (np .random .randn (N , ENSEMBLE_SIZE ) * sigma [..., None ])
8077 res_nrg = sr .crps_ensemble (obs , fct , ens_w = w , estimator = "nrg" , backend = backend )
8178 res_pwm = sr .crps_ensemble (obs , fct , ens_w = w , estimator = "pwm" , backend = backend )
8279 res_qd = sr .crps_ensemble (obs , fct , ens_w = w , estimator = "qd" , backend = backend )
83- assert np .allclose (res_nrg , res_pwm )
84- assert np .allclose (res_nrg , res_qd )
80+ if backend == "torch" :
81+ assert np .allclose (res_nrg , res_pwm , rtol = 1e-04 )
82+ assert np .allclose (res_nrg , res_qd , rtol = 1e-04 )
83+ else :
84+ assert np .allclose (res_nrg , res_pwm )
85+ assert np .allclose (res_nrg , res_qd )
8586
8687 # test correctness
8788 obs = - 0.6042506
0 commit comments