Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scoringrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
twes_ensemble,
vres_ensemble,
)
from scoringrules._dss import dssuv_ensemble, dssmv_ensemble
from scoringrules._error_spread import error_spread_score
from scoringrules._interval import interval_score, weighted_interval_score
from scoringrules._logs import (
Expand Down Expand Up @@ -196,6 +197,8 @@ def wrapper(*args, **kwargs):
"rps_score",
"log_score",
"rls_score",
"dssuv_ensemble",
"dssmv_ensemble",
"error_spread_score",
"es_ensemble",
"owes_ensemble",
Expand Down
108 changes: 108 additions & 0 deletions scoringrules/_dss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import typing as tp

from scoringrules.backend import backends
from scoringrules.core import dss
from scoringrules.core.utils import multivariate_array_check

if tp.TYPE_CHECKING:
from scoringrules.core.typing import Array, Backend


def dssuv_ensemble(
obs: "Array",
fct: "Array",
/,
m_axis: int = -1,
*,
bias: bool = False,
backend: "Backend" = None,
) -> "Array":
r"""Compute the Dawid-Sebastiani-Score for a finite univariate ensemble.

The Dawid-Sebastiani Score for an ensemble forecast is defined as

.. math::
\text{DSS}(F_{ens}, y)= \frac{(y - \bar{x)^2}{\sigma^2} + 2 \log \sigma

where :math:`\bar{x}` and :math:`\sigma` are the mean and standard deviation of the ensemble members.

Parameters
----------
obs : array_like
The observed values.
fct : array_like, shape (..., m)
The predicted forecast ensemble, where the ensemble dimension is by default
represented by the last axis.
m_axis : int
The axis corresponding to the ensemble. Default is the last axis.
bias : bool
Logical specifying whether the biased or unbiased estimator of the standard deviation
should be used to calculate the score. Default is the unbiased estimator (`bias=False`).
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Returns
-------
score: Array
The computed Dawid-Sebastiani Score.
"""
B = backends.active if backend is None else backends[backend]
obs, fct = map(B.asarray, (obs, fct))

if m_axis != -1:
fct = B.moveaxis(fct, m_axis, -1)

if backend == "numba":
return dss._dss_uv_gufunc(obs, fct, bias)

return dss.ds_score_uv(obs, fct, bias, backend=backend)


def dssmv_ensemble(
obs: "Array",
fct: "Array",
/,
m_axis: int = -2,
v_axis: int = -1,
*,
bias: bool = False,
backend: "Backend" = None,
) -> "Array":
r"""Compute the Dawid-Sebastiani-Score for a finite multivariate ensemble.

The Dawid-Sebastiani Score for an ensemble forecast is defined as

.. math::
\text{DSS}(F_{ens}, \mathbf{y})= (\mathbf{y} - \bar{mathbf{x}})^{\top} \Sigma^-1 (\mathbf{y} - \bar{mathbf{x}}) + \log \det(\Sigma)

where :math:`\bar{mathbf{x}}` is the mean of the ensemble members (along each dimension),
and :math:`\Sigma` is the sample covariance matrix estimated from the ensemble members.

Parameters
----------
obs : array_like
The observed values, where the variables dimension is by default the last axis.
fct : array_like
The predicted forecast ensemble, where the ensemble dimension is by default
represented by the second last axis and the variables dimension by the last axis.
m_axis : int
The axis corresponding to the ensemble dimension. Defaults to -2.
v_axis : int or tuple of int
The axis corresponding to the variables dimension. Defaults to -1.
bias : bool
Logical specifying whether the biased or unbiased estimator of the covariance matrix
should be used to calculate the score. Default is the unbiased estimator (`bias=False`).
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Returns
-------
score: Array
The computed Dawid-Sebastiani Score.
"""
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)

if backend == "numba":
return dss._dss_mv_gufunc(obs, fct, bias)

return dss.ds_score_mv(obs, fct, bias, backend=backend)
17 changes: 17 additions & 0 deletions scoringrules/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def std(
/,
*,
axis: int | tuple[int, ...] | None = None,
bias: bool = False,
keepdims: bool = False,
) -> "Array":
"""Calculate the standard deviation of the input array ``x``."""
Expand Down Expand Up @@ -297,3 +298,19 @@ def size(self, x: "Array") -> int:
@abc.abstractmethod
def indices(self, x: "Array") -> int:
"""Return an array representing the indices of a grid."""

@abc.abstractmethod
def inv(self, x: "Array") -> "Array":
"""Return the inverse of a matrix."""

@abc.abstractmethod
def cov(self, x: "Array", rowvar: bool, bias: bool) -> "Array":
"""Return the covariance matrix from a sample."""

