-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_gpfa.py
More file actions
289 lines (259 loc) · 11.7 KB
/
test_gpfa.py
File metadata and controls
289 lines (259 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# ...
# Copyright 2021 Brooks M. Musangu and Jan Drugowitsch.
# license Modified BSD, see LICENSE.txt for details.
# ...
"""
GPFA Unittests
"""
import unittest
import numpy as np
from scipy import linalg
from gpfa import GPFA
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel
class TestGPFA(unittest.TestCase):
"""
Unit tests for the GPFA analysis.
"""
def setUp(self):
"""
Set up synthetic data, initial parameters to help with the
functions to be tested
"""
np.random.seed(10)
self.bin_size = 0.02 # [s]
self.n_iters = 10
self.z_dim = 2
self.n_neurons = 2 # per rate therefore, there are 4 neurons
def gen_test_data(trial_lens, rates_a, rates_b, use_sqrt=True):
"""
Generate test data
There are 2 x number of neurons for each group -- the first 2
neurons use rates from set a, and the second 2 neurons use rates
from set b.
Args:
trial_lens : list of durations of each trial in [s]
len(trial_lens) corresponds with num of trials
rates_a : list of rates, one for each different time epoch.
Each each is a quarter of the total length.
rates_b : list of rates, one for each different time epoch
shuffled differently from rates_a
n_neurons : number of neurons
bin_size : bin size in [s] for analysis purpose
use_sqrt : boolean
if true, take square root of binned spike trains
Returns:
seqs : an array-like of binned spiketrains arrays per
trial
"""
# check the length of rates_a and rates_b must both be equal to 4
if len(rates_a) != 4:
raise ValueError("'rates_a' must have 4 elements in it")
if len(rates_b) != 4:
raise ValueError("'rates_b' must have 4 elements in it")
seqs = np.empty(len(trial_lens), object)
# generate data where num trials is len(trial_lens)
for n, t_l in enumerate(trial_lens):
# get number of bins for the each epoch
# each epoch is a quarter of the total length.
epoch_len = int(np.ceil(t_l / len(rates_a)))
nbins_per_epoch = int(epoch_len / self.bin_size)
# generate two spike trains each with two neurons
# neurons one and two use rates_a
# neuros three and four use rates_b
# concatenate them into one spiketrain
spk_rates_a = np.random.poisson(
rates_a[0], (self.n_neurons, nbins_per_epoch))
spk_rates_b = np.random.poisson(
rates_b[0], (self.n_neurons, nbins_per_epoch))
binned_spikecount = np.concatenate([spk_rates_a, spk_rates_b])
l_rates_a = len(rates_a)
# loop over the remaining rates
for i in range(1, l_rates_a):
# get number of bins for the remaining epochs
# n_bins_per_dur = int(durs[i] / bin_size)
spk_rates_a = np.random.poisson(
rates_a[i], (self.n_neurons, nbins_per_epoch))
spk_rates_b = np.random.poisson(
rates_b[i], (self.n_neurons, nbins_per_epoch))
spk_i = np.concatenate([spk_rates_a, spk_rates_b])
# concatenate previous spiketrains with new spiketrains
# from current duration
binned_spikecount = np.concatenate(
[binned_spikecount, spk_i], axis=1)
# take square root of the binned_spikeCount
# if `use_sqrt` is True (see paper for motivation)
if use_sqrt:
binned_sqrt_spkcount = np.sqrt(binned_spikecount)
seqs[n] = binned_sqrt_spkcount
return seqs
# ==================================================
# generate data
# ==================================================
rates_a = (2, 10, 2, 2)
rates_b = (2, 2, 10, 2)
trial_lens = [8, 10]
self.X = gen_test_data(trial_lens, rates_a, rates_b)
# get the number of time steps for each trial
self.T = np.array([X_n.shape[1] for X_n in self.X])
self.t_half = int(np.ceil(self.T[0] / 2.0))
# ==================================================
# initialize GPFA
# ==================================================
multi_params_kernel = ConstantKernel(
1-0.001, constant_value_bounds='fixed'
) * RBF(length_scale=0.1) + ConstantKernel(
0.001, constant_value_bounds='fixed'
) * WhiteKernel(
noise_level=1
)
seq_kernel = [ConstantKernel(
1-0.001, constant_value_bounds='fixed'
) * RBF(length_scale=0.1) + ConstantKernel(
0.001, constant_value_bounds='fixed'
) * WhiteKernel(
noise_level=1, noise_level_bounds='fixed'
),
ConstantKernel(
1-0.001, constant_value_bounds='fixed'
) * RBF(length_scale=0.1) + ConstantKernel(
0.001, constant_value_bounds='fixed'
) * WhiteKernel(
noise_level=1, noise_level_bounds='fixed'
)]
self.gpfa = GPFA(
bin_size=self.bin_size, z_dim=self.z_dim,
em_max_iters=self.n_iters
)
self.gpfa_with_seq_kernel = GPFA(
bin_size=self.bin_size, z_dim=self.z_dim,
gp_kernel=seq_kernel,
em_max_iters=self.n_iters
)
self.gpfa_with_multi_params_kernel = GPFA(
bin_size=self.bin_size, z_dim=self.z_dim,
gp_kernel=multi_params_kernel,
em_max_iters=self.n_iters
)
# fit the model
self.gpfa.fit(self.X)
self.gpfa_with_multi_params_kernel.fit(self.X)
self.gpfa_with_seq_kernel.fit(self.X)
self.results, _ = self.gpfa.predict(
returned_data=['Z_mu', 'Z_mu_orth'])
# get latents sequence and data log_likelihood
self.latent_seqs, self.ll = self.gpfa._infer_latents(self.X)
self.latent_seqs_multiparamskern, self.ll_multiparams_kernel = \
self.gpfa_with_multi_params_kernel._infer_latents(self.X)
self.latent_seqs_seqkernel, self.ll_seq_kernel = \
self.gpfa_with_seq_kernel._infer_latents(self.X)
def create_mu_and_cov(self, gpfa_inst):
"""
Create the GPFA mean and covariance using the equation
A5 from the Byron et a,. (2009) paper since the
implementation is different from equation A5. Here mean is
defined by (K_inv + C'R_invC)^-1 * C'R_inv * (y - d)
and covaraince is (K_inv + C'R_invC)
Paramters:
gpfa_inst : GPFA instance
Each istance is different based on the input params
Returns:
test_latent_seqs: numpy.ndarray
GPFA mean and cov
"""
test_latent_seqs = np.empty(
len(self.X), dtype=[('Z_mu', object), ('Z_cov', object)])
for n, t in enumerate(self.T):
# get the kernal as defined in GPFA
k_big = gpfa_inst._make_k_big(n_timesteps=t)
k_big_inv = linalg.inv(k_big)
rinv = np.diag(1.0 / np.diag(gpfa_inst.R_))
c_rinv = gpfa_inst.C_.T.dot(rinv)
# C'R_invC
c_rinv_c = c_rinv.dot(gpfa_inst.C_)
# subtract mean from activities (y - d)
dif = np.hstack([self.X[n]]) - \
gpfa_inst.d_[:, np.newaxis]
# C'R_inv * (y - d)
term1_mat = c_rinv.dot(dif).reshape(
(self.z_dim * t, -1), order='F')
# make a c_rinv_c big and block diagonal
blah = [c_rinv_c for _ in range(t)]
c_rinv_c_big = linalg.block_diag(*blah) # (x_dim*T) x (x_dim*T)
# (K_inv + C'R_invC)^-1 * C'R_inv * (y - d)
test_latent_seqs[n]['Z_mu'] = linalg.inv(
k_big_inv + c_rinv_c_big).dot(term1_mat).reshape(
(self.z_dim, t), order='F')
# compute covariance
cov = np.full((self.z_dim, self.z_dim, t), np.nan)
idx = np.arange(0, self.z_dim * t + 1, self.z_dim)
for i in range(t):
cov[:, :, i] = linalg.inv(
k_big_inv + c_rinv_c_big)[
idx[i]:idx[i + 1], idx[i]:idx[i + 1]]
test_latent_seqs[n]['Z_cov'] = cov
return test_latent_seqs
def test_infer_latents(self):
"""
Test the mean and cov for different GPFA instances
"""
# get test mean and cov for different GPFA instances
test_latent_seqs_gpfa = self.create_mu_and_cov(self.gpfa)
test_latent_seqs_seq_kern = self.create_mu_and_cov(
self.gpfa_with_seq_kernel
)
test_latent_seqs_multiparams = self.create_mu_and_cov(
self.gpfa_with_multi_params_kernel
)
# Assert
self.assertTrue(np.allclose(
self.latent_seqs['Z_mu'][0],
test_latent_seqs_gpfa['Z_mu'][0]))
self.assertTrue(np.allclose(
self.latent_seqs['Z_cov'][0],
test_latent_seqs_gpfa['Z_cov'][0]))
self.assertTrue(np.allclose(
self.latent_seqs_seqkernel['Z_mu'][0],
test_latent_seqs_seq_kern['Z_mu'][0]))
self.assertTrue(np.allclose(
self.latent_seqs_seqkernel['Z_cov'][0],
test_latent_seqs_seq_kern['Z_cov'][0]))
self.assertTrue(np.allclose(
self.latent_seqs_multiparamskern['Z_mu'][0],
test_latent_seqs_multiparams['Z_mu'][0]))
self.assertTrue(np.allclose(
self.latent_seqs_multiparamskern['Z_cov'][0],
test_latent_seqs_multiparams['Z_cov'][0]))
def test_data_loglikelihood(self):
"""
Test the data log_likelihood
"""
test_ll = -4092.076117337763
# Assert
self.assertAlmostEqual(test_ll, self.ll)
self.assertAlmostEqual(test_ll, self.ll_seq_kernel)
self.assertGreater(self.ll_multiparams_kernel, test_ll)
def test_orthonormalized_transform(self):
"""
Test GPFA orthonormalization transform of the parameter `C`.
"""
corth = self.gpfa.Corth_
c_orth = linalg.orth(self.gpfa.C_)
# Assert
self.assertTrue(np.allclose(c_orth, corth))
def test_orthonormalized_latents(self):
"""
Test GPFA orthonormalization functions applied in `gpfa.predict`.
"""
Z_mu = self.results['Z_mu'][0]
Z_mu_orth = self.results['Z_mu_orth'][0]
test_Z_mu_orth = np.dot(self.gpfa.OrthTrans_, Z_mu)
# Assert
self.assertTrue(np.allclose(Z_mu_orth, test_Z_mu_orth))
def test_variance_explained(self):
"""
Test GPFA explained_variance
"""
test_r2_score = 0.6648115733320232
r2_t1 = self.gpfa.variance_explained()[0]
# Assert
self.assertAlmostEqual(test_r2_score, r2_t1)