Skip to content

Commit 9ad9bc9

Browse files
committed
generalise multivariate Dawid-Sebastiani score to support jax backend
1 parent a315365 commit 9ad9bc9

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

scoringrules/backend/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def det(self, x: "NDArray") -> "NDArray":
277277
return np.linalg.det(x)
278278

279279
def reshape(self, x: "NDArray", shape: int | tuple[int, ...]) -> "NDArray":
280-
return np.reshape(x, shape=shape)
280+
return np.reshape(x, shape)
281281

282282

283283
class NumbaBackend(NumpyBackend):

scoringrules/core/dss/_score.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,20 @@ def ds_score_mv(
3030
"""Compute the Dawid Sebastiani Score for a multivariate finite ensemble."""
3131
B = backends.active if backend is None else backends[backend]
3232

33-
batch_dims = fct.shape[:-2]
34-
out = B.zeros(batch_dims)
33+
batch_shape = fct.shape[:-2]
34+
M, D = fct.shape[-2:]
3535

36-
# nested loop over all batch dimensions
37-
def recursive_loop(current_index, depth):
38-
if depth == len(batch_dims):
39-
fct_i = fct[tuple(current_index)] # (M, D)
40-
obs_i = obs[tuple(current_index)] # (D)
41-
score = ds_score_mv_mat(obs_i, fct_i, bias, backend=backend) # ()
42-
out[tuple(current_index)] = score
43-
return
36+
fct_flat = B.reshape(fct, (-1, M, D)) # (... M D)
37+
obs_flat = B.reshape(obs, (-1, D)) # (... D)
4438

45-
for i in range(batch_dims[depth]):
46-
recursive_loop(current_index + [i], depth + 1)
39+
# list to collect scores for each batch
40+
scores = [
41+
ds_score_mv_mat(obs_i, fct_i, bias=bias, backend=backend)
42+
for obs_i, fct_i in zip(obs_flat, fct_flat)
43+
]
4744

48-
recursive_loop([], 0)
49-
return out
45+
# reshape to original batch shape
46+
return B.reshape(B.stack(scores), batch_shape)
5047

5148

5249
def ds_score_mv_mat(

0 commit comments

Comments
 (0)