Skip to content

Commit 8b594ff

Browse files
authored
Merge pull request #453 from AllenInstitute/dev
Dev
2 parents b130e22 + 76b3ef6 commit 8b594ff

File tree

11 files changed

+4804
-4
lines changed

11 files changed

+4804
-4
lines changed

.github/workflows/deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- name: Installing python
1717
run: |
1818
sudo apt-get update
19-
sudo apt install python3.12
19+
sudo apt install python3.10.11
2020
sudo apt-get install build-essential
2121
2222
- name: Upgrading pip

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Set up Python
2525
uses: actions/setup-python@v5
2626
with:
27-
python-version: '3.9.13'
27+
python-version: '3.10.11'
2828

2929
# - name: Upgrading pip
3030
# run: pip install --upgrade pip

data/images/hankel_matrix.gif

112 KB
Loading

data/images/lnp_model.png

323 KB
Loading

data/images/visual_stimuli_set.png

78.5 KB
Loading

docs/_toc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ parts:
3434
chapters:
3535
- file: higher-order/cebra_time.ipynb
3636
- file: higher-order/tca.ipynb
37+
- file: higher-order/GLM_pynapple_nemos.ipynb
3738
- file: higher-order/glm.ipynb
3839
- file: higher-order/behavioral_state.ipynb
3940
- caption: Openscope Experimental Projects

