forked from nansencenter/DAPPER
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstats.py
More file actions
331 lines (275 loc) · 9.91 KB
/
stats.py
File metadata and controls
331 lines (275 loc) · 9.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from common import *
class Stats(MLR_Print):
"""
Contains and computes statistics of the DA methods.
"""
# Adjust this to omit heavy computations
comp_threshold_3 = 51
# Used by MLR_Print
excluded = MLR_Print.excluded + ['setup','config','xx','yy']
precision = 3
ordr_by_linenum = -1
def __init__(self,config,setup,xx,yy):
"""
Init the default statistics.
Note: you may well allocate & compute individual stats elsewhere,
and simply assigne them as an attribute to the stats instance.
"""
self.config = config
self.setup = setup
self.xx = xx
self.yy = yy
m = setup.f.m ; assert m ==xx.shape[1]
K = setup.t.K ; assert K ==xx.shape[0]-1
p = setup.h.m ; assert p ==yy.shape[1]
KObs = setup.t.KObs ; assert KObs==yy.shape[0]-1
# time-series constructor alias
fs = self.new_FAU_series
self.mu = fs(m) # Mean
self.var = fs(m) # Variances
self.mad = fs(m) # Mean abs deviations
self.err = fs(m) # Error (mu-truth)
self.logp_m = fs(1) # Marginal, Gaussian Log score
self.skew = fs(1) # Skewness
self.kurt = fs(1) # Kurtosis
self.rmv = fs(1) # Root-mean variance
self.rmse = fs(1) # Root-mean square error
if hasattr(config,'N'):
# Ensemble-only init
self._had_0v = False
self._is_ens = True
N = config.N
m_Nm = min(m,N)
self.w = fs(N) # Importance weights
self.rh = fs(m,dtype=int) # Rank histogram
#self.N = N # Use w.shape[1] instead
else:
# Linear-Gaussian assessment
self._is_ens = False
m_Nm = m
self.svals = fs(m_Nm) # Principal component (SVD) scores
self.umisf = fs(m_Nm) # Error in component directions
# Other.
self.trHK = np.full(KObs+1, nan)
self.infl = np.full(KObs+1, nan)
def assess(self,k,kObs=None,f_a_u=None,
E=None,w=None,mu=None,Cov=None):
"""
Common interface for both assess_ens and _ext.
f_a_u: One or more of ['f',' a', 'u'], indicating
that the result should be stored in (respectively)
the forecast/analysis/universal attribute.
Defaults: see source code.
If 'u' in f_a_u: call/update LivePlot.
"""
# Initial consistency checks.
if k==0:
if kObs is not None:
raise KeyError("Should not have any obs at initial time."+
"This very easily leads to bugs, and not 'DA convention'.")
if self._is_ens==True:
def rze(a,b,c):
raise TypeError("Expected "+a+" input, but "+b+" is "+c+" None")
if E is None: rze("ensemble","E","")
if mu is not None: rze("ensemble","my/Cov","not")
else:
if E is not None: rze("mu/Cov","E","not")
if mu is None: rze("mu/Cov","mu","")
# Defaults for f_a_u
if f_a_u is None:
if kObs is None:
f_a_u = 'u'
else:
f_a_u = 'au'
elif f_a_u == 'fau':
if kObs is None:
f_a_u = 'u'
key = (k,kObs,f_a_u)
LP = self.config.liveplotting
store_u = self.config.store_u
if not (LP or store_u) and kObs==None:
pass # Skip assessment
else:
# Prepare assessment call and arguments
if self._is_ens:
# Ensemble assessment
alias = self.assess_ens
state_prms = {'E':E,'w':w}
else:
# Linear-Gaussian assessment
alias = self.assess_ext
state_prms = {'mu':mu,'P':Cov}
# Call assessment
with np.errstate(divide='ignore',invalid='ignore'):
alias(key,**state_prms)
# In case of degeneracy, variance might be 0,
# causing warnings in computing skew/kurt/MGLS
# (which all normalize by variance).
# This should and will yield nan's, but we don't want
# the diagnostics computations to cause too many warnings,
# so we turned them off above. But we'll manually warn ONCE here.
if not getattr(self,'_had_0v',False) \
and np.allclose(sqrt(self.var[key]),0):
self._had_0v = True
warnings.warn("Sample variance was 0 at (k,kObs,fau) = " + str(key))
# LivePlot
if LP:
if k==0:
self.lplot = LivePlot(self,**state_prms,only=LP)
elif 'u' in f_a_u:
self.lplot.update(k,kObs,**state_prms)
def assess_ens(self,k,E,w=None):
"""Ensemble and Particle filter (weighted/importance) assessment."""
# Unpack
N,m = E.shape
x = self.xx[k[0]]
# Process weights
if w is None:
self._has_w = False
w = 1/N
else:
self._has_w = True
if np.isscalar(w):
assert w != 0
w = w*ones(N)
if abs(w.sum()-1) > 1e-5: raise_AFE("Weights did not sum to one.",k)
if not np.all(np.isfinite(E)): raise_AFE("Ensemble not finite.",k)
if not np.all(np.isreal(E)): raise_AFE("Ensemble not Real.",k)
self.w[k] = w
self.mu[k] = w @ E
A = E - self.mu[k]
# While A**2 is approx as fast as A*A,
# A**3 is 10x slower than A**2 (or A**2.0).
# => Use A2 = A**2, A3 = A*A2, A4=A*A3.
# But, to save memory, only use A_pow.
A_pow = A**2
self.var[k] = w @ A_pow
self.mad[k] = w @ abs(A) # Mean abs deviations
ub = unbias_var(w,avoid_pathological=True)
self.var[k] *= ub
# For simplicity, use naive (biased) formulae, derived
# from "empirical measure". See doc/unbiased_skew_kurt.jpg.
# Normalize by var. Compute "excess" kurt, which is 0 for Gaussians.
A_pow *= A
self.skew[k] = mean( w @ A_pow / self.var[k]**(3/2) )
A_pow *= A # idem.
self.kurt[k] = mean( w @ A_pow / self.var[k]**2 - 3 )
self.derivative_stats(k,x)
if sqrt(m*N) <= Stats.comp_threshold_3:
if N<=m:
_,s,UT = svd( (sqrt(w)*A.T).T, full_matrices=False)
s *= sqrt(ub) # Makes s^2 unbiased
self.svals[k] = s
self.umisf[k] = UT @ self.err[k]
else:
P = (A.T * w) @ A
s2,U = eigh(P)
s2 *= ub
self.svals[k] = sqrt(s2.clip(0))[::-1]
self.umisf[k] = U.T[::-1] @ self.err[k]
# For each state dim [i], compute rank of truth (x) among the ensemble (E)
Ex_sorted = np.sort(np.vstack((E,x)),axis=0,kind='heapsort')
self.rh[k] = [np.where(Ex_sorted[:,i] == x[i])[0][0] for i in range(m)]
def assess_ext(self,k,mu,P):
"""Kalman filter (Gaussian) assessment."""
isFinite = np.all(np.isfinite(mu)) # Do not check covariance
isReal = np.all(np.isreal(mu)) # (coz might not be explicitly availble)
if not isFinite: raise_AFE("Estimates not finite.",k)
if not isReal: raise_AFE("Estimates not Real.",k)
m = len(mu)
x = self.xx[k[0]]
self.mu[k] = mu
self.var[k] = P.diag if isinstance(P,CovMat) else diag(P)
self.mad[k] = sqrt(self.var[k])*sqrt(2/pi)
# ... because sqrt(2/pi) = ratio MAD/STD for Gaussians
self.derivative_stats(k,x)
if m <= Stats.comp_threshold_3:
P = P.full if isinstance(P,CovMat) else P
s2,U = nla.eigh(P)
self.svals[k] = sqrt(np.maximum(s2,0.0))[::-1]
self.umisf[k] = (U.T @ self.err[k])[::-1]
def derivative_stats(self,k,x):
"""Stats that apply for both _w and _ext paradigms and derive from the other stats."""
self.err[k] = self.mu[k] - x
self.rmv[k] = sqrt(mean(self.var[k]))
self.rmse[k] = sqrt(mean(self.err[k]**2))
self.MGLS(k)
def MGLS(self,k):
# Marginal Gaussian Log Score.
m = len(self.err[k])
ldet = log(self.var[k]).sum()
nmisf = self.var[k]**(-1/2) * self.err[k]
logp_m = (nmisf**2).sum() + ldet
self.logp_m[k] = logp_m/m
def average_in_time(self):
"""
Avarage all univariate (scalar) time series.
"""
avrg = AlignedDict()
for key,series in vars(self).items():
if key.startswith('_'):
continue
try:
# FAU_series
if isinstance(series,FAU_series):
# Compute
f_a_u = series.average()
# Add the sub-fields as sub-scripted fields
for sub in f_a_u: avrg[key+'_'+sub] = f_a_u[sub]
# Array
elif isinstance(series,np.ndarray):
if series.ndim > 1:
raise NotImplementedError
t = self.setup.t
if len(series) == len(t.kkObs):
inds = t.maskObs_BI
elif len(series) == len(t.kk):
inds = t.kk_BI
else:
raise ValueError
# Compute
avrg[key] = series_mean_with_conf(series[inds])
# Scalars
elif np.isscalar(series):
avrg[key] = series
else:
raise NotImplementedError
except NotImplementedError:
pass
return avrg
def new_FAU_series(self,m,**kwargs):
"Convenience FAU_series constructor."
store_u = self.config.store_u
return FAU_series(self.setup.t, m, store_u=store_u, **kwargs)
# TODO: Provide frontend initializer
# Better to initialize manually (np.full...)
# def new_array(self,f_a_u,m,**kwargs):
# "Convenience array constructor."
# t = self.setup.t
# # Convert int-len to shape-tuple
# if is_int(m):
# if m==1: m = ()
# else: m = (m,)
# # Set length
# if f_a_u=='a':
# K = t.KObs
# elif f_a_u=='u':
# K = t.K
# #
# return np.full((K+1,)+m,**kwargs)
def average_each_field(ss,axis=None):
assert ss.ndim == 2
if axis == 0:
ss = np.transpose(ss)
m,N = ss.shape
avrg = np.empty(m,dict)
keys = ss[0][0].keys()
for i,row in enumerate(ss):
avrg[i] = dict()
for key in keys:
avrg[i][key] = val_with_conf(
val = mean([s_ij[key].val for s_ij in row]),
conf = mean([s_ij[key].conf for s_ij in row])/sqrt(N))
# NB: This is a rudimentary averaging of confidence intervals
# Should be checked against variance of avrg[i][key].val
return avrg