Skip to content

Commit a3cc01f

Browse files
committed
add logs test
1 parent 2a047a2 commit a3cc01f

File tree

7 files changed

+26
-12
lines changed

7 files changed

+26
-12
lines changed

scoringrules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from scoringrules._brier import brier_score
44
from scoringrules._crps import crps_ensemble, crps_lognormal, crps_normal
55
from scoringrules._energy import energy_score
6-
from scoringrules._logs import logscore_normal
6+
from scoringrules._logs import logs_normal
77
from scoringrules._variogram import variogram_score
88
from scoringrules.backend import register_backend
99

@@ -15,7 +15,7 @@
1515
"crps_ensemble",
1616
"crps_normal",
1717
"crps_lognormal",
18-
"logscore_normal",
18+
"logs_normal",
1919
"brier_score",
2020
"energy_score",
2121
"variogram_score",

scoringrules/_logs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
ArrayLike = tp.TypeVar("ArrayLike", Array, float)
77

88

9-
def logscore_normal(
9+
def logs_normal(
1010
mu: ArrayLike,
1111
sigma: ArrayLike,
1212
observation: ArrayLike,
@@ -37,7 +37,7 @@ def logscore_normal(
3737
Examples
3838
--------
3939
>>> import scoringrules as sr
40-
>>> sr.logscore_normal(0.1, 0.4, 0.0)
40+
>>> sr.logs_normal(0.1, 0.4, 0.0)
4141
>>> 0.033898
4242
"""
43-
return srb[backend].logscore_normal(mu, sigma, observation)
43+
return srb[backend].logs_normal(mu, sigma, observation)

scoringrules/backend/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def variogram_score(
155155
out = self.np.sum((obs_diff - vfcts) ** 2, axis=(-2, -1))
156156
return out
157157

158-
def logscore_normal(
158+
def logs_normal(
159159
self,
160160
mu: ArrayLike,
161161
sigma: ArrayLike,

scoringrules/backend/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def crps_lognormal(
3333
"""Compute the CRPS for the log normal distribution."""
3434

3535
@abstractmethod
36-
def logscore_normal(
36+
def logs_normal(
3737
self,
3838
mu: ArrayLike,
3939
sigma: ArrayLike,

scoringrules/backend/gufuncs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _crps_lognormal_ufunc(mulog: float, sigmalog: float, observation: float) ->
4242

4343

4444
@vectorize(["float32(float32, float32, float32)", "float64(float64, float64, float64)"])
45-
def _logscore_normal_ufunc(mu: np.ndarray, sigma: np.ndarray, observation: np.ndarray):
45+
def _logs_normal_ufunc(mu: np.ndarray, sigma: np.ndarray, observation: np.ndarray):
4646
ω = (observation - mu) / sigma
4747
return -np.log(_norm_pdf(ω) / sigma)
4848

@@ -324,7 +324,7 @@ def _variogram_score_gufunc(forecasts, observation, p, out):
324324
"_crps_ensemble_qd_gufunc",
325325
"_crps_normal_ufunc",
326326
"_crps_lognormal_ufunc",
327-
"_logscore_normal_ufunc",
327+
"_logs_normal_ufunc",
328328
"_energy_score_gufunc",
329329
"_brier_score_ufunc",
330330
"_variogram_score_gufunc",

scoringrules/backend/numba.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
_crps_lognormal_ufunc,
1717
_crps_normal_ufunc,
1818
_energy_score_gufunc,
19-
_logscore_normal_ufunc,
19+
_logs_normal_ufunc,
2020
_variogram_score_gufunc,
2121
)
2222

@@ -87,13 +87,13 @@ def crps_lognormal(
8787
return _crps_lognormal_ufunc(mulog, sigmalog, observation)
8888

8989
@staticmethod
90-
def logscore_normal(
90+
def logs_normal(
9191
mu: ArrayLike,
9292
sigma: ArrayLike,
9393
observation: ArrayLike,
9494
) -> Array:
9595
"""Compute the logarithmic score for a normal distribution."""
96-
return _logscore_normal_ufunc(mu, sigma, observation)
96+
return _logs_normal_ufunc(mu, sigma, observation)
9797

9898
@staticmethod
9999
def energy_score(

tests/test_logs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
import pytest
3+
from scoringrules import _logs
4+
5+
ENSEMBLE_SIZE = 51
6+
N = 100
7+
8+
BACKENDS = ["numpy", "numba", "jax"]
9+
10+
11+
@pytest.mark.parametrize("backend", BACKENDS)
12+
def test_normal(backend):
13+
res = _logs.logs_normal(0.1, 0.1, 0.0, backend=backend)
14+
assert np.isclose(res, -0.8836466, rtol=1e-5)

0 commit comments

Comments
 (0)