Skip to content

Commit 487ff1d

Browse files
authored
bugfix numba gufuncs (#85)
1 parent 6e39f5a commit 487ff1d

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

scoringrules/_crps.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ def crps_ensemble(
4949
B = backends.active if backend is None else backends[backend]
5050
observations, forecasts = map(B.asarray, (observations, forecasts))
5151

52-
if estimator not in crps.estimator_gufuncs:
53-
raise ValueError(
54-
f"{estimator} is not a valid estimator. "
55-
f"Must be one of {crps.estimator_gufuncs.keys()}"
56-
)
57-
5852
if axis != -1:
5953
forecasts = B.moveaxis(forecasts, axis, -1)
6054

@@ -67,6 +61,11 @@ def crps_ensemble(
6761
forecasts = B.sort(forecasts, axis=-1)
6862

6963
if backend == "numba":
64+
if estimator not in crps.estimator_gufuncs:
65+
raise ValueError(
66+
f"{estimator} is not a valid estimator. "
67+
f"Must be one of {crps.estimator_gufuncs.keys()}"
68+
)
7069
return crps.estimator_gufuncs[estimator](observations, forecasts)
7170

7271
return crps.ensemble(observations, forecasts, estimator, backend=backend)

0 commit comments

Comments
 (0)