Skip to content

Commit 02afe44

Browse files
authored
Add error spread score (#10)
* added ess * docstrings * updated docstrings * wip * update docs
1 parent 5d599a8 commit 02afe44

File tree

16 files changed

+177
-22
lines changed

16 files changed

+177
-22
lines changed

docs/api/error_spread.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Error Spread Score
2+
3+
::: scoringrules.error_spread_score

docs/api/logarithmic.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Logarithmic Score
2+
3+
::: scoringrules.logs_normal

docs/stylesheets/extra.css

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,9 @@ h6.heading {
4343
border-radius: 2pt;
4444
padding: 0pt 5pt 2pt 5pt;
4545
}
46+
47+
48+
/* Maximum space for text block */
49+
.md-grid {
50+
max-width: 90%; /* or 100%, if you want to stretch to full-width */
51+
}

mkdocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ nav:
1313
- Analytical formulations: api/crps/analytical.md
1414
- Weighted versions: api/crps/weighted.md
1515
- Logarithmic Score: api/logarithmic.md
16+
- Error Spread Score: api/error_spread.md
1617
- Energy Score:
1718
- api/energy/index.md
1819
- Ensemble-based estimation: api/energy/ensemble.md

scoringrules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
twenergy_score,
1717
vrenergy_score,
1818
)
19+
from scoringrules._error_spread import error_spread_score
1920
from scoringrules._logs import logs_normal
2021
from scoringrules._variogram import (
2122
owvariogram_score,
@@ -40,6 +41,7 @@
4041
"vrcrps_ensemble",
4142
"logs_normal",
4243
"brier_score",
44+
"error_spread_score",
4345
"energy_score",
4446
"owenergy_score",
4547
"twenergy_score",

scoringrules/_crps.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def twcrps_ensemble(
100100
Chaining function used to emphasise particular outcomes. For example, a function that
101101
only considers values above a certain threshold $t$ by projecting forecasts and observations
102102
to $\[t, \inf)$.
103-
v_funcargs: tuple
104-
Additional arguments to the chaining function.
105103
axis: int
106104
The axis corresponding to the ensemble. Default is the last axis.
107105
backend: str
@@ -158,8 +156,6 @@ def owcrps_ensemble(
158156
The observed values.
159157
w_func: tp.Callable
160158
Weight function used to emphasise particular outcomes.
161-
w_funcargs: tuple
162-
Additional arguments to the weight function.
163159
axis: int
164160
The axis corresponding to the ensemble. Default is the last axis.
165161
backend: str
@@ -235,8 +231,6 @@ def vrcrps_ensemble(
235231
The observed values.
236232
w_func: tp.Callable
237233
Weight function used to emphasise particular outcomes.
238-
w_funcargs: tuple
239-
Additional arguments to the weight function.
240234
axis: int
241235
The axis corresponding to the ensemble. Default is the last axis.
242236
backend: str

scoringrules/_energy.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,6 @@ def owenergy_score(
143143
The observed values, where the variables dimension is by default the last axis.
144144
w_func: tp.Callable
145145
Weight function used to emphasise particular outcomes.
146-
w_funcargs: tuple
147-
Additional arguments to the weight function.
148146
m_axis: int
149147
The axis corresponding to the ensemble dimension. Defaults to -2.
150148
v_axis: int or tuple(int)
@@ -211,8 +209,6 @@ def vrenergy_score(
211209
The observed values, where the variables dimension is by default the last axis.
212210
w_func: tp.Callable
213211
Weight function used to emphasise particular outcomes.
214-
w_funcargs: tuple
215-
Additional arguments to the weight function.
216212
m_axis: int
217213
The axis corresponding to the ensemble dimension. Defaults to -2.
218214
v_axis: int or tuple(int)

scoringrules/_error_spread.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import typing as tp
2+
3+
from scoringrules.backend import backends
4+
from scoringrules.core import error_spread
5+
6+
if tp.TYPE_CHECKING:
7+
from scoringrules.core.typing import Array, ArrayLike, Backend
8+
9+
10+
def error_spread_score(
11+
forecasts: "Array",
12+
observations: "ArrayLike",
13+
/,
14+
axis: int = -1,
15+
*,
16+
backend: "Backend" = None,
17+
) -> "Array":
18+
r"""Compute the error-spread score (ESS) for a finite ensemble.
19+
20+
The error spread score [(Christensen et al., 2015)](https://doi.org/10.1002/qj.2375) is given by:
21+
22+
$$ESS = \left(s^2 - e^2 - e \cdot s \cdot g\right)^2$$
23+
24+
where the mean $m$, variance $s^2$, and skewness $g$ of the ensemble forecast of size $F$ are computed as follows:
25+
26+
$$m = \frac{1}{F} \sum_{f=1}^{F} X_f, \quad s^2 = \frac{1}{F-1} \sum_{f=1}^{F} (X_f - m)^2, \quad g = \frac{F}{(F-1)(F-2)} \sum_{f=1}^{F} \left(\frac{X_f - m}{s}\right)^3$$
27+
28+
The error in the ensemble mean $e$ is calculated as $e = m - y$, where $y$ is the observed value.
29+
30+
Parameters
31+
----------
32+
forecasts: Array
33+
The predicted forecast ensemble, where the ensemble dimension is by default
34+
represented by the last axis.
35+
observations: ArrayLike
36+
The observed values.
37+
axis: int
38+
The axis corresponding to the ensemble. Default is the last axis.
39+
backend: str
40+
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
41+
42+
Returns
43+
-------
44+
- Array
45+
An array of error spread scores for each ensemble forecast, which should be averaged to get meaningful values.
46+
"""
47+
B = backends.active if backend is None else backends[backend]
48+
forecasts, observations = map(B.asarray, (forecasts, observations))
49+
50+
if axis != -1:
51+
forecasts = B.moveaxis(forecasts, axis, -1)
52+
53+
if B.name == "numba":
54+
return error_spread._ess_gufunc(forecasts, observations)
55+
56+
return error_spread.ess(forecasts, observations, backend=backend)
57+
58+
59+
60+
# \[
61+
# ESS = \left(s^2 - e^2 - e \cdot s \cdot g\right)^2
62+
# \]

scoringrules/_variogram.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def twvariogram_score(
9494
The order of the Variogram Score. Typical values are 0.5, 1.0 or 2.0. Defaults to 1.0.
9595
v_func: tp.Callable
9696
Chaining function used to emphasise particular outcomes.
97-
v_funcargs: tuple
98-
Additional arguments to the chaining function.
9997
m_axis: int
10098
The axis corresponding to the ensemble dimension. Defaults to -2.
10199
v_axis: int
@@ -152,8 +150,6 @@ def owvariogram_score(
152150
The order of the Variogram Score. Typical values are 0.5, 1.0 or 2.0. Defaults to 1.0.
153151
w_func: tp.Callable
154152
Weight function used to emphasise particular outcomes.
155-
w_funcargs: tuple
156-
Additional arguments to the weight function.
157153
m_axis: int
158154
The axis corresponding to the ensemble dimension. Defaults to -2.
159155
v_axis: int
@@ -223,8 +219,6 @@ def vrvariogram_score(
223219
The order of the Variogram Score. Typical values are 0.5, 1.0 or 2.0. Defaults to 1.0.
224220
w_func: tp.Callable
225221
Weight function used to emphasise particular outcomes.
226-
w_funcargs: tuple
227-
Additional arguments to the weight function.
228222
m_axis: int
229223
The axis corresponding to the ensemble dimension. Defaults to -2.
230224
v_axis: int

scoringrules/backend/registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import sys
21
import typing as tp
32
from importlib.util import find_spec
43

@@ -20,7 +19,7 @@
2019
}
2120

2221
try:
23-
import numba # type: ignore
22+
import numba # noqa: F401
2423

2524
_NUMBA_IMPORTED = True
2625
except ImportError:
@@ -31,8 +30,9 @@ class BackendsRegistry(dict[str, ArrayBackend]):
3130
"""A dict-like container of registered backends."""
3231

3332
def __init__(self):
34-
for backend in self.available_backends:
35-
self.register_backend(backend)
33+
self.register_backend("numpy")
34+
if _NUMBA_IMPORTED:
35+
self.register_backend("numba")
3636

3737
self._active = "numba" if _NUMBA_IMPORTED else "numpy"
3838

0 commit comments

Comments
 (0)