Skip to content

Commit 80fd35d

Browse files
committed
Merge branch 'main' into docu
2 parents 66563af + 30b5c4e commit 80fd35d

File tree

4 files changed

+32
-32
lines changed

4 files changed

+32
-32
lines changed

scoringrules/core/crps/_closed.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def gtclogistic(
316316

317317
c = (1 - lmass - umass) / (F_u - F_l)
318318

319-
s1_u = B.where(u_inf and umass == 0.0, 0.0, u * umass**2)
320-
s1_l = B.where(l_inf and lmass == 0.0, 0.0, l * lmass**2)
319+
s1_u = B.where(u_inf & (umass == 0.0), 0.0, u * umass**2)
320+
s1_l = B.where(l_inf & (lmass == 0.0), 0.0, l * lmass**2)
321321

322322
s1 = B.abs(ω - z) + s1_u - s1_l
323323
s2 = c * z * ((1 - 2 * lmass) * F_u + (1 - 2 * umass) * F_l) / (1 - lmass - umass)
@@ -359,8 +359,8 @@ def gtcnormal(
359359

360360
u = B.where(u_inf, B.nan, u)
361361
l = B.where(l_inf, B.nan, l)
362-
s1_u = B.where(u_inf and umass == 0.0, 0.0, u * umass**2)
363-
s1_l = B.where(l_inf and lmass == 0.0, 0.0, l * lmass**2)
362+
s1_u = B.where(u_inf & (umass == 0.0), 0.0, u * umass**2)
363+
s1_l = B.where(l_inf & (lmass == 0.0), 0.0, l * lmass**2)
364364

365365
c = (1 - lmass - umass) / (F_u - F_l)
366366

@@ -410,8 +410,8 @@ def gtct(
410410
u = B.where(u_inf, B.nan, u)
411411
l = B.where(l_inf, B.nan, l)
412412

413-
s1_u = B.where(u_inf and umass == 0.0, 0.0, u * umass**2)
414-
s1_l = B.where(l_inf and lmass == 0.0, 0.0, l * lmass**2)
413+
s1_u = B.where(u_inf & (umass == 0.0), 0.0, u * umass**2)
414+
s1_l = B.where(l_inf & (lmass == 0.0), 0.0, l * lmass**2)
415415

416416
G_u = B.where(u_inf, 0.0, -f_u * (df + u**2) / (df - 1))
417417
G_l = B.where(l_inf, 0.0, -f_l * (df + l**2) / (df - 1))

scoringrules/core/crps/_gufuncs.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def _crps_ensemble_nrg_gufunc(obs: np.ndarray, fct: np.ndarray, out: np.ndarray)
118118
e_1 = 0
119119
e_2 = 0
120120

121-
for x_i in fct:
122-
e_1 += abs(x_i - obs)
123-
for x_j in fct:
124-
e_2 += abs(x_i - x_j)
121+
for i in range(M):
122+
e_1 += abs(fct[i] - obs)
123+
for j in range(i + 1, M):
124+
e_2 += 2 * abs(fct[j] - fct[i])
125125

126126
out[0] = e_1 / M - 0.5 * e_2 / (M**2)
127127

@@ -145,10 +145,10 @@ def _crps_ensemble_fair_gufunc(obs: np.ndarray, fct: np.ndarray, out: np.ndarray
145145
e_1 = 0
146146
e_2 = 0
147147

148-
for x_i in fct:
149-
e_1 += abs(x_i - obs)
150-
for x_j in fct:
151-
e_2 += abs(x_i - x_j)
148+
for i in range(M):
149+
e_1 += abs(fct[i] - obs)
150+
for j in range(i + 1, M):
151+
e_2 += 2 * abs(fct[j] - fct[i])
152152

153153
out[0] = e_1 / M - 0.5 * e_2 / (M * (M - 1))
154154

@@ -250,10 +250,10 @@ def _owcrps_ensemble_nrg_gufunc(
250250
e_1 = 0.0
251251
e_2 = 0.0
252252

253-
for i, x_i in enumerate(fct):
254-
e_1 += abs(x_i - obs) * fw[i] * ow
255-
for j, x_j in enumerate(fct):
256-
e_2 += abs(x_i - x_j) * fw[i] * fw[j] * ow
253+
for i in range(M):
254+
e_1 += abs(fct[i] - obs) * fw[i] * ow
255+
for j in range(i + 1, M):
256+
e_2 += 2 * abs(fct[i] - fct[j]) * fw[i] * fw[j] * ow
257257

258258
wbar = np.mean(fw)
259259

@@ -286,10 +286,10 @@ def _vrcrps_ensemble_nrg_gufunc(
286286
e_1 = 0.0
287287
e_2 = 0.0
288288

289-
for i, x_i in enumerate(fct):
290-
e_1 += abs(x_i - obs) * fw[i] * ow
291-
for j, x_j in enumerate(fct):
292-
e_2 += abs(x_i - x_j) * fw[i] * fw[j]
289+
for i in range(M):
290+
e_1 += abs(fct[i] - obs) * fw[i] * ow
291+
for j in range(i + 1, M):
292+
e_2 += 2 * abs(fct[i] - fct[j]) * fw[i] * fw[j]
293293

294294
wbar = np.mean(fw)
295295
wabs_x = np.mean(np.abs(fct) * fw)

scoringrules/core/energy/_gufuncs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def _energy_score_gufunc(
2121
e_2 = 0.0
2222
for i in range(M):
2323
e_1 += float(np.linalg.norm(fct[i] - obs))
24-
for j in range(M):
25-
e_2 += float(np.linalg.norm(fct[i] - fct[j]))
24+
for j in range(i + 1, M):
25+
e_2 += 2 * float(np.linalg.norm(fct[i] - fct[j]))
2626

2727
out[0] = e_1 / M - 0.5 / (M**2) * e_2
2828

@@ -49,8 +49,8 @@ def _owenergy_score_gufunc(
4949
e_2 = 0.0
5050
for i in range(M):
5151
e_1 += float(np.linalg.norm(fct[i] - obs) * fw[i] * ow)
52-
for j in range(M):
53-
e_2 += float(np.linalg.norm(fct[i] - fct[j]) * fw[i] * fw[j] * ow)
52+
for j in range(i + 1, M):
53+
e_2 += 2 * float(np.linalg.norm(fct[i] - fct[j]) * fw[i] * fw[j] * ow)
5454

5555
wbar = np.mean(fw)
5656

@@ -81,8 +81,8 @@ def _vrenergy_score_gufunc(
8181
for i in range(M):
8282
e_1 += float(np.linalg.norm(fct[i] - obs) * fw[i] * ow)
8383
wabs_x += np.linalg.norm(fct[i]) * fw[i]
84-
for j in range(M):
85-
e_2 += float(np.linalg.norm(fct[i] - fct[j]) * fw[i] * fw[j])
84+
for j in range(i + 1, M):
85+
e_2 += 2 * float(np.linalg.norm(fct[i] - fct[j]) * fw[i] * fw[j])
8686

8787
wabs_x = wabs_x / M
8888
wbar = np.mean(fw)

scoringrules/core/variogram/_gufuncs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def _owvariogram_score_gufunc(obs, fct, p, ow, fw, out):
4242
rho1 = abs(fct[k, i] - fct[k, j]) ** p
4343
rho2 = abs(obs[i] - obs[j]) ** p
4444
e_1 += (rho1 - rho2) ** 2 * fw[k] * ow
45-
for m in range(M):
45+
for m in range(k + 1, M):
4646
for i in range(D):
4747
for j in range(D):
4848
rho1 = abs(fct[k, i] - fct[k, j]) ** p
4949
rho2 = abs(fct[m, i] - fct[m, j]) ** p
50-
e_2 += (rho1 - rho2) ** 2 * fw[k] * fw[m] * ow
50+
e_2 += 2 * ((rho1 - rho2) ** 2) * fw[k] * fw[m] * ow
5151

5252
wbar = np.mean(fw)
5353

@@ -75,12 +75,12 @@ def _vrvariogram_score_gufunc(obs, fct, p, ow, fw, out):
7575
rho2 = abs(obs[i] - obs[j]) ** p
7676
e_1 += (rho1 - rho2) ** 2 * fw[k] * ow
7777
e_3_x += (rho1) ** 2 * fw[k]
78-
for m in range(M):
78+
for m in range(k + 1, M):
7979
for i in range(D):
8080
for j in range(D):
8181
rho1 = abs(fct[k, i] - fct[k, j]) ** p
8282
rho2 = abs(fct[m, i] - fct[m, j]) ** p
83-
e_2 += (rho1 - rho2) ** 2 * fw[k] * fw[m]
83+
e_2 += 2 * ((rho1 - rho2) ** 2) * fw[k] * fw[m]
8484

8585
e_3_x *= 1 / M
8686
wbar = np.mean(fw)

0 commit comments

Comments
 (0)