Skip to content

Commit 7f4603f

Browse files
update axis naming args (#95)
* change ens member axis to m_axis * change categorical axis to k_axis * update mixture components axis name
1 parent b922b16 commit 7f4603f

File tree

10 files changed

+93
-91
lines changed

10 files changed

+93
-91
lines changed

scoringrules/_brier.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def rps_score(
4646
obs: "ArrayLike",
4747
fct: "ArrayLike",
4848
/,
49-
axis: int = -1,
49+
k_axis: int = -1,
5050
*,
5151
backend: "Backend" = None,
5252
) -> "Array":
@@ -70,7 +70,7 @@ def rps_score(
7070
Array of 0's and 1's corresponding to unobserved and observed categories
7171
forecasts :
7272
Array of forecast probabilities for each category.
73-
axis: int
73+
k_axis: int
7474
The axis corresponding to the categories. Default is the last axis.
7575
backend : str
7676
The name of the backend used for computations. Defaults to 'numpy'.
@@ -84,8 +84,8 @@ def rps_score(
8484
B = backends.active if backend is None else backends[backend]
8585
fct = B.asarray(fct)
8686

87-
if axis != -1:
88-
fct = B.moveaxis(fct, axis, -1)
87+
if k_axis != -1:
88+
fct = B.moveaxis(fct, k_axis, -1)
8989

9090
return brier.rps_score(obs=obs, fct=fct, backend=backend)
9191

@@ -129,7 +129,7 @@ def rls_score(
129129
obs: "ArrayLike",
130130
fct: "ArrayLike",
131131
/,
132-
axis: int = -1,
132+
k_axis: int = -1,
133133
*,
134134
backend: "Backend" = None,
135135
) -> "Array":
@@ -153,6 +153,8 @@ def rls_score(
153153
Observed outcome, either 0 or 1.
154154
fct : array_like
155155
Forecasted probabilities between 0 and 1.
156+
k_axis: int
157+
The axis corresponding to the categories. Default is the last axis.
156158
backend : str
157159
The name of the backend used for computations. Defaults to 'numpy'.
158160
@@ -165,8 +167,8 @@ def rls_score(
165167
B = backends.active if backend is None else backends[backend]
166168
fct = B.asarray(fct)
167169

168-
if axis != -1:
169-
fct = B.moveaxis(fct, axis, -1)
170+
if k_axis != -1:
171+
fct = B.moveaxis(fct, k_axis, -1)
170172

171173
return brier.rls_score(obs=obs, fct=fct, backend=backend)
172174

scoringrules/_crps.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def crps_ensemble(
1111
obs: "ArrayLike",
1212
fct: "Array",
1313
/,
14-
axis: int = -1,
14+
m_axis: int = -1,
1515
*,
1616
sorted_ensemble: bool = False,
1717
estimator: str = "pwm",
@@ -48,7 +48,7 @@ def crps_ensemble(
4848
fct : array_like, shape (..., m)
4949
The predicted forecast ensemble, where the ensemble dimension is by default
5050
represented by the last axis.
51-
axis : int
51+
m_axis : int
5252
The axis corresponding to the ensemble. Default is the last axis.
5353
sorted_ensemble : bool
5454
Boolean indicating whether the ensemble members are already in ascending order.
@@ -99,16 +99,16 @@ def crps_ensemble(
9999
B = backends.active if backend is None else backends[backend]
100100
obs, fct = map(B.asarray, (obs, fct))
101101

102-
if axis != -1:
103-
fct = B.moveaxis(fct, axis, -1)
102+
if m_axis != -1:
103+
fct = B.moveaxis(fct, m_axis, -1)
104104

105105
if not sorted_ensemble and estimator not in [
106106
"nrg",
107107
"akr",
108108
"akr_circperm",
109109
"fair",
110110
]:
111-
fct = B.sort(fct, axis=-1)
111+
fct = B.sort(fct, m_axis=-1)
112112

113113
if backend == "numba":
114114
if estimator not in crps.estimator_gufuncs:
@@ -126,7 +126,7 @@ def twcrps_ensemble(
126126
fct: "Array",
127127
v_func: tp.Callable[["ArrayLike"], "ArrayLike"],
128128
/,
129-
axis: int = -1,
129+
m_axis: int = -1,
130130
*,
131131
estimator: str = "pwm",
132132
sorted_ensemble: bool = False,
@@ -156,7 +156,7 @@ def twcrps_ensemble(
156156
Chaining function used to emphasise particular outcomes. For example, a function that
157157
only considers values above a certain threshold :math:`t` by projecting forecasts and observations
158158
to :math:`[t, \inf)`.
159-
axis : int
159+
m_axis : int
160160
The axis corresponding to the ensemble. Default is the last axis.
161161
backend : str, optional
162162
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -200,7 +200,7 @@ def twcrps_ensemble(
200200
return crps_ensemble(
201201
obs,
202202
fct,
203-
axis=axis,
203+
m_axis=m_axis,
204204
sorted_ensemble=sorted_ensemble,
205205
estimator=estimator,
206206
backend=backend,
@@ -212,7 +212,7 @@ def owcrps_ensemble(
212212
fct: "Array",
213213
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
214214
/,
215-
axis: int = -1,
215+
m_axis: int = -1,
216216
*,
217217
estimator: tp.Literal["nrg"] = "nrg",
218218
backend: "Backend" = None,
@@ -245,7 +245,7 @@ def owcrps_ensemble(
245245
represented by the last axis.
246246
w_func : callable, array_like -> array_like
247247
Weight function used to emphasise particular outcomes.
248-
axis : int
248+
m_axis : int
249249
The axis corresponding to the ensemble. Default is the last axis.
250250
backend : str, optional
251251
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -288,8 +288,8 @@ def owcrps_ensemble(
288288
"Only the energy form of the estimator is available "
289289
"for the outcome-weighted CRPS."
290290
)
291-
if axis != -1:
292-
fct = B.moveaxis(fct, axis, -1)
291+
if m_axis != -1:
292+
fct = B.moveaxis(fct, m_axis, -1)
293293

294294
obs_weights, fct_weights = map(w_func, (obs, fct))
295295

@@ -309,7 +309,7 @@ def vrcrps_ensemble(
309309
fct: "Array",
310310
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
311311
/,
312-
axis: int = -1,
312+
m_axis: int = -1,
313313
*,
314314
estimator: tp.Literal["nrg"] = "nrg",
315315
backend: "Backend" = None,
@@ -340,7 +340,7 @@ def vrcrps_ensemble(
340340
represented by the last axis.
341341
w_func : callable, array_like -> array_like
342342
Weight function used to emphasise particular outcomes.
343-
axis : int
343+
m_axis : int
344344
The axis corresponding to the ensemble. Default is the last axis.
345345
backend : str, optional
346346
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -383,8 +383,8 @@ def vrcrps_ensemble(
383383
"Only the energy form of the estimator is available "
384384
"for the outcome-weighted CRPS."
385385
)
386-
if axis != -1:
387-
fct = B.moveaxis(fct, axis, -1)
386+
if m_axis != -1:
387+
fct = B.moveaxis(fct, m_axis, -1)
388388

389389
obs_weights, fct_weights = map(w_func, (obs, fct))
390390

@@ -404,7 +404,7 @@ def crps_quantile(
404404
fct: "Array",
405405
alpha: "Array",
406406
/,
407-
axis: int = -1,
407+
m_axis: int = -1,
408408
*,
409409
backend: "Backend" = None,
410410
) -> "Array":
@@ -435,7 +435,7 @@ def crps_quantile(
435435
represented by the last axis.
436436
alpha : array_like
437437
The percentile levels. We expect the quantile array to match the axis (see below) of the forecast array.
438-
axis : int
438+
m_axis : int
439439
The axis corresponding to the ensemble. Default is the last axis.
440440
backend : str, optional
441441
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -456,8 +456,8 @@ def crps_quantile(
456456
B = backends.active if backend is None else backends[backend]
457457
obs, fct, alpha = map(B.asarray, (obs, fct, alpha))
458458

459-
if axis != -1:
460-
fct = B.moveaxis(fct, axis, -1)
459+
if m_axis != -1:
460+
fct = B.moveaxis(fct, m_axis, -1)
461461

462462
if not fct.shape[-1] == alpha.shape[-1]:
463463
raise ValueError("Expected matching length of `fct` and `alpha` values.")
@@ -1839,7 +1839,7 @@ def crps_mixnorm(
18391839
s: "ArrayLike",
18401840
/,
18411841
w: "ArrayLike" = None,
1842-
axis: "ArrayLike" = -1,
1842+
m_axis: "ArrayLike" = -1,
18431843
*,
18441844
backend: "Backend" = None,
18451845
) -> "ArrayLike":
@@ -1863,7 +1863,7 @@ def crps_mixnorm(
18631863
Standard deviations of the component normal distributions.
18641864
w: array_like
18651865
Non-negative weights assigned to each component.
1866-
axis : int
1866+
m_axis : int
18671867
The axis corresponding to the mixture components. Default is the last axis.
18681868
backend : str, optional
18691869
The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -1890,15 +1890,15 @@ def crps_mixnorm(
18901890
obs, m, s = map(B.asarray, (obs, m, s))
18911891

18921892
if w is None:
1893-
M: int = m.shape[axis]
1893+
M: int = m.shape[m_axis]
18941894
w = B.zeros(m.shape) + 1 / M
18951895
else:
18961896
w = B.asarray(w)
18971897

1898-
if axis != -1:
1899-
m = B.moveaxis(m, axis, -1)
1900-
s = B.moveaxis(s, axis, -1)
1901-
w = B.moveaxis(w, axis, -1)
1898+
if m_axis != -1:
1899+
m = B.moveaxis(m, m_axis, -1)
1900+
s = B.moveaxis(s, m_axis, -1)
1901+
w = B.moveaxis(w, m_axis, -1)
19021902

19031903
return crps.mixnorm(obs, m, s, w, backend=backend)
19041904

scoringrules/_error_spread.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def error_spread_score(
1111
observations: "ArrayLike",
1212
forecasts: "Array",
1313
/,
14-
axis: int = -1,
14+
m_axis: int = -1,
1515
*,
1616
backend: "Backend" = None,
1717
) -> "Array":
@@ -24,7 +24,7 @@ def error_spread_score(
2424
forecasts: Array
2525
The predicted forecast ensemble, where the ensemble dimension is by default
2626
represented by the last axis.
27-
axis: int
27+
m_axis: int
2828
The axis corresponding to the ensemble. Default is the last axis.
2929
backend: str
3030
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
@@ -37,8 +37,8 @@ def error_spread_score(
3737
B = backends.active if backend is None else backends[backend]
3838
observations, forecasts = map(B.asarray, (observations, forecasts))
3939

40-
if axis != -1:
41-
forecasts = B.moveaxis(forecasts, axis, -1)
40+
if m_axis != -1:
41+
forecasts = B.moveaxis(forecasts, m_axis, -1)
4242

4343
if B.name == "numba":
4444
return error_spread._ess_gufunc(observations, forecasts)

scoringrules/_kernels.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def gksuv_ensemble(
1212
obs: "ArrayLike",
1313
fct: "Array",
1414
/,
15-
axis: int = -1,
15+
m_axis: int = -1,
1616
*,
1717
estimator: str = "nrg",
1818
backend: "Backend" = None,
@@ -40,7 +40,7 @@ def gksuv_ensemble(
4040
fct : array_like
4141
The predicted forecast ensemble, where the ensemble dimension is by default
4242
represented by the last axis.
43-
axis : int
43+
m_axis : int
4444
The axis corresponding to the ensemble. Default is the last axis.
4545
estimator : str
4646
Indicates the estimator to be used.
@@ -73,8 +73,8 @@ def gksuv_ensemble(
7373
f"Must be one of ['fair', 'nrg']"
7474
)
7575

76-
if axis != -1:
77-
fct = B.moveaxis(fct, axis, -1)
76+
if m_axis != -1:
77+
fct = B.moveaxis(fct, m_axis, -1)
7878

7979
if backend == "numba":
8080
return kernels.estimator_gufuncs[estimator](obs, fct)
@@ -87,7 +87,7 @@ def twgksuv_ensemble(
8787
fct: "Array",
8888
v_func: tp.Callable[["ArrayLike"], "ArrayLike"],
8989
/,
90-
axis: int = -1,
90+
m_axis: int = -1,
9191
*,
9292
estimator: str = "nrg",
9393
backend: "Backend" = None,
@@ -120,7 +120,7 @@ def twgksuv_ensemble(
120120
Chaining function used to emphasise particular outcomes. For example, a function that
121121
only considers values above a certain threshold :math:`t` by projecting forecasts and observations
122122
to :math:`[t, \inf)`.
123-
axis : int
123+
m_axis : int
124124
The axis corresponding to the ensemble. Default is the last axis.
125125
backend : str
126126
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
@@ -144,7 +144,7 @@ def twgksuv_ensemble(
144144
return gksuv_ensemble(
145145
obs,
146146
fct,
147-
axis=axis,
147+
m_axis=m_axis,
148148
estimator=estimator,
149149
backend=backend,
150150
)
@@ -155,7 +155,7 @@ def owgksuv_ensemble(
155155
fct: "Array",
156156
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
157157
/,
158-
axis: int = -1,
158+
m_axis: int = -1,
159159
*,
160160
backend: "Backend" = None,
161161
) -> "Array":
@@ -186,7 +186,7 @@ def owgksuv_ensemble(
186186
represented by the last axis.
187187
w_func : callable, array_like -> array_like
188188
Weight function used to emphasise particular outcomes.
189-
axis : int
189+
m_axis : int
190190
The axis corresponding to the ensemble. Default is the last axis.
191191
backend : str
192192
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
@@ -210,8 +210,8 @@ def owgksuv_ensemble(
210210

211211
obs, fct = map(B.asarray, (obs, fct))
212212

213-
if axis != -1:
214-
fct = B.moveaxis(fct, axis, -1)
213+
if m_axis != -1:
214+
fct = B.moveaxis(fct, m_axis, -1)
215215

216216
obs_weights, fct_weights = map(w_func, (obs, fct))
217217

@@ -229,7 +229,7 @@ def vrgksuv_ensemble(
229229
fct: "Array",
230230
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
231231
/,
232-
axis: int = -1,
232+
m_axis: int = -1,
233233
*,
234234
backend: "Backend" = None,
235235
) -> "Array":
@@ -259,7 +259,7 @@ def vrgksuv_ensemble(
259259
represented by the last axis.
260260
w_func : callable, array_like -> array_like
261261
Weight function used to emphasise particular outcomes.
262-
axis : int
262+
m_axis : int
263263
The axis corresponding to the ensemble. Default is the last axis.
264264
backend : str
265265
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
@@ -283,8 +283,8 @@ def vrgksuv_ensemble(
283283

284284
obs, fct = map(B.asarray, (obs, fct))
285285

286-
if axis != -1:
287-
fct = B.moveaxis(fct, axis, -1)
286+
if m_axis != -1:
287+
fct = B.moveaxis(fct, m_axis, -1)
288288

289289
obs_weights, fct_weights = map(w_func, (obs, fct))
290290

0 commit comments

Comments
 (0)