docs/embargoed/cell_matching.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
},
6969
{
7070
"cell_type": "code",
71-
"execution_count": 3,
71+
"execution_count": null,
7272
"id": "e4511768",
7373
"metadata": {},
7474
"outputs": [],
@@ -82,7 +82,7 @@
8282
"# the subject ids should be the same, but the session ids should be different\n",
8383
"input_dandi_filepaths = [dandi_filepath_1, dandi_filepath_2]\n",
8484
"\n",
85-
"dandi_api_key = \"f9459a77200783c455ec6f3cb0b6cd92fc9fe106\""
85+
"dandi_api_key = os.environ['DANDI_API_KEY']"
8686
]
8787
},
8888
{

docs/higher-order/GLM_pynapple_nemos.ipynb

Lines changed: 4568 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from time import perf_counter
2+
3+
import jax
4+
5+
import nemos as nmo
6+
import pynapple as nap
7+
import numpy as np
8+
from scipy.optimize import minimize
9+
10+
jax.config.update("jax_enable_x64", True)
11+
12+
def neg_log_lik_lnp(theta, X, y, Cinv):
13+
# Compute the Poisson log likelihood
14+
rate = np.exp(X @ theta)
15+
log_lik = y @ np.log(rate) - rate.sum()
16+
log_lik -= theta.T @ Cinv @ theta
17+
18+
return -log_lik
19+
20+
def fit_lnp(X, y, lam=0):
21+
filt_len = X.shape[1]
22+
Imat = np.identity(filt_len) # identity matrix of size of filter + const
23+
Imat[0,0] = 0
24+
Cinv = lam*Imat
25+
26+
# Use a random vector of weights to start (mean 0, sd .2)
27+
x0 = np.random.normal(0, .2, filt_len)
28+
print("y:",y.shape,"X:",X.shape,"x0:",x0.shape)
29+
30+
# Find parameters that minimize the negative log likelihood function
31+
res = minimize(neg_log_lik_lnp, x0, args=(X, y, Cinv))
32+
33+
return res["x"]
34+
35+
def predict(X, weights, constant):
36+
y = np.exp(X @ weights + constant)
37+
return y
38+
39+
def predict_spikes(X, weights, constant):
40+
rate = predict(X, weights, constant)
41+
spks = np.random.poisson(np.matrix.transpose(rate))
42+
return spks
43+
44+
def retrieve_stim_info(color_code, features, flashes):
45+
"""Retrieve stimulus information based on color code.
46+
47+
Parameters
48+
----------
49+
color_code :
50+
The color label (e.g., '-1.0' for black, '1.0 for white) to identify the stimulus.
51+
features :
52+
An array indicating which flash interval each timestamp belongs to.
53+
54+
Returns
55+
----------
56+
color_feature:
57+
A binary array where 1 indicates the timestamp falls within a flash
58+
interval of the given color_code, and 0 otherwise.
59+
"""
60+
# Get the indices of flash intervals where the color matches the given color_code
61+
intervals = flashes.index[flashes["color"] == color_code]
62+
# Initialize an array of zeros with the same length as the features array
63+
color_feature = np.zeros(len(features))
64+
# Create a boolean mask for entries in 'features' that match the target flash intervals
65+
mask = np.isin(features, intervals)
66+
# Mark the matching timestamps with 1
67+
color_feature[mask] = 1
68+
69+
return color_feature
70+
71+
72+
73+
dandiset_id = "000021"
74+
dandi_filepath = "sub-726298249/sub-726298249_ses-754829445.nwb"
75+
download_loc = "."
76+
77+
path = "docs/higher-order/sub-726298249_ses-754829445.nwb"
78+
# t0 = perf_counter()
79+
# io = nmo.fetch.download_dandi_data(dandiset_id, dandi_filepath)
80+
# print(perf_counter() - t0)
81+
# nap_nwb = nap.NWBFile(io.read(), lazy_loading=True)
82+
#
83+
84+
nap_nwb = nap.load_file(path)
85+
nwb = nap_nwb.nwb
86+
channel_probes = {}
87+
88+
electrodes = nwb.electrodes
89+
for i in range(len(electrodes)):
90+
channel_id = electrodes["id"][i]
91+
location = electrodes["location"][i]
92+
channel_probes[channel_id] = location
93+
94+
# function aligns location information from electrodes table with channel id from the units table
95+
def get_unit_location(unit_id):
96+
return channel_probes[int(units[unit_id].peak_channel_id)]
97+
98+
units = nap_nwb["units"]
99+
units.brain_area = [channel_probes[int(ch_id)] for ch_id in units.peak_channel_id]
100+
101+
units = units[(units.quality == "good") & (units.brain_area == "VISp") & (units.firing_rate > 2.)]
102+
103+
104+
105+
flashes = nap_nwb["flashes_presentations"]
106+
107+
# Set start, end and bin size
108+
start = nap_nwb["flashes_presentations"].start.min()
109+
end = nap_nwb["flashes_presentations"].end.max()
110+
bin_sz = 0.05
111+
112+
counts = units.count(bin_sz, ep=nap.IntervalSet(start, end))
113+
114+
# Create Tsd with timestamps corresponding to the desired time bins and bins sizes
115+
uniform = nap.Tsd(t=counts.t, d=np.ones(counts.t.shape[0]))
116+
117+
# For each desired timestamp, find the index of the flash interval it falls into.
118+
# Returns NaN for timestamps outside all intervals, and an index for those within.
119+
features = flashes.in_interval(uniform)
120+
121+
white_stimuli = retrieve_stim_info("1.0", features, flashes)
122+
black_stimuli = retrieve_stim_info("-1.0", features, flashes)
123+
124+
history_size = int(0.25 / bin_sz)
125+
bas = nmo.basis.HistoryConv(history_size, label="w") + nmo.basis.HistoryConv(history_size, label="b")
126+
127+
X = bas.compute_features(white_stimuli, black_stimuli)
128+
129+
model = nmo.glm.GLM()
130+
model.fit(X, counts[:, 0])
131+
rate = model.predict(X)
132+
133+
# first 5 of rate are nans (conv in mode valid + padding)
134+
intercept_plus_X = np.hstack((np.ones((X.shape[0], 1)), X))
135+
intercept_plus_coeff = np.hstack((model.intercept_, model.coef_))
136+
ll = model.observation_model._negative_log_likelihood(counts[5:,0], rate[5:], aggregate_sample_scores=np.sum)
137+
ll2 = neg_log_lik_lnp(
138+
intercept_plus_coeff, intercept_plus_X[5:], counts[5:, 0], np.zeros((11, 11))
139+
)
140+
141+
# if this is 0 or close it means that that we are computing the same un-regularized likelihood (modulo using the mean
142+
# instead of sum, so our likelihood used in the fit is theirs divided by the number of samples, which doesn't make a
143+
# difference).
144+
print(ll - ll2)
145+
146+
# add regularization
147+
148+
lam = 2**5 # their value for the regulariser
149+
150+
# our penalized loss is loss - 0.5 * lam * coef @ coeff, so 2**5 * 2 == 2**6
151+
model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength= lam * 2 / (X.shape[0] - 5))
152+
reg = model.regularizer
153+
154+
155+
# instead of fitting attach coeff
156+
model.coef_ = intercept_plus_coeff[1:]
157+
model.intercept_ = intercept_plus_coeff[:1]
158+
# get the loss + penalty in nemos
159+
loss_with_penalty = model.regularizer.penalized_loss(
160+
model._predict_and_compute_loss, model.regularizer_strength
161+
)
162+
ll_penalised = (X.shape[0] - 5) * loss_with_penalty((model.coef_, model.intercept_), X[5:], counts[5:,0].d)
163+
164+
Imat = np.identity(11)
165+
Imat[0, 0] = 0 # they do not regularise the intercept (same as as, good)
166+
Cinv = lam * Imat
167+
168+
169+
ll2_penalised = neg_log_lik_lnp(
170+
intercept_plus_coeff, intercept_plus_X[5:], counts[5:, 0], Cinv
171+
)
172+
print(ll_penalised - ll2_penalised)

