diff --git a/scoringrules/core/crps/_closed.py b/scoringrules/core/crps/_closed.py index 7edc13e..4fc7822 100644 --- a/scoringrules/core/crps/_closed.py +++ b/scoringrules/core/crps/_closed.py @@ -628,8 +628,12 @@ def mixnorm( s_X = B.sqrt(s[..., None] ** 2 + s[..., None, :] ** 2) w_X = w[..., None] * w[..., None, :] - A_y = m_y * (2 * _norm_cdf(m_y / s) - 1) + 2 * s * _norm_pdf(m_y / s) - A_X = m_X * (2 * _norm_cdf(m_X / s_X) - 1) + 2 * s_X * _norm_pdf(m_X / s_X) + A_y = m_y * (2 * _norm_cdf(m_y / s, backend=backend) - 1) + 2 * s * _norm_pdf( + m_y / s, backend=backend + ) + A_X = m_X * (2 * _norm_cdf(m_X / s_X, backend=backend) - 1) + 2 * s_X * _norm_pdf( + m_X / s_X, backend=backend + ) sc_1 = B.sum(w * A_y, axis=-1) sc_2 = B.sum(w_X * A_X, axis=(-1, -2))