1212 block_kde ,
1313 gaussian_pdf ,
1414 get_position_at_time ,
15+ log_gaussian_pdf ,
1516)
1617
1718
@@ -62,13 +63,109 @@ def kde_distance(
6263 return distance
6364
6465
66+ @jax .jit
67+ def log_kde_distance (
68+ eval_points : jnp .ndarray , samples : jnp .ndarray , std : jnp .ndarray
69+ ) -> jnp .ndarray :
70+ """Log-distance (log kernel product) between eval points and samples using Gaussian kernels.
71+
72+ Computes:
73+ log_distance[i, j] = sum_d log N(eval_points[j, d] | samples[i, d], std[d])
74+
75+ Parameters
76+ ----------
77+ eval_points : jnp.ndarray, shape (n_eval_points, n_dims)
78+ Evaluation points.
79+ samples : jnp.ndarray, shape (n_samples, n_dims)
80+ Training samples.
81+ std : jnp.ndarray, shape (n_dims,)
82+ Per-dimension kernel std.
83+
84+ Returns
85+ -------
86+ log_distance : jnp.ndarray, shape (n_samples, n_eval_points)
87+ Log of the product of per-dimension Gaussian kernels.
88+ """
89+ log_dist = jnp .zeros ((samples .shape [0 ], eval_points .shape [0 ]))
90+ for dim_eval , dim_samp , dim_std in zip (eval_points .T , samples .T , std , strict = False ):
91+ log_dist += log_gaussian_pdf (
92+ jnp .expand_dims (dim_eval , axis = 0 ), # (1, n_eval)
93+ jnp .expand_dims (dim_samp , axis = 1 ), # (n_samples, 1)
94+ dim_std ,
95+ )
96+ return log_dist
97+
98+
99+ def _compute_log_mark_kernel_gemm (
100+ decoding_features : jnp .ndarray ,
101+ encoding_features : jnp .ndarray ,
102+ waveform_stds : jnp .ndarray ,
103+ ) -> jnp .ndarray :
104+ """Compute log mark kernel using GEMM (matrix multiplication) instead of per-dimension loop.
105+
106+ This is mathematically equivalent to the loop-based approach but much faster for
107+ multi-dimensional features. The Gaussian kernel in log-space:
108+
109+ log K(x, y) = -0.5 * sum_d [(x_d - y_d)^2 / sigma_d^2] - log_norm_const
110+ = -0.5 * sum_d [(x_d/sigma_d)^2 + (y_d/sigma_d)^2 - 2*(x_d/sigma_d)*(y_d/sigma_d)] - log_norm_const
111+ = -0.5 * (||x_scaled||^2 + ||y_scaled||^2 - 2 * x_scaled @ y_scaled^T) - log_norm_const
112+
113+ The cross term x_scaled @ y_scaled^T is a single matrix multiply (GEMM).
114+
115+ Parameters
116+ ----------
117+ decoding_features : jnp.ndarray, shape (n_decoding_spikes, n_features)
118+ Waveform features for decoding spikes.
119+ encoding_features : jnp.ndarray, shape (n_encoding_spikes, n_features)
120+ Waveform features for encoding spikes.
121+ waveform_stds : jnp.ndarray, shape (n_features,)
122+ Standard deviations for each feature dimension.
123+
124+ Returns
125+ -------
126+ logK_mark : jnp.ndarray, shape (n_encoding_spikes, n_decoding_spikes)
127+ Log kernel matrix K[i, j] = log(Gaussian kernel between encoding spike i and decoding spike j).
128+ """
129+ n_features = waveform_stds .shape [0 ]
130+
131+ # Precompute inverse standard deviations and normalization constant
132+ inv_sigma = 1.0 / waveform_stds # (n_features,)
133+
134+ # Log normalization constant: -0.5 * (D * log(2π) + 2 * sum(log(sigma)))
135+ # Factor of 2 because we have sum of log(sigma), not log(sigma^2)
136+ log_norm_const = - 0.5 * (
137+ n_features * jnp .log (2.0 * jnp .pi ) + 2.0 * jnp .sum (jnp .log (waveform_stds ))
138+ )
139+
140+ # Scale features by inverse standard deviations
141+ Y = encoding_features * inv_sigma [None , :] # (n_enc, n_features)
142+ X = decoding_features * inv_sigma [None , :] # (n_dec, n_features)
143+
144+ # Compute squared norms
145+ y2 = jnp .sum (Y ** 2 , axis = 1 ) # (n_enc,)
146+ x2 = jnp .sum (X ** 2 , axis = 1 ) # (n_dec,)
147+
148+ # GEMM: compute cross terms X @ Y^T = (n_dec, n_features) @ (n_features, n_enc)
149+ cross_term = X @ Y .T # (n_dec, n_enc)
150+
151+ # Combine: log K[i,j] = -0.5 * (y2[i] + x2[j] - 2*cross_term[j,i]) + log_norm_const
152+ # Note: We need (n_enc, n_dec) output, so transpose the cross term
153+ logK_mark = log_norm_const - 0.5 * (
154+ y2 [:, None ] + x2 [None , :] - 2.0 * cross_term .T
155+ ) # (n_enc, n_dec)
156+
157+ return logK_mark
158+
159+
65160def estimate_log_joint_mark_intensity (
66161 decoding_spike_waveform_features : jnp .ndarray ,
67162 encoding_spike_waveform_features : jnp .ndarray ,
68163 waveform_stds : jnp .ndarray ,
69164 occupancy : jnp .ndarray ,
70165 mean_rate : float ,
71166 position_distance : jnp .ndarray ,
167+ use_gemm : bool = True ,
168+ pos_tile_size : int | None = None ,
72169) -> jnp .ndarray :
73170 """Estimate the log joint mark intensity of decoding spikes and spike waveforms.
74171
@@ -80,26 +177,109 @@ def estimate_log_joint_mark_intensity(
80177 occupancy : jnp.ndarray, shape (n_position_bins,)
81178 mean_rate : float
82179 position_distance : jnp.ndarray, shape (n_encoding_spikes, n_position_bins)
180+ use_gemm : bool, optional
181+ If True (default), use GEMM-based log-space computation (faster for multi-dimensional features).
182+ If False, use linear-space computation (matches reference exactly).
183+ pos_tile_size : int | None, optional
184+ If provided, tile computation over position dimension in chunks (only for use_gemm=True).
83185
84186 Returns
85187 -------
86188 log_joint_mark_intensity : jnp.ndarray, shape (n_decoding_spikes, n_position_bins)
87189
88190 """
89- spike_waveform_feature_distance = kde_distance (
191+ n_encoding_spikes = encoding_spike_waveform_features .shape [0 ]
192+
193+ if not use_gemm :
194+ # Linear-space computation (matches reference exactly)
195+ spike_waveform_feature_distance = kde_distance (
196+ decoding_spike_waveform_features ,
197+ encoding_spike_waveform_features ,
198+ waveform_stds ,
199+ ) # shape (n_encoding_spikes, n_decoding_spikes)
200+
201+ marginal_density = (
202+ spike_waveform_feature_distance .T @ position_distance / n_encoding_spikes
203+ ) # shape (n_decoding_spikes, n_position_bins)
204+ return jnp .log (
205+ mean_rate * jnp .where (occupancy > 0.0 , marginal_density / occupancy , 0.0 )
206+ )
207+
208+ # Log-space computation with GEMM optimization
209+ # Build log-kernel matrix for marks: (n_enc, n_dec)
210+ logK_mark = _compute_log_mark_kernel_gemm (
90211 decoding_spike_waveform_features ,
91212 encoding_spike_waveform_features ,
92213 waveform_stds ,
93- ) # shape (n_encoding_spikes, n_decoding_spikes)
214+ )
94215
95- n_encoding_spikes = encoding_spike_waveform_features .shape [0 ]
96- marginal_density = (
97- spike_waveform_feature_distance .T @ position_distance / n_encoding_spikes
98- ) # shape (n_decoding_spikes, n_position_bins)
99- return jnp .log (
100- mean_rate * jnp .where (occupancy > 0.0 , marginal_density / occupancy , 0.0 )
216+ # Convert position_distance to log-space
217+ log_position_distance = jnp .log (position_distance )
218+
219+ # Uniform weights: log(1/n) for each encoding spike
220+ log_w = - jnp .log (float (n_encoding_spikes ))
221+
222+ # Use scan to avoid materializing (n_enc × n_dec × n_pos) array
223+ n_pos = log_position_distance .shape [1 ]
224+ n_dec = logK_mark .shape [1 ]
225+
226+ if pos_tile_size is None or pos_tile_size >= n_pos :
227+ # No tiling: process all positions at once (default, fastest)
228+ def scan_over_dec (carry , y_col : jnp .ndarray ) -> tuple [None , jnp .ndarray ]:
229+ # y_col: (n_enc,), the column of logK_mark for one decoding spike
230+ # returns: (n_pos,), logsumexp over enc dimension
231+ result = jax .nn .logsumexp (
232+ log_w + log_position_distance + y_col [:, None ], axis = 0
233+ )
234+ return None , result
235+
236+ # scan over decoding spikes' columns -> (n_dec, n_pos)
237+ _ , log_marginal = jax .lax .scan (scan_over_dec , None , logK_mark .T )
238+ else :
239+ # Tiled: process positions in chunks to reduce peak memory
240+ log_marginal = jnp .zeros ((n_dec , n_pos ))
241+
242+ for pos_start in range (0 , n_pos , pos_tile_size ):
243+ pos_end = min (pos_start + pos_tile_size , n_pos )
244+ pos_slice = slice (pos_start , pos_end )
245+
246+ # Tile: slice of log_position_distance for this chunk of positions
247+ log_pos_tile = log_position_distance [:, pos_slice ] # (n_enc, tile_size)
248+
249+ # Create closure to capture log_pos_tile properly
250+ def make_scan_fn (tile ):
251+ def scan_over_dec_tile (
252+ carry , y_col : jnp .ndarray
253+ ) -> tuple [None , jnp .ndarray ]:
254+ # y_col: (n_enc,)
255+ # returns: (tile_size,), logsumexp over enc dimension
256+ result = jax .nn .logsumexp (log_w + tile + y_col [:, None ], axis = 0 )
257+ return None , result
258+
259+ return scan_over_dec_tile
260+
261+ # scan over decoding spikes for this position tile -> (n_dec, tile_size)
262+ _ , log_marginal_tile = jax .lax .scan (
263+ make_scan_fn (log_pos_tile ), None , logK_mark .T
264+ )
265+
266+ # Update output with this tile
267+ log_marginal = log_marginal .at [:, pos_slice ].set (log_marginal_tile )
268+
269+ # Add mean rate and subtract occupancy (in log)
270+ log_mean_rate = jnp .log (mean_rate )
271+ log_occ = jnp .log (jnp .where (occupancy > 0.0 , occupancy , 1.0 )) # avoid log(0)
272+
273+ # Result: log(mean_rate * marginal / occupancy)
274+ # Use where to handle occupancy = 0 cases
275+ log_joint = jnp .where (
276+ occupancy [None , :] > 0.0 ,
277+ log_mean_rate + log_marginal - log_occ [None , :],
278+ jnp .log (0.0 ), # -inf for zero occupancy
101279 )
102280
281+ return log_joint
282+
103283
104284def block_estimate_log_joint_mark_intensity (
105285 decoding_spike_waveform_features : jnp .ndarray ,
@@ -109,6 +289,8 @@ def block_estimate_log_joint_mark_intensity(
109289 mean_rate : float ,
110290 position_distance : jnp .ndarray ,
111291 block_size : int = 100 ,
292+ use_gemm : bool = True ,
293+ pos_tile_size : int | None = None ,
112294) -> jnp .ndarray :
113295 """Estimate the log joint mark intensity of decoding spikes and spike waveforms.
114296
@@ -121,6 +303,10 @@ def block_estimate_log_joint_mark_intensity(
121303 mean_rate : float
122304 position_distance : jnp.ndarray, shape (n_encoding_spikes, n_position_bins)
123305 block_size : int, optional
306+ use_gemm : bool, optional
307+ If True (default), use GEMM-based log-space computation.
308+ pos_tile_size : int | None, optional
309+ If provided, tile computation over position dimension.
124310
125311 Returns
126312 -------
@@ -130,24 +316,31 @@ def block_estimate_log_joint_mark_intensity(
130316 n_decoding_spikes = decoding_spike_waveform_features .shape [0 ]
131317 n_position_bins = occupancy .shape [0 ]
132318
133- log_joint_mark_intensity = jnp .zeros ((n_decoding_spikes , n_position_bins ))
319+ if n_decoding_spikes == 0 :
320+ return jnp .full ((0 , n_position_bins ), LOG_EPS )
134321
322+ # Use JIT-compiled update with buffer donation for memory efficiency
323+ # Donate the accumulator buffer (arg 0) so it can be reused in-place
324+ @jax .jit
325+ def _update_block (out_array , block_result , start_idx ):
326+ return jax .lax .dynamic_update_slice (out_array , block_result , (start_idx , 0 ))
327+
328+ out = jnp .zeros ((n_decoding_spikes , n_position_bins ))
135329 for start_ind in range (0 , n_decoding_spikes , block_size ):
136330 block_inds = slice (start_ind , start_ind + block_size )
137- log_joint_mark_intensity = jax .lax .dynamic_update_slice (
138- log_joint_mark_intensity ,
139- estimate_log_joint_mark_intensity (
140- decoding_spike_waveform_features [block_inds ],
141- encoding_spike_waveform_features ,
142- waveform_stds ,
143- occupancy ,
144- mean_rate ,
145- position_distance ,
146- ),
147- (start_ind , 0 ),
331+ block_result = estimate_log_joint_mark_intensity (
332+ decoding_spike_waveform_features [block_inds ],
333+ encoding_spike_waveform_features ,
334+ waveform_stds ,
335+ occupancy ,
336+ mean_rate ,
337+ position_distance ,
338+ use_gemm = use_gemm ,
339+ pos_tile_size = pos_tile_size ,
148340 )
341+ out = _update_block (out , block_result , start_ind )
149342
150- return jnp .clip (log_joint_mark_intensity , a_min = LOG_EPS , a_max = None )
343+ return jnp .clip (out , a_min = LOG_EPS , a_max = None )
151344
152345
153346def fit_clusterless_kde_encoding_model (
0 commit comments