@@ -11,7 +11,7 @@ def crps_ensemble(
1111 obs : "ArrayLike" ,
1212 fct : "Array" ,
1313 / ,
14- axis : int = - 1 ,
14+ m_axis : int = - 1 ,
1515 * ,
1616 sorted_ensemble : bool = False ,
1717 estimator : str = "pwm" ,
@@ -48,7 +48,7 @@ def crps_ensemble(
4848 fct : array_like, shape (..., m)
4949 The predicted forecast ensemble, where the ensemble dimension is by default
5050 represented by the last axis.
51- axis : int
51+ m_axis : int
5252 The axis corresponding to the ensemble. Default is the last axis.
5353 sorted_ensemble : bool
5454 Boolean indicating whether the ensemble members are already in ascending order.
@@ -99,16 +99,16 @@ def crps_ensemble(
9999 B = backends .active if backend is None else backends [backend ]
100100 obs , fct = map (B .asarray , (obs , fct ))
101101
102- if axis != - 1 :
103- fct = B .moveaxis (fct , axis , - 1 )
102+ if m_axis != - 1 :
103+ fct = B .moveaxis (fct , m_axis , - 1 )
104104
105105 if not sorted_ensemble and estimator not in [
106106 "nrg" ,
107107 "akr" ,
108108 "akr_circperm" ,
109109 "fair" ,
110110 ]:
111- fct = B .sort (fct , axis = - 1 )
111+ fct = B .sort (fct , m_axis = - 1 )
112112
113113 if backend == "numba" :
114114 if estimator not in crps .estimator_gufuncs :
@@ -126,7 +126,7 @@ def twcrps_ensemble(
126126 fct : "Array" ,
127127 v_func : tp .Callable [["ArrayLike" ], "ArrayLike" ],
128128 / ,
129- axis : int = - 1 ,
129+ m_axis : int = - 1 ,
130130 * ,
131131 estimator : str = "pwm" ,
132132 sorted_ensemble : bool = False ,
@@ -156,7 +156,7 @@ def twcrps_ensemble(
156156 Chaining function used to emphasise particular outcomes. For example, a function that
157157 only considers values above a certain threshold :math:`t` by projecting forecasts and observations
158158 to :math:`[t, \inf)`.
159- axis : int
159+ m_axis : int
160160 The axis corresponding to the ensemble. Default is the last axis.
161161 backend : str, optional
162162 The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -200,7 +200,7 @@ def twcrps_ensemble(
200200 return crps_ensemble (
201201 obs ,
202202 fct ,
203- axis = axis ,
203+ m_axis = m_axis ,
204204 sorted_ensemble = sorted_ensemble ,
205205 estimator = estimator ,
206206 backend = backend ,
@@ -212,7 +212,7 @@ def owcrps_ensemble(
212212 fct : "Array" ,
213213 w_func : tp .Callable [["ArrayLike" ], "ArrayLike" ],
214214 / ,
215- axis : int = - 1 ,
215+ m_axis : int = - 1 ,
216216 * ,
217217 estimator : tp .Literal ["nrg" ] = "nrg" ,
218218 backend : "Backend" = None ,
@@ -245,7 +245,7 @@ def owcrps_ensemble(
245245 represented by the last axis.
246246 w_func : callable, array_like -> array_like
247247 Weight function used to emphasise particular outcomes.
248- axis : int
248+ m_axis : int
249249 The axis corresponding to the ensemble. Default is the last axis.
250250 backend : str, optional
251251 The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -288,8 +288,8 @@ def owcrps_ensemble(
288288 "Only the energy form of the estimator is available "
289289 "for the outcome-weighted CRPS."
290290 )
291- if axis != - 1 :
292- fct = B .moveaxis (fct , axis , - 1 )
291+ if m_axis != - 1 :
292+ fct = B .moveaxis (fct , m_axis , - 1 )
293293
294294 obs_weights , fct_weights = map (w_func , (obs , fct ))
295295
@@ -309,7 +309,7 @@ def vrcrps_ensemble(
309309 fct : "Array" ,
310310 w_func : tp .Callable [["ArrayLike" ], "ArrayLike" ],
311311 / ,
312- axis : int = - 1 ,
312+ m_axis : int = - 1 ,
313313 * ,
314314 estimator : tp .Literal ["nrg" ] = "nrg" ,
315315 backend : "Backend" = None ,
@@ -340,7 +340,7 @@ def vrcrps_ensemble(
340340 represented by the last axis.
341341 w_func : callable, array_like -> array_like
342342 Weight function used to emphasise particular outcomes.
343- axis : int
343+ m_axis : int
344344 The axis corresponding to the ensemble. Default is the last axis.
345345 backend : str, optional
346346 The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -383,8 +383,8 @@ def vrcrps_ensemble(
383383 "Only the energy form of the estimator is available "
384384 "for the outcome-weighted CRPS."
385385 )
386- if axis != - 1 :
387- fct = B .moveaxis (fct , axis , - 1 )
386+ if m_axis != - 1 :
387+ fct = B .moveaxis (fct , m_axis , - 1 )
388388
389389 obs_weights , fct_weights = map (w_func , (obs , fct ))
390390
@@ -404,7 +404,7 @@ def crps_quantile(
404404 fct : "Array" ,
405405 alpha : "Array" ,
406406 / ,
407- axis : int = - 1 ,
407+ m_axis : int = - 1 ,
408408 * ,
409409 backend : "Backend" = None ,
410410) -> "Array" :
@@ -435,7 +435,7 @@ def crps_quantile(
435435 represented by the last axis.
436436 alpha : array_like
437437 The percentile levels. We expect the quantile array to match the axis (see below) of the forecast array.
438- axis : int
438+ m_axis : int
439439 The axis corresponding to the ensemble. Default is the last axis.
440440 backend : str, optional
441441 The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -456,8 +456,8 @@ def crps_quantile(
456456 B = backends .active if backend is None else backends [backend ]
457457 obs , fct , alpha = map (B .asarray , (obs , fct , alpha ))
458458
459- if axis != - 1 :
460- fct = B .moveaxis (fct , axis , - 1 )
459+ if m_axis != - 1 :
460+ fct = B .moveaxis (fct , m_axis , - 1 )
461461
462462 if not fct .shape [- 1 ] == alpha .shape [- 1 ]:
463463 raise ValueError ("Expected matching length of `fct` and `alpha` values." )
@@ -1839,7 +1839,7 @@ def crps_mixnorm(
18391839 s : "ArrayLike" ,
18401840 / ,
18411841 w : "ArrayLike" = None ,
1842- axis : "ArrayLike" = - 1 ,
1842+ m_axis : "ArrayLike" = - 1 ,
18431843 * ,
18441844 backend : "Backend" = None ,
18451845) -> "ArrayLike" :
@@ -1863,7 +1863,7 @@ def crps_mixnorm(
18631863 Standard deviations of the component normal distributions.
18641864 w: array_like
18651865 Non-negative weights assigned to each component.
1866- axis : int
1866+ m_axis : int
18671867 The axis corresponding to the mixture components. Default is the last axis.
18681868 backend : str, optional
18691869 The name of the backend used for computations. Defaults to ``numba`` if available, else ``numpy``.
@@ -1890,15 +1890,15 @@ def crps_mixnorm(
18901890 obs , m , s = map (B .asarray , (obs , m , s ))
18911891
18921892 if w is None :
1893- M : int = m .shape [axis ]
1893+ M : int = m .shape [m_axis ]
18941894 w = B .zeros (m .shape ) + 1 / M
18951895 else :
18961896 w = B .asarray (w )
18971897
1898- if axis != - 1 :
1899- m = B .moveaxis (m , axis , - 1 )
1900- s = B .moveaxis (s , axis , - 1 )
1901- w = B .moveaxis (w , axis , - 1 )
1898+ if m_axis != - 1 :
1899+ m = B .moveaxis (m , m_axis , - 1 )
1900+ s = B .moveaxis (s , m_axis , - 1 )
1901+ w = B .moveaxis (w , m_axis , - 1 )
19021902
19031903 return crps .mixnorm (obs , m , s , w , backend = backend )
19041904
0 commit comments