-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathaligners.py
More file actions
367 lines (293 loc) · 14.6 KB
/
aligners.py
File metadata and controls
367 lines (293 loc) · 14.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
from abc import ABC
import numpy as np
from scipy import interpolate
import scipy
from plotFunc import *
from aux import *
from env import *
class Aligner(ABC):
def alignTrajXVal(self, ts, xsZ, xToAlign):
"""
aligns trajectories to a fixed std dev level away from controls
assumes xsZ are z-scores of x-values
Parameters
----------
ts - vector of timepoints
xsZ - vector of biomarker values for each timepoint
xToAlign - (nrBiomk,1) array of biomarker values to align with
Returns
-------
tsNew - shifted trajectories
xToAlignIsOutside - nrTraj array of flags specifying if the x used to align the traj is outside
"""
nrPoints = xsZ.shape[0]
nrBiomk = xsZ.shape[1]
tsNew = np.zeros(ts.shape)
xToAlignIsOutside = np.zeros(nrBiomk, bool)
for b in range(nrBiomk):
closestP = 0 # find which x-value is closest to the stdLevel
minDist = np.abs(xsZ[0 ,b] - xToAlign[b])
for p in range(nrPoints):
newDist = np.abs(xsZ[p ,b] - xToAlign[b])
if newDist < minDist:
closestP = p
minDist = newDist
tsNew[: ,b] = ts[: ,b] - ts[closestP ,b]
xToAlignIsOutside[b] = xToAlign[b] < min(xsZ[: ,b]) or xToAlign[b] > max(xsZ[: ,b])
return tsNew, xToAlignIsOutside, xToAlign
class AlignerBaseVisit(Aligner):
def align(self, dpmObj, tsNzSect, xsZ, longData, longDiag):
anchorID = dpmObj.params['anchorID']
muCtlData = dpmObj.muCtlData
sigmaCtlData = dpmObj.sigmaCtlData
blDataPatMean = calcBlDataPatMeanZ(tsNzSect, xsZ, longData, longDiag,
anchorID, muCtlData, sigmaCtlData)
return self.alignTrajXVal(tsNzSect, xsZ, blDataPatMean)
class AlignerBaseVisitNoise(AlignerBaseVisit):
# xValShift
# def __init__(self, xValShift):
# self.xValShift = xValShift
# overwrite parent's method
def alignNoise(self, tsNzSect, xsZ, longData, longDiag, anchorID, muCtlData, sigmaCtlData, xValShift):
# aligns traj to biomk mean value at baseline visit of patients, but adds gaussian noise to the alignment
blDataPatMean = calcBlDataPatMeanZ(tsNzSect, xsZ, longData, longDiag,
anchorID, muCtlData, sigmaCtlData)
return self.alignTrajXVal(tsNzSect, xsZ, blDataPatMean + xValShift)
def calcBlDataPatMeanZ(tsNzSect, xsZ, longData, longDiag, anchorID, muCtlData, sigmaCtlData):
# estimate avg value at baseline visit of patients for each biomk
nrSubj = len(longData)
nrBiomk = longData[0].shape[1]
blData = np.zeros((nrSubj, nrBiomk))
for i in range(len(longData)):
blData[i ,:] = longData[i][0 ,:]
blDataPat = blData[longDiag == anchorID, :]
blDataPatMean = np.nanmean(blDataPat, axis=0)
blDataPatMeanZ = (blDataPatMean - muCtlData) / sigmaCtlData # convert to z-score
print(blDataPatMean.shape)
print(blDataPatMean)
print(np.nanmax(xsZ), np.nanmin(xsZ))
assert(len(blDataPatMeanZ) == nrBiomk)
# align traj to the mean of the baseline visit of patients
return blDataPatMeanZ
class AlignerEM(Aligner):
def align(self, dpmObj, tsNzSect, xsZ, longData, longDiag):
anchorID = dpmObj.params['anchorID']
muCtlData = dpmObj.muCtlData
sigmaCtlData = dpmObj.sigmaCtlData
# aligns traj to biomk mean value at baseline visit of patients, but adds gaussian noise to the alignment
blDataPatMean = calcBlDataPatMeanZ(tsNzSect, xsZ, longData, longDiag,
anchorID, muCtlData, sigmaCtlData)
tsPatMeanAligned, xToAlignIsOutside = self.alignTrajXVal(tsNzSect, xsZ, blDataPatMean)
test = tsPatMeanAligned[0,1]
dpmObj.ts = tsPatMeanAligned
patMask = np.logical_not(np.in1d(dpmObj.diag, dpmObj.params['excludeXvalidID']))
dataNonZ = dpmObj.data[patMask ,:]
data = dpmObj.getDataZ(dataNonZ) # convert data to z-scores as this is how the traj work
diag = dpmObj.diag[patMask ]
nrSubj, nrBiomk = dataNonZ.shape
nrPoints = tsPatMeanAligned.shape[0]
initBiomkShifts = np.zeros(nrBiomk)
initSubjStages = np.zeros(nrSubj)
minBiomkS = -20
maxBiomkS = 20
nrIterations = 10
nrBiomkShifts = 100
biomkShifts = np.zeros((nrBiomk,nrIterations))
sigmaSqs = np.zeros((nrBiomk,nrIterations))
sigmaSqs[:, 0] = np.power(np.mean(dpmObj.estimNoiseZ, axis=0),2)
maxLikStages = np.zeros((nrSubj,nrIterations))
patStagesSet = dpmObj.params['tsStages']
biomkShiftsSet = np.linspace(minBiomkS, maxBiomkS, num=nrBiomkShifts)
dataIndicesStaging = np.arange(0, len(dpmObj.diag), 1)
fs = [interpolate.interp1d(tsPatMeanAligned[:, b], xsZ[:, b], kind='linear', fill_value='extrapolate')
for b in range(nrBiomk)]
logL = np.zeros(nrIterations) # usually is -inf because of the biomk shifts normalisation (sum to 0)
logL[0] = calcIncompleteDataLogL(data, fs, patStagesSet, biomkShifts[:, 0],
sigmaSqs[:, 0])
print('itNr %d ' % 0, 'logL', logL[0])
for itNr in range(1,nrIterations):
# E-step - estimate the staging probabilities of subjects given the current alignment of traj
# i.e. E_{p(Z|X,theta_old)}
dpmObj.ts = tsPatMeanAligned + np.tile(biomkShifts[:,itNr-1], (nrPoints, 1))
dpmObj.covMatNoiseZ = np.diag(sigmaSqs[:, itNr-1]) # already contains sigma squares, just put in diag matrix
(maxLikStages[:, itNr], maxStagesIndex, stagingProb, stagingLik, _) = \
dpmObj.stageSubjectsData(dataNonZ)
# M-step - estimate the traj alignment by maximising E_{p(Z|X,theta_old)} (log p(X,Z|theta))
# estimate the optimal noise level sigma using the EM update
sigmaSqs[:, itNr] = estimateSigmasEMupdate(stagingProb, data, fs, patStagesSet, biomkShifts[:,itNr])
# now estimate the optimal biomk shifts given the stagingProb and the sigma levels
QsumB = np.zeros((nrBiomk, nrBiomkShifts))
for b in range(nrBiomk): # for each trajectory
for bs, biomkShiftCurr in enumerate(biomkShiftsSet): # for each possible shift of that trajectory
for s, stage in enumerate(patStagesSet): # for each possible stage of the subjects
# try to vectorize over all the participants at that particular stage
meanCurrStage = fs[b](stage-biomkShiftCurr)
QsumB[b,bs] += np.sum(np.multiply(stagingProb[:,s], np.power(data[:,b] - meanCurrStage,2)))
# print(stagingProb[:,s], np.power(dataZ[:,b] - meanCurrStage,2))
# print(QsumB[b,bs])
# print(asda)
#find the biomk shift that caused the highest increase in Q(theta, theta^old)
# take min as QsubB is a simplified version of Q(theta, theta^old) which needs to be minimized
bsMin = np.argmin(QsumB[b,:])
biomkShifts[b, itNr] = biomkShiftsSet[bsMin]
# print(QsumB[b, :], bsMin, biomkShiftsSet[bsMin])
# print(sdas)
# print('maxLikStages[1:20]', maxLikStages[1:20])
QsumMin = np.min(QsumB, axis=1)
assert(QsumMin.shape[0] == nrBiomk)
QsumSum = np.sum(QsumMin)
Qtheta = calcQthetaThetaOld(stagingProb, data, fs, patStagesSet, biomkShifts[:, itNr],
sigmaSqs[:, itNr])
logL = calcIncompleteDataLogL(data, fs, patStagesSet, biomkShifts[:, itNr],
sigmaSqs[:, itNr])
print('itNr %d QsumSum %f ' % (itNr, QsumSum), 'Qtheta', Qtheta, 'logL', logL)
print(' biomkShifts[:, itNr]', biomkShifts[:, itNr])
# print('QsumB[0,:]', QsumB[0,:])
# print('QsumMin', QsumMin)
# print('stagingProb[1:3,:]', stagingProb[1:3,:])
# print(asdas)
assert(tsPatMeanAligned[0,1] == test)
# fig = plotTrajSubfigWithData(dpmObj.ts, dpmObj.xsZ, None, None, None, dpmObj.params['labels'],
# dpmObj.params['plotTrajParams'], dataZ, dpmObj.diag, maxLikStages[:, itNr], thresh=0)
# fig.savefig('matfiles/%s/EMfig_it%d.png' % (dpmObj.expName, itNr), dpi=100)
dpmObj.ts = tsPatMeanAligned + np.tile(biomkShifts[:, -1], (nrPoints, 1))
print('Biomk shifts', biomkShifts[:,-1])
# print(asdsa)
res = {'biomkShifts' : biomkShifts}
return dpmObj.ts, res
class AlignerLogLOpt(Aligner):
def align(self, dpmObj, tsNzSect, xsZ, longData, longDiag):
"""
aligns the trajectories using direct optimisation on the incomplete data logL from calcIncompleteDataLogL
the incomplete data logL is the marginal of logL over the stages
Parameters
----------
dpmObj
tsNzSect
xsZ
longData
longDiag
Returns
-------
"""
anchorID = dpmObj.params['anchorID']
muCtlData = dpmObj.muCtlData
sigmaCtlData = dpmObj.sigmaCtlData
# aligns traj to biomk mean value at baseline visit of patients, but adds gaussian noise to the alignment
blDataPatMean = calcBlDataPatMeanZ(tsNzSect, xsZ, longData, longDiag,
anchorID, muCtlData, sigmaCtlData)
tsPatMeanAligned, xToAlignIsOutside = self.alignTrajXVal(tsNzSect, xsZ, blDataPatMean)
test = tsPatMeanAligned[0,1]
dpmObj.ts = tsPatMeanAligned
patMask = np.logical_not(np.in1d(dpmObj.diag, dpmObj.params['excludeXvalidID']))
dataIndicesNN = np.logical_and(patMask, np.sum(np.isnan(dpmObj.data), 1) == 0)
dataNonZ = dpmObj.data[dataIndicesNN ,:]
data = dpmObj.getDataZ(dataNonZ) # convert data to z-scores as this is how the traj work
diag = dpmObj.diag[dataIndicesNN]
nrSubj, nrBiomk = dataNonZ.shape
nrPoints = tsPatMeanAligned.shape[0]
initBiomkShifts = np.zeros(nrBiomk)
initSubjStages = np.zeros(nrSubj)
minBiomkS = -20
maxBiomkS = 20
nrIterations = 30
nrBiomkShifts = 100
biomkShifts = np.zeros((nrBiomk,nrIterations))
sigmaSqs = np.zeros((nrBiomk,nrIterations))
sigmaSqs[:, 0] = np.power(np.mean(dpmObj.estimNoiseZ, axis=0),2)
maxLikStages = np.zeros((nrSubj,nrIterations))
patStagesSet = dpmObj.params['tsStages']
biomkShiftsSet = np.linspace(minBiomkS, maxBiomkS, num=nrBiomkShifts)
fs = [interpolate.interp1d(tsPatMeanAligned[:, b], xsZ[:, b], kind='linear', fill_value='extrapolate')
for b in range(nrBiomk)]
logL = np.zeros(nrIterations) # usually is -inf because of the biomk shifts normalisation (sum to 0)
logL[0] = calcIncompleteDataLogL(data, fs, patStagesSet, biomkShifts[:, 0],
sigmaSqs[:, 0])
print('itNr %d ' % 0, 'logL', logL[0])
for itNr in range(1,nrIterations):
# estimate the optimal noise level sigma
dpmObj.ts = tsPatMeanAligned + np.tile(biomkShifts[:,itNr-1], (nrPoints, 1))
dpmObj.covMatNoiseZ = np.diag(sigmaSqs[:, itNr-1]) # already contains sigma squares, just put in diag matrix
(maxLikStages[:, itNr], maxStagesIndex, stagingProb, stagingLik, _) = dpmObj.stageSubjectsData(dataNonZ)
sigmaSqs[:, itNr] = estimateSigmasEMupdate(stagingProb, data, fs, patStagesSet, biomkShifts[:, itNr])
# make sure the biomk shifts sum to zero, as there is one extra DOF
addDOF = lambda x: np.append(x, 1-np.sum(x))
fun = lambda x: -calcIncompleteDataLogL(data, fs, patStagesSet, addDOF(x), sigmaSqs[:, itNr])
optRes = scipy.optimize.minimize(fun=fun, x0=biomkShifts[:-1,itNr-1], method='Powell')
assert optRes.success
print(optRes.x)
biomkShifts[:, itNr] = addDOF(optRes.x)
logL[itNr] = calcIncompleteDataLogL(data, fs, patStagesSet, biomkShifts[:, itNr],
sigmaSqs[:, itNr])
print('itNr %d ' % itNr, 'logL', logL[itNr])
print('sigmaSqs[:, itNr]', sigmaSqs[:, itNr])
print(' biomkShifts[:, itNr]', biomkShifts[:, itNr])
# print('QsumB[0,:]', QsumB[0,:])
# print('QsumMin', QsumMin)
# print('stagingProb[1:3,:]', stagingProb[1:3,:])
# print(asdas)
assert(tsPatMeanAligned[0,1] == test)
# fig = plotTrajSubfigWithData(dpmObj.ts, dpmObj.xsZ, None, None, None, dpmObj.params['labels'],
# dpmObj.params['plotTrajParams'], dataZ, dpmObj.diag, maxLikStages[:, itNr], thresh=0)
# fig.savefig('matfiles/%s/EMfig_it%d.png' % (dpmObj.expName, itNr), dpi=100)
if logL[itNr] < logL[itNr - 1]:
# break the loop as the logL is decreasing
break
print('Biomk shifts', biomkShifts[:,-1])
# print(asdsa)
# shift all traj with one constant so that 0-axis is optimally separating CTL from patients (or MCI)
threshMask = np.in1d(diag, [CTL, dpmObj.params['anchorID']])
# print(diag[threshMask], diag, dpmObj.params['anchorID'])
if len(np.unique(diag[threshMask])) != 2:
raise Exception("there should be exactly 2 groups for finding the ideal threshold")
idealThresh = findIdealThresh(stagingProb[threshMask,:], diag[threshMask])
dpmObj.ts -= patStagesSet[idealThresh] # center the traj around the ideal separating threshold
res = {'biomkShifts' : biomkShifts, 'logL':logL}
return dpmObj.ts, res
def estimateSigmasEMupdate(stagingProb, data, fs, patStagesSet, biomkShifts):
nrSubj, nrBiomk = data.shape
sigmaSqs = np.zeros(nrBiomk)
for b in range(nrBiomk):
for s, stage in enumerate(patStagesSet): # for each possible stage of the subjects
# try to vectorize over all the participants at that particular stage
meanCurrStage = fs[b](stage - biomkShifts[b])
sigmaSqs[b] += (1/nrSubj) * np.sum(np.multiply(stagingProb[:, s], np.power(data[:, b] - meanCurrStage, 2)))
return sigmaSqs
def calcQthetaThetaOld(stagingProb, data, fs, patStagesSet, biomkShifts, sigmaSqs):
nrSubj, nrBiomk = data.shape
logLiks = np.zeros(stagingProb.shape)
for s, stage in enumerate(patStagesSet): # for each possible stage of the subjects
# try to vectorize over all the participants at that particular stage
meanCurrStage = [fs[b](stage - biomkShifts[b]) for b in range(nrBiomk)]
logLiks[:,s] = scipy.stats.multivariate_normal.logpdf(data, meanCurrStage, np.diag(sigmaSqs))
Qsum = np.sum(np.multiply(stagingProb, logLiks))
# print(logLiks)
# print(Qsum)
# print(asdsa)
return Qsum
def calcIncompleteDataLogL(data, fs, patStagesSet, biomkShifts, sigmaSqs):
"""
computes the incomplete data log-likelihood, i.e. p(X|theta) = sum_Z p(X,Z|theta)
so it sums ove all the possible stages of the patients
Parameters
----------
data
fs
patStagesSet
biomkShifts
sigmaSqs
Returns
-------
"""
nrSubj, nrBiomk = data.shape
nrStages = len(patStagesSet)
liks = np.zeros((nrSubj, nrStages), float)
for s, stage in enumerate(patStagesSet): # for each possible stage of the subjects
# try to vectorize over all the participants at that particular stage
meanCurrStage = [fs[b](stage - biomkShifts[b]) for b in range(nrBiomk)]
liks[:,s] = scipy.stats.multivariate_normal.pdf(data, meanCurrStage, np.diag(sigmaSqs))
logL = np.sum(np.log(np.sum((1/nrStages) * liks, axis=1)))
# print(liks)
# print(logL)
# print(asdsa)
return logL