Skip to content

Commit 8c31568

Browse files
authored
Change order of positional arguments (obs first) (#18)
* change public api signatures * change core functions' signatures * change docs and readme files * change tests' args order * change order in energy score functions
1 parent 3d07128 commit 8c31568

29 files changed

+407
-437
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ import scoringrules as sr
3636
import numpy as np
3737

3838
obs = np.random.randn(100)
39-
fcts = obs[:,None] + np.random.randn(100, 21) * 0.1
40-
sr.crps_ensemble(fcts, obs)
39+
fct = obs[:,None] + np.random.randn(100, 21) * 0.1
40+
sr.crps_ensemble(obs, fct)
4141
```
4242

4343
## Metrics

docs/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ import scoringrules as sr
3232
import numpy as np
3333

3434
obs = np.random.randn(100)
35-
fcts = obs[:,None] + np.random.randn(100, 21) * 0.1
36-
sr.crps_ensemble(fcts, obs)
35+
fct = obs[:,None] + np.random.randn(100, 21) * 0.1
36+
sr.crps_ensemble(obs, fct)
3737
```
3838

3939
## Metrics

docs/user_guide.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@ sr.crps_normal(0.1, 1.2, 0.3)
1717

1818
# on arrays
1919
sr.brier_score(np.random.uniform(0, 1, 100), np.random.binomial(1, 0.5, 100))
20-
sr.crps_lognormal(np.random.randn(100), np.random.uniform(0.5, 1.5, 100), np.random.lognormal(0, 1, 100))
20+
sr.crps_lognormal(np.random.lognormal(0, 1, 100), np.random.randn(100), np.random.uniform(0.5, 1.5, 100))
2121

2222
# ensemble metrics
2323
obs = np.random.randn(100)
24-
fcts = obs[:,None] + np.random.randn(100, 21) * 0.1
24+
fct = obs[:,None] + np.random.randn(100, 21) * 0.1
2525

26-
sr.crps_ensemble(fcts, obs)
27-
sr.error_spread_score(fcts, obs)
26+
sr.crps_ensemble(obs, fct)
27+
sr.error_spread_score(obs, fct)
2828

2929
# multivariate ensemble metrics
3030
obs = np.random.randn(100,3)
31-
fcts = obs[:,None] + np.random.randn(100, 21, 3) * 0.1
31+
fct = obs[:,None] + np.random.randn(100, 21, 3) * 0.1
3232

33-
sr.energy_score(fcts, obs)
34-
sr.variogram_score(fcts, obs)
33+
sr.energy_score(obs, fct)
34+
sr.variogram_score(obs, fct)
3535
```
3636

3737
For the univariate ensemble metrics, the ensemble dimension is on the last axis unless you specify otherwise with the `axis` argument. For the multivariate ensemble metrics, the ensemble dimension and the variable dimension are on the second last and last axis respectively, unless specified otherwise with `m_axis` and `v_axis`.

scoringrules/_brier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88

99
def brier_score(
10-
forecasts: "ArrayLike",
1110
observations: "ArrayLike",
11+
forecasts: "ArrayLike",
1212
/,
1313
*,
1414
backend: "Backend" = None,
@@ -24,10 +24,10 @@ def brier_score(
2424
2525
Parameters
2626
----------
27-
forecasts : NDArray
28-
Forecasted probabilities between 0 and 1.
2927
observations: NDArray
3028
Observed outcome, either 0 or 1.
29+
forecasts : NDArray
30+
Forecasted probabilities between 0 and 1.
3131
backend: str
3232
The name of the backend used for computations. Defaults to 'numpy'.
3333

scoringrules/_crps.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
def crps_ensemble(
11-
forecasts: "Array",
1211
observations: "ArrayLike",
12+
forecasts: "Array",
1313
/,
1414
axis: int = -1,
1515
*,
@@ -21,11 +21,11 @@ def crps_ensemble(
2121
2222
Parameters
2323
----------
24+
observations: ArrayLike
25+
The observed values.
2426
forecasts: ArrayLike
2527
The predicted forecast ensemble, where the ensemble dimension is by default
2628
represented by the last axis.
27-
observations: ArrayLike
28-
The observed values.
2929
axis: int
3030
The axis corresponding to the ensemble. Default is the last axis.
3131
sorted_ensemble: bool
@@ -47,7 +47,7 @@ def crps_ensemble(
4747
>>> crps.ensemble(pred, obs)
4848
"""
4949
B = backends.active if backend is None else backends[backend]
50-
forecasts, observations = map(B.asarray, (forecasts, observations))
50+
observations, forecasts = map(B.asarray, (observations, forecasts))
5151

5252
if estimator not in crps.estimator_gufuncs:
5353
raise ValueError(
@@ -62,14 +62,14 @@ def crps_ensemble(
6262
forecasts = B.sort(forecasts, axis=-1)
6363

6464
if backend == "numba":
65-
return crps.estimator_gufuncs[estimator](forecasts, observations)
65+
return crps.estimator_gufuncs[estimator](observations, forecasts)
6666

67-
return crps.ensemble(forecasts, observations, estimator, backend=backend)
67+
return crps.ensemble(observations, forecasts, estimator, backend=backend)
6868

6969

7070
def twcrps_ensemble(
71-
forecasts: "Array",
7271
observations: "ArrayLike",
72+
forecasts: "Array",
7373
v_func: tp.Callable[["ArrayLike"], "ArrayLike"],
7474
/,
7575
axis: int = -1,
@@ -91,11 +91,11 @@ def twcrps_ensemble(
9191
9292
Parameters
9393
----------
94+
observations: ArrayLike
95+
The observed values.
9496
forecasts: ArrayLike
9597
The predicted forecast ensemble, where the ensemble dimension is by default
9698
represented by the last axis.
97-
observations: ArrayLike
98-
The observed values.
9999
v_func: tp.Callable
100100
Chaining function used to emphasise particular outcomes. For example, a function that
101101
only considers values above a certain threshold $t$ by projecting forecasts and observations
@@ -115,10 +115,10 @@ def twcrps_ensemble(
115115
>>> from scoringrules import crps
116116
>>> twcrps.ensemble(pred, obs)
117117
"""
118-
forecasts, observations = map(v_func, (forecasts, observations))
118+
observations, forecasts = map(v_func, (observations, forecasts))
119119
return crps_ensemble(
120-
forecasts,
121120
observations,
121+
forecasts,
122122
axis=axis,
123123
sorted_ensemble=sorted_ensemble,
124124
estimator=estimator,
@@ -127,8 +127,8 @@ def twcrps_ensemble(
127127

128128

129129
def owcrps_ensemble(
130-
forecasts: "Array",
131130
observations: "ArrayLike",
131+
forecasts: "Array",
132132
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
133133
/,
134134
axis: int = -1,
@@ -149,11 +149,11 @@ def owcrps_ensemble(
149149
150150
Parameters
151151
----------
152+
observations: ArrayLike
153+
The observed values.
152154
forecasts: ArrayLike
153155
The predicted forecast ensemble, where the ensemble dimension is by default
154156
represented by the last axis.
155-
observations: ArrayLike
156-
The observed values.
157157
w_func: tp.Callable
158158
Weight function used to emphasise particular outcomes.
159159
axis: int
@@ -181,24 +181,24 @@ def owcrps_ensemble(
181181
if axis != -1:
182182
forecasts = B.moveaxis(forecasts, axis, -1)
183183

184-
fcts_weights, obs_weights = map(w_func, (forecasts, observations))
184+
obs_weights, fct_weights = map(w_func, (observations, forecasts))
185185

186186
if backend == "numba":
187187
return crps.estimator_gufuncs["ow" + estimator](
188-
forecasts, observations, fcts_weights, obs_weights
188+
observations, forecasts, obs_weights, fct_weights
189189
)
190190

191-
forecasts, observations, fcts_weights, obs_weights = map(
192-
B.asarray, (forecasts, observations, fcts_weights, obs_weights)
191+
observations, forecasts, obs_weights, fct_weights = map(
192+
B.asarray, (observations, forecasts, obs_weights, fct_weights)
193193
)
194194
return crps.ow_ensemble(
195-
forecasts, observations, fcts_weights, obs_weights, backend=backend
195+
observations, forecasts, obs_weights, fct_weights, backend=backend
196196
)
197197

198198

199199
def vrcrps_ensemble(
200-
forecasts: "Array",
201200
observations: "ArrayLike",
201+
forecasts: "Array",
202202
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
203203
/,
204204
axis: int = -1,
@@ -224,11 +224,11 @@ def vrcrps_ensemble(
224224
225225
Parameters
226226
----------
227+
observations: ArrayLike
228+
The observed values.
227229
forecasts: ArrayLike
228230
The predicted forecast ensemble, where the ensemble dimension is by default
229231
represented by the last axis.
230-
observations: ArrayLike
231-
The observed values.
232232
w_func: tp.Callable
233233
Weight function used to emphasise particular outcomes.
234234
axis: int
@@ -256,25 +256,25 @@ def vrcrps_ensemble(
256256
if axis != -1:
257257
forecasts = B.moveaxis(forecasts, axis, -1)
258258

259-
fcts_weights, obs_weights = map(w_func, (forecasts, observations))
259+
obs_weights, fct_weights = map(w_func, (observations, forecasts))
260260

261261
if backend == "numba":
262262
return crps.estimator_gufuncs["vr" + estimator](
263-
forecasts, observations, fcts_weights, obs_weights
263+
observations, forecasts, obs_weights, fct_weights
264264
)
265265

266-
forecasts, observations, fcts_weights, obs_weights = map(
267-
B.asarray, (forecasts, observations, fcts_weights, obs_weights)
266+
observations, forecasts, obs_weights, fct_weights = map(
267+
B.asarray, (observations, forecasts, obs_weights, fct_weights)
268268
)
269269
return crps.vr_ensemble(
270-
forecasts, observations, fcts_weights, obs_weights, backend=backend
270+
observations, forecasts, obs_weights, fct_weights, backend=backend
271271
)
272272

273273

274274
def crps_normal(
275+
observation: "ArrayLike",
275276
mu: "ArrayLike",
276277
sigma: "ArrayLike",
277-
observation: "ArrayLike",
278278
/,
279279
*,
280280
backend: "Backend" = None,
@@ -291,12 +291,12 @@ def crps_normal(
291291
292292
Parameters
293293
----------
294+
observations: ArrayLike
295+
The observed values.
294296
mu: ArrayLike
295297
Mean of the forecast normal distribution.
296298
sigma: ArrayLike
297299
Standard deviation of the forecast normal distribution.
298-
observation: ArrayLike
299-
The observed values.
300300
301301
Returns
302302
-------
@@ -308,13 +308,13 @@ def crps_normal(
308308
>>> from scoringrules import crps
309309
>>> crps.normal(0.1, 0.4, 0.0)
310310
"""
311-
return crps.normal(mu, sigma, observation, backend=backend)
311+
return crps.normal(observation, mu, sigma, backend=backend)
312312

313313

314314
def crps_lognormal(
315+
observation: "ArrayLike",
315316
mulog: "ArrayLike",
316317
sigmalog: "ArrayLike",
317-
observation: "ArrayLike",
318318
backend: "Backend" = None,
319319
) -> "ArrayLike":
320320
r"""Compute the closed form of the CRPS for the lognormal distribution.
@@ -334,6 +334,8 @@ def crps_lognormal(
334334
335335
Parameters
336336
----------
337+
observations: ArrayLike
338+
The observed values.
337339
mulog: ArrayLike
338340
Mean of the normal underlying distribution.
339341
sigmalog: ArrayLike
@@ -349,13 +351,13 @@ def crps_lognormal(
349351
>>> from scoringrules import crps
350352
>>> crps.lognormal(0.1, 0.4, 0.0)
351353
"""
352-
return crps.lognormal(mulog, sigmalog, observation, backend=backend)
354+
return crps.lognormal(observation, mulog, sigmalog, backend=backend)
353355

354356

355357
def crps_logistic(
358+
observation: "ArrayLike",
356359
mu: "ArrayLike",
357360
sigma: "ArrayLike",
358-
observation: "ArrayLike",
359361
/,
360362
*,
361363
backend: "Backend" = None,
@@ -372,12 +374,12 @@ def crps_logistic(
372374
373375
Parameters
374376
----------
377+
observations: ArrayLike
378+
Observed values.
375379
mu: ArrayLike
376380
Location parameter of the forecast logistic distribution.
377381
sigma: ArrayLike
378382
Scale parameter of the forecast logistic distribution.
379-
observation: ArrayLike
380-
Observed values.
381383
382384
Returns
383385
-------
@@ -389,7 +391,7 @@ def crps_logistic(
389391
>>> from scoringrules import crps
390392
>>> crps.logistic(0.1, 0.4, 0.0)
391393
"""
392-
return crps.logistic(mu, sigma, observation, backend=backend)
394+
return crps.logistic(observation, mu, sigma, backend=backend)
393395

394396

395397
__all__ = [

0 commit comments

Comments
 (0)