88
99
1010def crps_ensemble (
11- forecasts : "Array" ,
1211 observations : "ArrayLike" ,
12+ forecasts : "Array" ,
1313 / ,
1414 axis : int = - 1 ,
1515 * ,
@@ -21,11 +21,11 @@ def crps_ensemble(
2121
2222 Parameters
2323 ----------
24+ observations: ArrayLike
25+ The observed values.
2426 forecasts: ArrayLike
2527 The predicted forecast ensemble, where the ensemble dimension is by default
2628 represented by the last axis.
27- observations: ArrayLike
28- The observed values.
2929 axis: int
3030 The axis corresponding to the ensemble. Default is the last axis.
3131 sorted_ensemble: bool
@@ -47,7 +47,7 @@ def crps_ensemble(
4747 >>> crps.ensemble(pred, obs)
4848 """
4949 B = backends .active if backend is None else backends [backend ]
50- forecasts , observations = map (B .asarray , (forecasts , observations ))
50+ observations , forecasts = map (B .asarray , (observations , forecasts ))
5151
5252 if estimator not in crps .estimator_gufuncs :
5353 raise ValueError (
@@ -62,14 +62,14 @@ def crps_ensemble(
6262 forecasts = B .sort (forecasts , axis = - 1 )
6363
6464 if backend == "numba" :
65- return crps .estimator_gufuncs [estimator ](forecasts , observations )
65+ return crps .estimator_gufuncs [estimator ](observations , forecasts )
6666
67- return crps .ensemble (forecasts , observations , estimator , backend = backend )
67+ return crps .ensemble (observations , forecasts , estimator , backend = backend )
6868
6969
7070def twcrps_ensemble (
71- forecasts : "Array" ,
7271 observations : "ArrayLike" ,
72+ forecasts : "Array" ,
7373 v_func : tp .Callable [["ArrayLike" ], "ArrayLike" ],
7474 / ,
7575 axis : int = - 1 ,
@@ -91,11 +91,11 @@ def twcrps_ensemble(
9191
9292 Parameters
9393 ----------
94+ observations: ArrayLike
95+ The observed values.
9496 forecasts: ArrayLike
9597 The predicted forecast ensemble, where the ensemble dimension is by default
9698 represented by the last axis.
97- observations: ArrayLike
98- The observed values.
9999 v_func: tp.Callable
100100 Chaining function used to emphasise particular outcomes. For example, a function that
101101 only considers values above a certain threshold $t$ by projecting forecasts and observations
@@ -115,10 +115,10 @@ def twcrps_ensemble(
115115 >>> from scoringrules import crps
116116 >>> twcrps.ensemble(pred, obs)
117117 """
118- forecasts , observations = map (v_func , (forecasts , observations ))
118+ observations , forecasts = map (v_func , (observations , forecasts ))
119119 return crps_ensemble (
120- forecasts ,
121120 observations ,
121+ forecasts ,
122122 axis = axis ,
123123 sorted_ensemble = sorted_ensemble ,
124124 estimator = estimator ,
@@ -127,8 +127,8 @@ def twcrps_ensemble(
127127
128128
129129def owcrps_ensemble (
130- forecasts : "Array" ,
131130 observations : "ArrayLike" ,
131+ forecasts : "Array" ,
132132 w_func : tp .Callable [["ArrayLike" ], "ArrayLike" ],
133133 / ,
134134 axis : int = - 1 ,
@@ -149,11 +149,11 @@ def owcrps_ensemble(
149149
150150 Parameters
151151 ----------
152+ observations: ArrayLike
153+ The observed values.
152154 forecasts: ArrayLike
153155 The predicted forecast ensemble, where the ensemble dimension is by default
154156 represented by the last axis.
155- observations: ArrayLike
156- The observed values.
157157 w_func: tp.Callable
158158 Weight function used to emphasise particular outcomes.
159159 axis: int
@@ -181,24 +181,24 @@ def owcrps_ensemble(
181181 if axis != - 1 :
182182 forecasts = B .moveaxis (forecasts , axis , - 1 )
183183
184- fcts_weights , obs_weights = map (w_func , (forecasts , observations ))
184+ obs_weights , fct_weights = map (w_func , (observations , forecasts ))
185185
186186 if backend == "numba" :
187187 return crps .estimator_gufuncs ["ow" + estimator ](
188- forecasts , observations , fcts_weights , obs_weights
188+ observations , forecasts , obs_weights , fct_weights
189189 )
190190
191- forecasts , observations , fcts_weights , obs_weights = map (
192- B .asarray , (forecasts , observations , fcts_weights , obs_weights )
191+ observations , forecasts , obs_weights , fct_weights = map (
192+ B .asarray , (observations , forecasts , obs_weights , fct_weights )
193193 )
194194 return crps .ow_ensemble (
195- forecasts , observations , fcts_weights , obs_weights , backend = backend
195+ observations , forecasts , obs_weights , fct_weights , backend = backend
196196 )
197197
198198
199199def vrcrps_ensemble (
200- forecasts : "Array" ,
201200 observations : "ArrayLike" ,
201+ forecasts : "Array" ,
202202 w_func : tp .Callable [["ArrayLike" ], "ArrayLike" ],
203203 / ,
204204 axis : int = - 1 ,
@@ -224,11 +224,11 @@ def vrcrps_ensemble(
224224
225225 Parameters
226226 ----------
227+ observations: ArrayLike
228+ The observed values.
227229 forecasts: ArrayLike
228230 The predicted forecast ensemble, where the ensemble dimension is by default
229231 represented by the last axis.
230- observations: ArrayLike
231- The observed values.
232232 w_func: tp.Callable
233233 Weight function used to emphasise particular outcomes.
234234 axis: int
@@ -256,25 +256,25 @@ def vrcrps_ensemble(
256256 if axis != - 1 :
257257 forecasts = B .moveaxis (forecasts , axis , - 1 )
258258
259- fcts_weights , obs_weights = map (w_func , (forecasts , observations ))
259+ obs_weights , fct_weights = map (w_func , (observations , forecasts ))
260260
261261 if backend == "numba" :
262262 return crps .estimator_gufuncs ["vr" + estimator ](
263- forecasts , observations , fcts_weights , obs_weights
263+ observations , forecasts , obs_weights , fct_weights
264264 )
265265
266- forecasts , observations , fcts_weights , obs_weights = map (
267- B .asarray , (forecasts , observations , fcts_weights , obs_weights )
266+ observations , forecasts , obs_weights , fct_weights = map (
267+ B .asarray , (observations , forecasts , obs_weights , fct_weights )
268268 )
269269 return crps .vr_ensemble (
270- forecasts , observations , fcts_weights , obs_weights , backend = backend
270+ observations , forecasts , obs_weights , fct_weights , backend = backend
271271 )
272272
273273
274274def crps_normal (
275+ observation : "ArrayLike" ,
275276 mu : "ArrayLike" ,
276277 sigma : "ArrayLike" ,
277- observation : "ArrayLike" ,
278278 / ,
279279 * ,
280280 backend : "Backend" = None ,
@@ -291,12 +291,12 @@ def crps_normal(
291291
292292 Parameters
293293 ----------
294+ observations: ArrayLike
295+ The observed values.
294296 mu: ArrayLike
295297 Mean of the forecast normal distribution.
296298 sigma: ArrayLike
297299 Standard deviation of the forecast normal distribution.
298- observation: ArrayLike
299- The observed values.
300300
301301 Returns
302302 -------
@@ -308,13 +308,13 @@ def crps_normal(
308308 >>> from scoringrules import crps
309309 >>> crps.normal(0.1, 0.4, 0.0)
310310 """
311- return crps .normal (mu , sigma , observation , backend = backend )
311+ return crps .normal (observation , mu , sigma , backend = backend )
312312
313313
314314def crps_lognormal (
315+ observation : "ArrayLike" ,
315316 mulog : "ArrayLike" ,
316317 sigmalog : "ArrayLike" ,
317- observation : "ArrayLike" ,
318318 backend : "Backend" = None ,
319319) -> "ArrayLike" :
320320 r"""Compute the closed form of the CRPS for the lognormal distribution.
@@ -334,6 +334,8 @@ def crps_lognormal(
334334
335335 Parameters
336336 ----------
337+ observations: ArrayLike
338+ The observed values.
337339 mulog: ArrayLike
338340 Mean of the normal underlying distribution.
339341 sigmalog: ArrayLike
@@ -349,13 +351,13 @@ def crps_lognormal(
349351 >>> from scoringrules import crps
350352 >>> crps.lognormal(0.1, 0.4, 0.0)
351353 """
352- return crps .lognormal (mulog , sigmalog , observation , backend = backend )
354+ return crps .lognormal (observation , mulog , sigmalog , backend = backend )
353355
354356
355357def crps_logistic (
358+ observation : "ArrayLike" ,
356359 mu : "ArrayLike" ,
357360 sigma : "ArrayLike" ,
358- observation : "ArrayLike" ,
359361 / ,
360362 * ,
361363 backend : "Backend" = None ,
@@ -372,12 +374,12 @@ def crps_logistic(
372374
373375 Parameters
374376 ----------
377+ observations: ArrayLike
378+ Observed values.
375379 mu: ArrayLike
376380 Location parameter of the forecast logistic distribution.
377381 sigma: ArrayLike
378382 Scale parameter of the forecast logistic distribution.
379- observation: ArrayLike
380- Observed values.
381383
382384 Returns
383385 -------
@@ -389,7 +391,7 @@ def crps_logistic(
389391 >>> from scoringrules import crps
390392 >>> crps.logistic(0.1, 0.4, 0.0)
391393 """
392- return crps .logistic (mu , sigma , observation , backend = backend )
394+ return crps .logistic (observation , mu , sigma , backend = backend )
393395
394396
395397__all__ = [
0 commit comments