@@ -30,23 +30,20 @@ def ds_score_mv(
3030 """Compute the Dawid Sebastiani Score for a multivariate finite ensemble."""
3131 B = backends .active if backend is None else backends [backend ]
3232
33- batch_dims = fct .shape [:- 2 ]
34- out = B . zeros ( batch_dims )
33+ batch_shape = fct .shape [:- 2 ]
34+ M , D = fct . shape [ - 2 :]
3535
36- # nested loop over all batch dimensions
37- def recursive_loop (current_index , depth ):
38- if depth == len (batch_dims ):
39- fct_i = fct [tuple (current_index )] # (M, D)
40- obs_i = obs [tuple (current_index )] # (D)
41- score = ds_score_mv_mat (obs_i , fct_i , bias , backend = backend ) # ()
42- out [tuple (current_index )] = score
43- return
36+ fct_flat = B .reshape (fct , (- 1 , M , D )) # (... M D)
37+ obs_flat = B .reshape (obs , (- 1 , D )) # (... D)
4438
45- for i in range (batch_dims [depth ]):
46- recursive_loop (current_index + [i ], depth + 1 )
39+ # list to collect scores for each batch
40+ scores = [
41+ ds_score_mv_mat (obs_i , fct_i , bias = bias , backend = backend )
42+ for obs_i , fct_i in zip (obs_flat , fct_flat )
43+ ]
4744
48- recursive_loop ([], 0 )
49- return out
45+ # reshape to original batch shape
46+ return B . reshape ( B . stack ( scores ), batch_shape )
5047
5148
5249def ds_score_mv_mat (
0 commit comments