@@ -23,7 +23,7 @@ def vs_ensemble(
2323 obs_diff = B .abs (B .expand_dims (obs , - 2 ) - B .expand_dims (obs , - 1 )) ** p # (... D D)
2424
2525 if estimator == "nrg" :
26- vfct = B .sum (B . abs ( fct_diff ) ** p , axis = - 3 ) / M # (... D D)
26+ vfct = B .sum (fct_diff , axis = - 3 ) / M # (... D D)
2727 out = B .sum ((obs_diff - vfct ) ** 2 , axis = (- 2 , - 1 )) # (...)
2828
2929 elif estimator == "fair" :
@@ -52,29 +52,20 @@ def owvs_ensemble(
5252) -> "Array" :
5353 """Compute the Outcome-Weighted Variogram Score for a multivariate finite ensemble."""
5454 B = backends .active if backend is None else backends [backend ]
55- M : int = fct .shape [- 2 ]
56- wbar = B .mean (fw , axis = - 1 )
57-
58- fct_diff = B .expand_dims (fct , - 2 ) - B .expand_dims (fct , - 1 ) # (... M D D)
59- fct_diff = B .abs (fct_diff ) ** p # (... M D D)
55+ M = fct .shape [- 2 ]
56+ wbar = B .sum (fw , - 1 ) / M
6057
61- obs_diff = B .expand_dims (obs , - 2 ) - B .expand_dims (obs , - 1 ) # (... D D)
62- obs_diff = B .abs (obs_diff ) ** p # (... D D)
63- del obs , fct
64-
65- E_1 = (fct_diff - B .expand_dims (obs_diff , - 3 )) ** 2 # (... M D D)
66- E_1 = B .sum (E_1 , axis = (- 2 , - 1 )) # (... M)
67- E_1 = B .sum (E_1 * fw * B .expand_dims (ow , - 1 ), axis = - 1 ) / (M * wbar ) # (...)
58+ fct_diff = (
59+ B .abs (B .expand_dims (fct , - 2 ) - B .expand_dims (fct , - 1 )) ** p
60+ ) # (... M D D)
61+ obs_diff = B .abs (B .expand_dims (obs , - 2 ) - B .expand_dims (obs , - 1 )) ** p # (... D D)
6862
69- fct_diff_spread = B .expand_dims (fct_diff , - 3 ) - B .expand_dims (
70- fct_diff , - 4
71- ) # (... M M D D)
72- fw_prod = B .expand_dims (fw , - 2 ) * B .expand_dims (fw , - 1 ) # (... M M)
73- E_2 = B .sum (fct_diff_spread ** 2 , axis = (- 2 , - 1 )) # (... M M)
74- E_2 *= fw_prod * B .expand_dims (ow , (- 2 , - 1 )) # (... M M)
75- E_2 = B .sum (E_2 , axis = (- 2 , - 1 )) / (M ** 2 * wbar ** 2 ) # (...)
63+ vfct = B .sum (fct_diff * B .expand_dims (fw , (- 2 , - 1 )), axis = - 3 ) / (
64+ M * B .expand_dims (wbar , (- 2 , - 1 ))
65+ ) # (... D D)
66+ out = B .sum (((obs_diff - vfct ) ** 2 ), axis = (- 2 , - 1 )) * ow # (...)
7667
77- return E_1 - 0.5 * E_2
68+ return out
7869
7970
8071def vrvs_ensemble (
0 commit comments