docs/references.bib

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,56 @@
1+
@techreport{NeuropixelsWhitePaper2019,
2+
title = {Allen Brain Observatory - Neuropixels Visual Coding - Technical White Paper},
3+
author = {{Allen Institute for Brain Science}},
4+
institution = {{Allen Institute for Brain Science}},
5+
url = {https://brainmapportal-live-4cc80a57cd6e400d854-f7fdcae.divio-media.net/filer_public/80/75/8075a100-ca64-429a-b39a-569121b612b2/neuropixels_visual_coding_-_white_paper_v10.pdf},
6+
year = {2019},
7+
month = {October},
8+
}
9+
10+
@misc{PillowCosyneTutorial,
11+
title = { Jonathan Pillow - Tutorial: Statistical models for neural data - Part 1 (Cosyne 2018)},
12+
author = {Jonathan Pillow},
13+
howpublished = {YouTube},
14+
url = {https://www.youtube.com/watch?v=NFeGW5ljUoI},
15+
year = {2018},
16+
month = {March},
17+
}
18+
19+
@article{pillowPredictionDecodingRetinal2005,
20+
title = {Prediction and {{Decoding}} of {{Retinal Ganglion Cell Responses}} with a {{Probabilistic Spiking Model}}},
21+
author = {Pillow, Jonathan W. and Paninski, Liam and Uzzell, Valerie J. and Simoncelli, Eero P. and Chichilnisky, E. J.},
22+
year = {2005},
23+
journal = {Journal of Neuroscience},
24+
volume = {25},
25+
number = {47},
26+
pages = {11003--11013},
27+
publisher = {Society for Neuroscience},
28+
issn = {0270-6474, 1529-2401},
29+
doi = {10.1523/JNEUROSCI.3305-05.2005},
30+
urldate = {2025-07-06},
31+
chapter = {Behavioral/Systems/Cognitive},
32+
copyright = {Copyright {\copyright} 2005 Society for Neuroscience 0270-6474/05/2511003-11.00/0},
33+
langid = {english},
34+
pmid = {16306413},
35+
keywords = {computational model,decoding,integrate and fire,neural coding,precision,retinal ganglion cell,spike timing,spike trains,variability},
36+
}
37+
38+
@article{pillowSpatiotemporalCorrelationsVisual2008,
39+
title = {Spatio-Temporal Correlations and Visual Signalling in a Complete Neuronal Population},
40+
author = {Pillow, Jonathan W. and Shlens, Jonathon and Paninski, Liam and Sher, Alexander and Litke, Alan M. and Chichilnisky, E. J. and Simoncelli, Eero P.},
41+
year = {2008},
42+
journal = {Nature},
43+
volume = {454},
44+
number = {7207},
45+
pages = {995--999},
46+
publisher = {Nature Publishing Group},
47+
issn = {1476-4687},
48+
doi = {10.1038/nature07140},
49+
urldate = {2025-07-06},
50+
copyright = {2008 Macmillan Publishers Limited. All rights reserved},
51+
langid = {english},
52+
keywords = {Humanities and Social Sciences,multidisciplinary,Science},
53+
}
154
@article{Rubel2022
255
, title = {The Neurodata Without Borders ecosystem for neurophysiological data science}
356
, author = {Oliver Rübel and Andrew Tritt and Ryan Ly and Benjamin K Dichter and Satrajit Ghosh and Lawrence Niu and Pamela Baker and Ivan Soltesz and Lydia Ng and Karel Svoboda and Loren Frank and Kristofer E Bouchard}

0 commit comments

Comments
 (0)