@abc.abstractmethod
def det(self, x: "Array") -> "Array":
"""Return the determinant of a matrix."""

@abc.abstractmethod
def reshape(self, x: "Array", shape: int | tuple[int, ...]) -> "Array":
"""Reshape an array to a new ``shape``."""
16 changes: 15 additions & 1 deletion scoringrules/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def std(
/,
*,
axis: int | tuple[int, ...] | None = None,
bias: bool = False,
keepdims: bool = False,
) -> "Array":
return jnp.std(x, ddof=1, axis=axis, keepdims=keepdims)
ddof = 0 if bias else 1
return jnp.std(x, ddof=ddof, axis=axis, keepdims=keepdims)

def quantile(
self,
Expand Down Expand Up @@ -269,6 +271,18 @@ def size(self, x: "Array") -> int:
def indices(self, dimensions: tuple) -> "Array":
return jnp.indices(dimensions)

def inv(self, x: "Array") -> "Array":
return jnp.linalg.inv(x)

def cov(self, x: "Array", rowvar: bool = True, bias: bool = False) -> "Array":
return jnp.cov(x, rowvar=rowvar, bias=bias)

def det(self, x: "Array") -> "Array":
return jnp.linalg.det(x)

def reshape(self, x: "Array", shape: int | tuple[int, ...]) -> "Array":
return jnp.reshape(x, shape)


if __name__ == "__main__":
B = JaxBackend()
Expand Down
16 changes: 15 additions & 1 deletion scoringrules/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def std(
/,
*,
axis: int | tuple[int, ...] | None = None,
bias: bool = False,
keepdims: bool = False,
) -> "NDArray":
return np.std(x, ddof=1, axis=axis, keepdims=keepdims)
ddof = 0 if bias else 1
return np.std(x, ddof=ddof, axis=axis, keepdims=keepdims)

def quantile(
self,
Expand Down Expand Up @@ -265,6 +267,18 @@ def size(self, x: "NDArray") -> int:
def indices(self, dimensions: tuple) -> "NDArray":
return np.indices(dimensions)

def inv(self, x: "NDArray") -> "NDArray":
return np.linalg.inv(x)

def cov(self, x: "NDArray", rowvar: bool = True, bias: bool = False) -> "NDArray":
return np.cov(x, rowvar=rowvar, bias=bias)

def det(self, x: "NDArray") -> "NDArray":
return np.linalg.det(x)

def reshape(self, x: "NDArray", shape: int | tuple[int, ...]) -> "NDArray":
return np.reshape(x, shape)


class NumbaBackend(NumpyBackend):
"""Numba backend."""
Expand Down
22 changes: 21 additions & 1 deletion scoringrules/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ def std(
/,
*,
axis: int | tuple[int, ...] | None = None,
bias: bool = False,
keepdims: bool = False,
) -> "Tensor":
n = x.shape.num_elements() if axis is None else x.shape[axis]
resc = self.sqrt(n / (n - 1))
if not bias:
resc = self.sqrt(n / (n - 1))
return tf.math.reduce_std(x, axis=axis, keepdims=keepdims) * resc

def quantile(
Expand Down Expand Up @@ -304,6 +306,24 @@ def indices(self, dimensions: tuple) -> "Tensor":
indices = tf.stack(index_grids)
return indices

def inv(self, x: "Tensor") -> "Tensor":
return tf.linalg.inv(x)

def cov(self, x: "Tensor", rowvar: bool = True, bias: bool = False) -> "Tensor":
if not rowvar:
x = tf.transpose(x)
x = x - tf.reduce_mean(x, axis=1, keepdims=True)
correction = tf.cast(tf.shape(x)[1], x.dtype) - 1.0
if bias:
correction += 1.0
return tf.matmul(x, x, transpose_b=True) / correction

def det(self, x: "Tensor") -> "Tensor":
return tf.linalg.det(x)

def reshape(self, x: "Tensor", shape: int | tuple[int, ...]) -> "Tensor":
return tf.reshape(x, shape)


if __name__ == "__main__":
B = TensorflowBackend()
Expand Down
19 changes: 18 additions & 1 deletion scoringrules/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def std(
/,
*,
axis: int | tuple[int, ...] | None = None,
bias: bool = False,
keepdims: bool = False,
) -> "Tensor":
return torch.std(x, correction=1, axis=axis, keepdim=keepdims)
correction = 0 if bias else 1
return torch.std(x, correction=correction, axis=axis, keepdim=keepdims)

def quantile(
self,
Expand Down Expand Up @@ -286,3 +288,18 @@ def indices(self, dimensions: tuple) -> "Tensor":
index_grids = torch.meshgrid(*ranges, indexing="ij")
indices = torch.stack(index_grids)
return indices

def inv(self, x: "Tensor") -> "Tensor":
return torch.linalg.inv(x)

def cov(self, x: "Tensor", rowvar: bool = True, bias: bool = False) -> "Tensor":
correction = 0 if bias else 1
if not rowvar:
x = torch.transpose(x, -2, -1)
return torch.cov(x, correction=correction)

def det(self, x: "Tensor") -> "Tensor":
return torch.linalg.det(x)

def reshape(self, x: "Tensor", shape: int | tuple[int, ...]) -> "Tensor":
return torch.reshape(x, shape)
9 changes: 9 additions & 0 deletions scoringrules/core/dss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
try:
from ._gufuncs import _dss_mv_gufunc, _dss_uv_gufunc
except ImportError:
_dss_uv_gufunc = None
_dss_mv_gufunc = None

from ._score import ds_score_uv, ds_score_mv

__all__ = ["ds_score_uv", "_dss_uv_gufunc", "ds_score_mv", "_dss_mv_gufunc"]
31 changes: 31 additions & 0 deletions scoringrules/core/dss/_gufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
from numba import guvectorize

from scoringrules.core.utils import lazy_gufunc_wrapper_uv, lazy_gufunc_wrapper_mv


@lazy_gufunc_wrapper_uv
@guvectorize("(),(n),()->()")
def _dss_uv_gufunc(obs: np.ndarray, fct: np.ndarray, bias: bool, out: np.ndarray):
ens_mean = np.mean(fct)
fact = len(fct) - 1.0
if bias:
fact += 1.0
var = np.sum((fct - ens_mean) ** 2) / fact
sig = np.sqrt(var)
log_sig = 2 * np.log(sig)
bias_precision = ((obs - ens_mean) / sig) ** 2
out[0] = bias_precision + log_sig


@lazy_gufunc_wrapper_mv
@guvectorize("(d),(m,d),()->()")
def _dss_mv_gufunc(obs: np.ndarray, fct: np.ndarray, bias: bool, out: np.ndarray):
M = fct.shape[0]
ens_mean = np.sum(fct, axis=0) / M
obs_cent = obs - ens_mean
cov = np.cov(fct, rowvar=False, bias=bias)
prec = np.linalg.inv(cov)
log_det = np.log(np.linalg.det(cov))
bias_precision = np.transpose(obs_cent) @ prec @ obs_cent
out[0] = bias_precision + log_det
64 changes: 64 additions & 0 deletions scoringrules/core/dss/_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import typing as tp

from scoringrules.backend import backends

if tp.TYPE_CHECKING:
from scoringrules.core.typing import Array, Backend


def ds_score_uv(
obs: "Array",
fct: "Array",
bias: bool = False,
backend: "Backend" = None,
) -> "Array":
"""Compute the Dawid Sebastiani Score for a univariate finite ensemble."""
B = backends.active if backend is None else backends[backend]
ens_mean = B.mean(fct, axis=-1)
sig = B.std(fct, axis=-1, bias=bias)
bias_precision = ((obs - ens_mean) / sig) ** 2
log_sig = 2 * B.log(sig)
return bias_precision + log_sig


def ds_score_mv(
obs: "Array", # (... D)
fct: "Array", # (... M D)
bias: bool = False,
backend: "Backend" = None,
) -> "Array":
"""Compute the Dawid Sebastiani Score for a multivariate finite ensemble."""
B = backends.active if backend is None else backends[backend]

batch_shape = fct.shape[:-2]
M, D = fct.shape[-2:]

fct_flat = B.reshape(fct, (-1, M, D)) # (... M D)
obs_flat = B.reshape(obs, (-1, D)) # (... D)

# list to collect scores for each batch
scores = [
ds_score_mv_mat(obs_i, fct_i, bias=bias, backend=backend)
for obs_i, fct_i in zip(obs_flat, fct_flat)
]

# reshape to original batch shape
return B.reshape(B.stack(scores), batch_shape)


def ds_score_mv_mat(
obs: "Array", # (D)
fct: "Array", # (M D)
bias: bool = False,
backend: "Backend" = None,
) -> "Array":
"""Compute the Dawid Sebastiani Score for one multivariate finite ensemble."""
B = backends.active if backend is None else backends[backend]
cov = B.cov(fct, rowvar=False, bias=bias) # (D D)
precision = B.inv(cov) # (D D)
log_det = B.log(B.det(cov)) # ()
obs_cent = obs - B.mean(fct, axis=-2) # (D)
bias_precision = B.squeeze(
B.expand_dims(obs_cent, axis=-2) @ precision @ B.expand_dims(obs_cent, axis=-1)
) # ()
return bias_precision + log_det
Loading
Loading