1+ from time import perf_counter
2+
3+ import jax
4+
5+ import nemos as nmo
6+ import pynapple as nap
7+ import numpy as np
8+ from scipy .optimize import minimize
9+
10+ jax .config .update ("jax_enable_x64" , True )
11+
12+ def neg_log_lik_lnp (theta , X , y , Cinv ):
13+ # Compute the Poisson log likelihood
14+ rate = np .exp (X @ theta )
15+ log_lik = y @ np .log (rate ) - rate .sum ()
16+ log_lik -= theta .T @ Cinv @ theta
17+
18+ return - log_lik
19+
20+ def fit_lnp (X , y , lam = 0 ):
21+ filt_len = X .shape [1 ]
22+ Imat = np .identity (filt_len ) # identity matrix of size of filter + const
23+ Imat [0 ,0 ] = 0
24+ Cinv = lam * Imat
25+
26+ # Use a random vector of weights to start (mean 0, sd .2)
27+ x0 = np .random .normal (0 , .2 , filt_len )
28+ print ("y:" ,y .shape ,"X:" ,X .shape ,"x0:" ,x0 .shape )
29+
30+ # Find parameters that minimize the negative log likelihood function
31+ res = minimize (neg_log_lik_lnp , x0 , args = (X , y , Cinv ))
32+
33+ return res ["x" ]
34+
35+ def predict (X , weights , constant ):
36+ y = np .exp (X @ weights + constant )
37+ return y
38+
39+ def predict_spikes (X , weights , constant ):
40+ rate = predict (X , weights , constant )
41+ spks = np .random .poisson (np .matrix .transpose (rate ))
42+ return spks
43+
44+ def retrieve_stim_info (color_code , features , flashes ):
45+ """Retrieve stimulus information based on color code.
46+
47+ Parameters
48+ ----------
49+ color_code :
50+ The color label (e.g., '-1.0' for black, '1.0 for white) to identify the stimulus.
51+ features :
52+ An array indicating which flash interval each timestamp belongs to.
53+
54+ Returns
55+ ----------
56+ color_feature:
57+ A binary array where 1 indicates the timestamp falls within a flash
58+ interval of the given color_code, and 0 otherwise.
59+ """
60+ # Get the indices of flash intervals where the color matches the given color_code
61+ intervals = flashes .index [flashes ["color" ] == color_code ]
62+ # Initialize an array of zeros with the same length as the features array
63+ color_feature = np .zeros (len (features ))
64+ # Create a boolean mask for entries in 'features' that match the target flash intervals
65+ mask = np .isin (features , intervals )
66+ # Mark the matching timestamps with 1
67+ color_feature [mask ] = 1
68+
69+ return color_feature
70+
71+
72+
73+ dandiset_id = "000021"
74+ dandi_filepath = "sub-726298249/sub-726298249_ses-754829445.nwb"
75+ download_loc = "."
76+
77+ path = "docs/higher-order/sub-726298249_ses-754829445.nwb"
78+ # t0 = perf_counter()
79+ # io = nmo.fetch.download_dandi_data(dandiset_id, dandi_filepath)
80+ # print(perf_counter() - t0)
81+ # nap_nwb = nap.NWBFile(io.read(), lazy_loading=True)
82+ #
83+
84+ nap_nwb = nap .load_file (path )
85+ nwb = nap_nwb .nwb
86+ channel_probes = {}
87+
88+ electrodes = nwb .electrodes
89+ for i in range (len (electrodes )):
90+ channel_id = electrodes ["id" ][i ]
91+ location = electrodes ["location" ][i ]
92+ channel_probes [channel_id ] = location
93+
94+ # function aligns location information from electrodes table with channel id from the units table
95+ def get_unit_location (unit_id ):
96+ return channel_probes [int (units [unit_id ].peak_channel_id )]
97+
98+ units = nap_nwb ["units" ]
99+ units .brain_area = [channel_probes [int (ch_id )] for ch_id in units .peak_channel_id ]
100+
101+ units = units [(units .quality == "good" ) & (units .brain_area == "VISp" ) & (units .firing_rate > 2. )]
102+
103+
104+
105+ flashes = nap_nwb ["flashes_presentations" ]
106+
107+ # Set start, end and bin size
108+ start = nap_nwb ["flashes_presentations" ].start .min ()
109+ end = nap_nwb ["flashes_presentations" ].end .max ()
110+ bin_sz = 0.05
111+
112+ counts = units .count (bin_sz , ep = nap .IntervalSet (start , end ))
113+
114+ # Create Tsd with timestamps corresponding to the desired time bins and bins sizes
115+ uniform = nap .Tsd (t = counts .t , d = np .ones (counts .t .shape [0 ]))
116+
117+ # For each desired timestamp, find the index of the flash interval it falls into.
118+ # Returns NaN for timestamps outside all intervals, and an index for those within.
119+ features = flashes .in_interval (uniform )
120+
121+ white_stimuli = retrieve_stim_info ("1.0" , features , flashes )
122+ black_stimuli = retrieve_stim_info ("-1.0" , features , flashes )
123+
124+ history_size = int (0.25 / bin_sz )
125+ bas = nmo .basis .HistoryConv (history_size , label = "w" ) + nmo .basis .HistoryConv (history_size , label = "b" )
126+
127+ X = bas .compute_features (white_stimuli , black_stimuli )
128+
129+ model = nmo .glm .GLM ()
130+ model .fit (X , counts [:, 0 ])
131+ rate = model .predict (X )
132+
133+ # first 5 of rate are nans (conv in mode valid + padding)
134+ intercept_plus_X = np .hstack ((np .ones ((X .shape [0 ], 1 )), X ))
135+ intercept_plus_coeff = np .hstack ((model .intercept_ , model .coef_ ))
136+ ll = model .observation_model ._negative_log_likelihood (counts [5 :,0 ], rate [5 :], aggregate_sample_scores = np .sum )
137+ ll2 = neg_log_lik_lnp (
138+ intercept_plus_coeff , intercept_plus_X [5 :], counts [5 :, 0 ], np .zeros ((11 , 11 ))
139+ )
140+
141+ # if this is 0 or close it means that that we are computing the same un-regularized likelihood (modulo using the mean
142+ # instead of sum, so our likelihood used in the fit is theirs divided by the number of samples, which doesn't make a
143+ # difference).
144+ print (ll - ll2 )
145+
146+ # add regularization
147+
148+ lam = 2 ** 5 # their value for the regulariser
149+
150+ # our penalized loss is loss - 0.5 * lam * coef @ coeff, so 2**5 * 2 == 2**6
151+ model = nmo .glm .GLM (regularizer = "Ridge" , regularizer_strength = lam * 2 / (X .shape [0 ] - 5 ))
152+ reg = model .regularizer
153+
154+
155+ # instead of fitting attach coeff
156+ model .coef_ = intercept_plus_coeff [1 :]
157+ model .intercept_ = intercept_plus_coeff [:1 ]
158+ # get the loss + penalty in nemos
159+ loss_with_penalty = model .regularizer .penalized_loss (
160+ model ._predict_and_compute_loss , model .regularizer_strength
161+ )
162+ ll_penalised = (X .shape [0 ] - 5 ) * loss_with_penalty ((model .coef_ , model .intercept_ ), X [5 :], counts [5 :,0 ].d )
163+
164+ Imat = np .identity (11 )
165+ Imat [0 , 0 ] = 0 # they do not regularise the intercept (same as as, good)
166+ Cinv = lam * Imat
167+
168+
169+ ll2_penalised = neg_log_lik_lnp (
170+ intercept_plus_coeff , intercept_plus_X [5 :], counts [5 :, 0 ], Cinv
171+ )
172+ print (ll_penalised - ll2_penalised )
0 commit comments