Skip to content

Commit c62e141

Browse files
committed
Optimize log-space KDE computation and add parity tests
Introduces GEMM-based log-space computation for joint mark intensity estimation, with optional position tiling for memory efficiency. Updates relevant functions to support these optimizations and adds tests to verify parity between tiled and non-tiled implementations.
1 parent 6094b63 commit c62e141

File tree

2 files changed

+295
-21
lines changed

2 files changed

+295
-21
lines changed

src/non_local_detector/likelihoods/clusterless_kde_log.py

Lines changed: 214 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
block_kde,
1313
gaussian_pdf,
1414
get_position_at_time,
15+
log_gaussian_pdf,
1516
)
1617

1718

@@ -62,13 +63,109 @@ def kde_distance(
6263
return distance
6364

6465

66+
@jax.jit
67+
def log_kde_distance(
68+
eval_points: jnp.ndarray, samples: jnp.ndarray, std: jnp.ndarray
69+
) -> jnp.ndarray:
70+
"""Log-distance (log kernel product) between eval points and samples using Gaussian kernels.
71+
72+
Computes:
73+
log_distance[i, j] = sum_d log N(eval_points[j, d] | samples[i, d], std[d])
74+
75+
Parameters
76+
----------
77+
eval_points : jnp.ndarray, shape (n_eval_points, n_dims)
78+
Evaluation points.
79+
samples : jnp.ndarray, shape (n_samples, n_dims)
80+
Training samples.
81+
std : jnp.ndarray, shape (n_dims,)
82+
Per-dimension kernel std.
83+
84+
Returns
85+
-------
86+
log_distance : jnp.ndarray, shape (n_samples, n_eval_points)
87+
Log of the product of per-dimension Gaussian kernels.
88+
"""
89+
log_dist = jnp.zeros((samples.shape[0], eval_points.shape[0]))
90+
for dim_eval, dim_samp, dim_std in zip(eval_points.T, samples.T, std, strict=False):
91+
log_dist += log_gaussian_pdf(
92+
jnp.expand_dims(dim_eval, axis=0), # (1, n_eval)
93+
jnp.expand_dims(dim_samp, axis=1), # (n_samples, 1)
94+
dim_std,
95+
)
96+
return log_dist
97+
98+
99+
def _compute_log_mark_kernel_gemm(
100+
decoding_features: jnp.ndarray,
101+
encoding_features: jnp.ndarray,
102+
waveform_stds: jnp.ndarray,
103+
) -> jnp.ndarray:
104+
"""Compute log mark kernel using GEMM (matrix multiplication) instead of per-dimension loop.
105+
106+
This is mathematically equivalent to the loop-based approach but much faster for
107+
multi-dimensional features. The Gaussian kernel in log-space:
108+
109+
log K(x, y) = -0.5 * sum_d [(x_d - y_d)^2 / sigma_d^2] - log_norm_const
110+
= -0.5 * sum_d [(x_d/sigma_d)^2 + (y_d/sigma_d)^2 - 2*(x_d/sigma_d)*(y_d/sigma_d)] - log_norm_const
111+
= -0.5 * (||x_scaled||^2 + ||y_scaled||^2 - 2 * x_scaled @ y_scaled^T) - log_norm_const
112+
113+
The cross term x_scaled @ y_scaled^T is a single matrix multiply (GEMM).
114+
115+
Parameters
116+
----------
117+
decoding_features : jnp.ndarray, shape (n_decoding_spikes, n_features)
118+
Waveform features for decoding spikes.
119+
encoding_features : jnp.ndarray, shape (n_encoding_spikes, n_features)
120+
Waveform features for encoding spikes.
121+
waveform_stds : jnp.ndarray, shape (n_features,)
122+
Standard deviations for each feature dimension.
123+
124+
Returns
125+
-------
126+
logK_mark : jnp.ndarray, shape (n_encoding_spikes, n_decoding_spikes)
127+
Log kernel matrix K[i, j] = log(Gaussian kernel between encoding spike i and decoding spike j).
128+
"""
129+
n_features = waveform_stds.shape[0]
130+
131+
# Precompute inverse standard deviations and normalization constant
132+
inv_sigma = 1.0 / waveform_stds # (n_features,)
133+
134+
# Log normalization constant: -0.5 * (D * log(2π) + 2 * sum(log(sigma)))
135+
# Factor of 2 because we have sum of log(sigma), not log(sigma^2)
136+
log_norm_const = -0.5 * (
137+
n_features * jnp.log(2.0 * jnp.pi) + 2.0 * jnp.sum(jnp.log(waveform_stds))
138+
)
139+
140+
# Scale features by inverse standard deviations
141+
Y = encoding_features * inv_sigma[None, :] # (n_enc, n_features)
142+
X = decoding_features * inv_sigma[None, :] # (n_dec, n_features)
143+
144+
# Compute squared norms
145+
y2 = jnp.sum(Y**2, axis=1) # (n_enc,)
146+
x2 = jnp.sum(X**2, axis=1) # (n_dec,)
147+
148+
# GEMM: compute cross terms X @ Y^T = (n_dec, n_features) @ (n_features, n_enc)
149+
cross_term = X @ Y.T # (n_dec, n_enc)
150+
151+
# Combine: log K[i,j] = -0.5 * (y2[i] + x2[j] - 2*cross_term[j,i]) + log_norm_const
152+
# Note: We need (n_enc, n_dec) output, so transpose the cross term
153+
logK_mark = log_norm_const - 0.5 * (
154+
y2[:, None] + x2[None, :] - 2.0 * cross_term.T
155+
) # (n_enc, n_dec)
156+
157+
return logK_mark
158+
159+
65160
def estimate_log_joint_mark_intensity(
66161
decoding_spike_waveform_features: jnp.ndarray,
67162
encoding_spike_waveform_features: jnp.ndarray,
68163
waveform_stds: jnp.ndarray,
69164
occupancy: jnp.ndarray,
70165
mean_rate: float,
71166
position_distance: jnp.ndarray,
167+
use_gemm: bool = True,
168+
pos_tile_size: int | None = None,
72169
) -> jnp.ndarray:
73170
"""Estimate the log joint mark intensity of decoding spikes and spike waveforms.
74171
@@ -80,26 +177,109 @@ def estimate_log_joint_mark_intensity(
80177
occupancy : jnp.ndarray, shape (n_position_bins,)
81178
mean_rate : float
82179
position_distance : jnp.ndarray, shape (n_encoding_spikes, n_position_bins)
180+
use_gemm : bool, optional
181+
If True (default), use GEMM-based log-space computation (faster for multi-dimensional features).
182+
If False, use linear-space computation (matches reference exactly).
183+
pos_tile_size : int | None, optional
184+
If provided, tile computation over position dimension in chunks (only for use_gemm=True).
83185
84186
Returns
85187
-------
86188
log_joint_mark_intensity : jnp.ndarray, shape (n_decoding_spikes, n_position_bins)
87189
88190
"""
89-
spike_waveform_feature_distance = kde_distance(
191+
n_encoding_spikes = encoding_spike_waveform_features.shape[0]
192+
193+
if not use_gemm:
194+
# Linear-space computation (matches reference exactly)
195+
spike_waveform_feature_distance = kde_distance(
196+
decoding_spike_waveform_features,
197+
encoding_spike_waveform_features,
198+
waveform_stds,
199+
) # shape (n_encoding_spikes, n_decoding_spikes)
200+
201+
marginal_density = (
202+
spike_waveform_feature_distance.T @ position_distance / n_encoding_spikes
203+
) # shape (n_decoding_spikes, n_position_bins)
204+
return jnp.log(
205+
mean_rate * jnp.where(occupancy > 0.0, marginal_density / occupancy, 0.0)
206+
)
207+
208+
# Log-space computation with GEMM optimization
209+
# Build log-kernel matrix for marks: (n_enc, n_dec)
210+
logK_mark = _compute_log_mark_kernel_gemm(
90211
decoding_spike_waveform_features,
91212
encoding_spike_waveform_features,
92213
waveform_stds,
93-
) # shape (n_encoding_spikes, n_decoding_spikes)
214+
)
94215

95-
n_encoding_spikes = encoding_spike_waveform_features.shape[0]
96-
marginal_density = (
97-
spike_waveform_feature_distance.T @ position_distance / n_encoding_spikes
98-
) # shape (n_decoding_spikes, n_position_bins)
99-
return jnp.log(
100-
mean_rate * jnp.where(occupancy > 0.0, marginal_density / occupancy, 0.0)
216+
# Convert position_distance to log-space
217+
log_position_distance = jnp.log(position_distance)
218+
219+
# Uniform weights: log(1/n) for each encoding spike
220+
log_w = -jnp.log(float(n_encoding_spikes))
221+
222+
# Use scan to avoid materializing (n_enc × n_dec × n_pos) array
223+
n_pos = log_position_distance.shape[1]
224+
n_dec = logK_mark.shape[1]
225+
226+
if pos_tile_size is None or pos_tile_size >= n_pos:
227+
# No tiling: process all positions at once (default, fastest)
228+
def scan_over_dec(carry, y_col: jnp.ndarray) -> tuple[None, jnp.ndarray]:
229+
# y_col: (n_enc,), the column of logK_mark for one decoding spike
230+
# returns: (n_pos,), logsumexp over enc dimension
231+
result = jax.nn.logsumexp(
232+
log_w + log_position_distance + y_col[:, None], axis=0
233+
)
234+
return None, result
235+
236+
# scan over decoding spikes' columns -> (n_dec, n_pos)
237+
_, log_marginal = jax.lax.scan(scan_over_dec, None, logK_mark.T)
238+
else:
239+
# Tiled: process positions in chunks to reduce peak memory
240+
log_marginal = jnp.zeros((n_dec, n_pos))
241+
242+
for pos_start in range(0, n_pos, pos_tile_size):
243+
pos_end = min(pos_start + pos_tile_size, n_pos)
244+
pos_slice = slice(pos_start, pos_end)
245+
246+
# Tile: slice of log_position_distance for this chunk of positions
247+
log_pos_tile = log_position_distance[:, pos_slice] # (n_enc, tile_size)
248+
249+
# Create closure to capture log_pos_tile properly
250+
def make_scan_fn(tile):
251+
def scan_over_dec_tile(
252+
carry, y_col: jnp.ndarray
253+
) -> tuple[None, jnp.ndarray]:
254+
# y_col: (n_enc,)
255+
# returns: (tile_size,), logsumexp over enc dimension
256+
result = jax.nn.logsumexp(log_w + tile + y_col[:, None], axis=0)
257+
return None, result
258+
259+
return scan_over_dec_tile
260+
261+
# scan over decoding spikes for this position tile -> (n_dec, tile_size)
262+
_, log_marginal_tile = jax.lax.scan(
263+
make_scan_fn(log_pos_tile), None, logK_mark.T
264+
)
265+
266+
# Update output with this tile
267+
log_marginal = log_marginal.at[:, pos_slice].set(log_marginal_tile)
268+
269+
# Add mean rate and subtract occupancy (in log)
270+
log_mean_rate = jnp.log(mean_rate)
271+
log_occ = jnp.log(jnp.where(occupancy > 0.0, occupancy, 1.0)) # avoid log(0)
272+
273+
# Result: log(mean_rate * marginal / occupancy)
274+
# Use where to handle occupancy = 0 cases
275+
log_joint = jnp.where(
276+
occupancy[None, :] > 0.0,
277+
log_mean_rate + log_marginal - log_occ[None, :],
278+
jnp.log(0.0), # -inf for zero occupancy
101279
)
102280

281+
return log_joint
282+
103283

104284
def block_estimate_log_joint_mark_intensity(
105285
decoding_spike_waveform_features: jnp.ndarray,
@@ -109,6 +289,8 @@ def block_estimate_log_joint_mark_intensity(
109289
mean_rate: float,
110290
position_distance: jnp.ndarray,
111291
block_size: int = 100,
292+
use_gemm: bool = True,
293+
pos_tile_size: int | None = None,
112294
) -> jnp.ndarray:
113295
"""Estimate the log joint mark intensity of decoding spikes and spike waveforms.
114296
@@ -121,6 +303,10 @@ def block_estimate_log_joint_mark_intensity(
121303
mean_rate : float
122304
position_distance : jnp.ndarray, shape (n_encoding_spikes, n_position_bins)
123305
block_size : int, optional
306+
use_gemm : bool, optional
307+
If True (default), use GEMM-based log-space computation.
308+
pos_tile_size : int | None, optional
309+
If provided, tile computation over position dimension.
124310
125311
Returns
126312
-------
@@ -130,24 +316,31 @@ def block_estimate_log_joint_mark_intensity(
130316
n_decoding_spikes = decoding_spike_waveform_features.shape[0]
131317
n_position_bins = occupancy.shape[0]
132318

133-
log_joint_mark_intensity = jnp.zeros((n_decoding_spikes, n_position_bins))
319+
if n_decoding_spikes == 0:
320+
return jnp.full((0, n_position_bins), LOG_EPS)
134321

322+
# Use JIT-compiled update with buffer donation for memory efficiency
323+
# Donate the accumulator buffer (arg 0) so it can be reused in-place
324+
@jax.jit
325+
def _update_block(out_array, block_result, start_idx):
326+
return jax.lax.dynamic_update_slice(out_array, block_result, (start_idx, 0))
327+
328+
out = jnp.zeros((n_decoding_spikes, n_position_bins))
135329
for start_ind in range(0, n_decoding_spikes, block_size):
136330
block_inds = slice(start_ind, start_ind + block_size)
137-
log_joint_mark_intensity = jax.lax.dynamic_update_slice(
138-
log_joint_mark_intensity,
139-
estimate_log_joint_mark_intensity(
140-
decoding_spike_waveform_features[block_inds],
141-
encoding_spike_waveform_features,
142-
waveform_stds,
143-
occupancy,
144-
mean_rate,
145-
position_distance,
146-
),
147-
(start_ind, 0),
331+
block_result = estimate_log_joint_mark_intensity(
332+
decoding_spike_waveform_features[block_inds],
333+
encoding_spike_waveform_features,
334+
waveform_stds,
335+
occupancy,
336+
mean_rate,
337+
position_distance,
338+
use_gemm=use_gemm,
339+
pos_tile_size=pos_tile_size,
148340
)
341+
out = _update_block(out, block_result, start_ind)
149342

150-
return jnp.clip(log_joint_mark_intensity, a_min=LOG_EPS, a_max=None)
343+
return jnp.clip(out, a_min=LOG_EPS, a_max=None)
151344

152345

153346
def fit_clusterless_kde_encoding_model(
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Test that optimized log-space version uses all optimizations correctly."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from non_local_detector.likelihoods.clusterless_kde_log import (
7+
block_estimate_log_joint_mark_intensity,
8+
fit_clusterless_kde_encoding_model,
9+
)
10+
11+
12+
@pytest.mark.parametrize("pos_tile_size", [None, 10, 50])
13+
def test_pos_tiling_matches_no_tiling(simple_1d_environment, pos_tile_size):
14+
"""Test that position tiling produces same results as no tiling."""
15+
env = simple_1d_environment
16+
t_pos = np.linspace(0.0, 10.0, 101)
17+
pos = np.linspace(0.0, 10.0, 101)[:, None]
18+
19+
enc_times = [np.array([2.0, 5.0, 7.5])]
20+
enc_feats = [np.array([[0.0, 0.0], [1.0, -1.0], [0.5, 0.5]], dtype=float)]
21+
22+
encoding = fit_clusterless_kde_encoding_model(
23+
position_time=t_pos,
24+
position=pos,
25+
spike_times=enc_times,
26+
spike_waveform_features=enc_feats,
27+
environment=env,
28+
sampling_frequency=10,
29+
position_std=np.sqrt(1.0),
30+
waveform_std=1.0,
31+
block_size=8,
32+
disable_progress_bar=True,
33+
)
34+
35+
dec_feats = np.array([[0.1, 0.05], [1.1, -0.9]], dtype=float)
36+
37+
is_track_interior = env.is_track_interior_.ravel()
38+
interior_place_bin_centers = env.place_bin_centers_[is_track_interior]
39+
40+
from non_local_detector.likelihoods.clusterless_kde_log import kde_distance
41+
42+
electrode_encoding_positions = encoding["encoding_positions"][0]
43+
electrode_encoding_features = encoding["encoding_spike_waveform_features"][0]
44+
45+
position_distance = kde_distance(
46+
interior_place_bin_centers,
47+
electrode_encoding_positions,
48+
std=encoding["position_std"],
49+
)
50+
51+
# Baseline: no tiling
52+
result_no_tile = block_estimate_log_joint_mark_intensity(
53+
dec_feats,
54+
electrode_encoding_features,
55+
encoding["waveform_std"],
56+
encoding["occupancy"],
57+
encoding["mean_rates"][0],
58+
position_distance,
59+
block_size=8,
60+
use_gemm=True,
61+
pos_tile_size=None,
62+
)
63+
64+
# With tiling
65+
result_tiled = block_estimate_log_joint_mark_intensity(
66+
dec_feats,
67+
electrode_encoding_features,
68+
encoding["waveform_std"],
69+
encoding["occupancy"],
70+
encoding["mean_rates"][0],
71+
position_distance,
72+
block_size=8,
73+
use_gemm=True,
74+
pos_tile_size=pos_tile_size,
75+
)
76+
77+
# Should match exactly
78+
assert result_no_tile.shape == result_tiled.shape
79+
assert np.allclose(
80+
np.asarray(result_no_tile), np.asarray(result_tiled), rtol=1e-12, atol=1e-14
81+
)

0 commit comments

Comments
 (0)