Skip to content

Commit af6239c

Browse files
committed
Refactor get_spike_time_bin_ind to common module
Moved the get_spike_time_bin_ind function to likelihoods/common.py and updated imports in clusterless_gmm.py, clusterless_kde.py, and clusterless_kde_log.py. Removed duplicate implementations from the affected files for consistency and code reuse. Minor docstring and code cleanup included.
1 parent 15f3b2b commit af6239c

File tree

4 files changed

+37
-73
lines changed

4 files changed

+37
-73
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from track_linearization import get_linearized_position # type: ignore[import-untyped]
1515

1616
from non_local_detector.environment import Environment
17-
from non_local_detector.likelihoods.common import EPS, get_position_at_time, safe_divide
17+
from non_local_detector.likelihoods.common import (
18+
EPS,
19+
get_position_at_time,
20+
get_spike_time_bin_ind,
21+
safe_divide,
22+
)
1823
from non_local_detector.likelihoods.gmm import GaussianMixtureModel
1924

2025
# ---------------------------------------------------------------------
@@ -38,30 +43,6 @@ def _as_jnp(x) -> jnp.ndarray:
3843
return x if isinstance(x, jnp.ndarray) else jnp.asarray(x)
3944

4045

41-
def get_spike_time_bin_ind(
42-
spike_times: jnp.ndarray, time_bin_edges: jnp.ndarray
43-
) -> jnp.ndarray:
44-
"""Map spike times to decoding time-bin indices on device.
45-
46-
Parameters
47-
----------
48-
spike_times : jnp.ndarray, shape (n_spikes,)
49-
Array of spike times to map to bins.
50-
time_bin_edges : jnp.ndarray, shape (n_bins + 1,)
51-
Edges of time bins defining intervals [t0, t1), ..., [t_{n-1}, tn].
52-
53-
Returns
54-
-------
55-
bin_indices : jnp.ndarray, shape (n_spikes,)
56-
Index of time bin for each spike (0 to n_bins-1).
57-
"""
58-
# Right-closed bins [t_i, t_{i+1}), except the last edge which is included
59-
# Use JAX ops to keep everything on device
60-
inds = jnp.searchsorted(time_bin_edges, spike_times, side="right") - 1
61-
last = jnp.isclose(spike_times, time_bin_edges[-1])
62-
return jnp.where(last, time_bin_edges.shape[0] - 2, inds).astype(jnp.int32)
63-
64-
6546
def _gmm_logp(gmm: GaussianMixtureModel, X: jnp.ndarray) -> jnp.ndarray:
6647
"""Log density under a fitted GMM.
6748

src/non_local_detector/likelihoods/clusterless_kde.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,10 @@
1212
block_kde,
1313
gaussian_pdf,
1414
get_position_at_time,
15+
get_spike_time_bin_ind,
1516
)
1617

1718

18-
def get_spike_time_bin_ind(spike_times: np.ndarray, time: np.ndarray) -> np.ndarray:
19-
"""Get the index of the time bin for each spike time.
20-
21-
Parameters
22-
----------
23-
spike_times : np.ndarray, shape (n_spikes,)
24-
time : np.ndarray, shape (n_time_bins,)
25-
Bin edges.
26-
27-
Returns
28-
-------
29-
ind : np.ndarray, shape (n_spikes,)
30-
"""
31-
return np.digitize(spike_times, time[1:-1])
32-
33-
3419
def kde_distance(
3520
eval_points: jnp.ndarray, samples: jnp.ndarray, std: jnp.ndarray
3621
) -> jnp.ndarray:

src/non_local_detector/likelihoods/clusterless_kde_log.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,12 @@
1212
block_kde,
1313
gaussian_pdf,
1414
get_position_at_time,
15+
get_spike_time_bin_ind,
1516
log_gaussian_pdf,
1617
safe_log,
1718
)
1819

1920

20-
def get_spike_time_bin_ind(spike_times: np.ndarray, time: np.ndarray) -> np.ndarray:
21-
"""Get the index of the time bin for each spike time.
22-
23-
Parameters
24-
----------
25-
spike_times : np.ndarray, shape (n_spikes,)
26-
time : np.ndarray, shape (n_time_bins,)
27-
Bin edges.
28-
29-
Returns
30-
-------
31-
ind : np.ndarray, shape (n_spikes,)
32-
"""
33-
return np.digitize(spike_times, time[1:-1])
34-
35-
3621
@jax.jit
3722
def kde_distance(
3823
eval_points: jnp.ndarray, samples: jnp.ndarray, std: jnp.ndarray
@@ -123,6 +108,7 @@ def log_gaussian_per_dim(eval_dim, sample_dim, sigma):
123108
return jnp.sum(per_dim_log_distances, axis=0)
124109

125110

111+
@jax.jit
126112
def log_kde_distance_streaming(
127113
eval_points: jnp.ndarray,
128114
samples: jnp.ndarray,
@@ -152,6 +138,10 @@ def log_kde_distance_streaming(
152138
153139
Notes
154140
-----
141+
This function is JIT-compiled and will be specialized for each unique combination
142+
of input shapes. The shape dimensions (n_dims, n_samp, n_eval) are traced during
143+
compilation, so different shapes will result in separate compiled versions.
144+
155145
Memory usage:
156146
- log_kde_distance (vmap): O(D×n_samp×n_eval) peak
157147
- log_kde_distance_streaming (fori_loop): O(n_samp×n_eval) peak
@@ -489,6 +479,7 @@ def process_pos_tile(
489479
)(log_pos_tile, logK_mark_chunk.T)
490480

491481
# Update output with this tile
482+
# fori_loop handles in-place updates efficiently when possible
492483
return jax.lax.dynamic_update_slice(
493484
log_marginal_chunk, log_marginal_tile, (0, pos_start)
494485
)
@@ -838,15 +829,6 @@ def block_estimate_log_joint_mark_intensity(
838829
if n_decoding_spikes == 0:
839830
return jnp.full((0, n_position_bins), LOG_EPS)
840831

841-
# Create JIT-compiled update function with buffer donation
842-
# donate_argnums=(0,) allows JAX to reuse the output buffer in-place
843-
_update_block = jax.jit(
844-
lambda out_array, block_result, start_idx: jax.lax.dynamic_update_slice(
845-
out_array, block_result, (start_idx, 0)
846-
),
847-
donate_argnums=(0,),
848-
)
849-
850832
out = jnp.zeros((n_decoding_spikes, n_position_bins))
851833
for start_ind in range(0, n_decoding_spikes, block_size):
852834
block_inds = slice(start_ind, start_ind + block_size)
@@ -865,7 +847,7 @@ def block_estimate_log_joint_mark_intensity(
865847
position_eval_points=position_eval_points,
866848
position_std=position_std,
867849
)
868-
out = _update_block(out, block_result, start_ind)
850+
out = jax.lax.dynamic_update_slice(out, block_result, (start_ind, 0))
869851

870852
return jnp.clip(out, a_min=LOG_EPS, a_max=None)
871853

@@ -1171,12 +1153,12 @@ def predict_clusterless_kde_log_likelihood(
11711153
enc_tile_size=enc_tile_size,
11721154
pos_tile_size=pos_tile_size,
11731155
use_streaming=use_streaming,
1174-
encoding_positions=electrode_encoding_positions
1175-
if use_streaming
1176-
else None,
1177-
position_eval_points=interior_place_bin_centers
1178-
if use_streaming
1179-
else None,
1156+
encoding_positions=(
1157+
electrode_encoding_positions if use_streaming else None
1158+
),
1159+
position_eval_points=(
1160+
interior_place_bin_centers if use_streaming else None
1161+
),
11801162
position_std=position_std if use_streaming else None,
11811163
),
11821164
get_spike_time_bin_ind(electrode_spike_times, time),

src/non_local_detector/likelihoods/common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ def get_position_at_time(
5050
return position_at_spike_times
5151

5252

53+
def get_spike_time_bin_ind(spike_times: np.ndarray, time: np.ndarray) -> np.ndarray:
54+
"""Get the index of the time bin for each spike time.
55+
56+
Parameters
57+
----------
58+
spike_times : np.ndarray, shape (n_spikes,)
59+
time : np.ndarray, shape (n_time_bins,)
60+
Bin edges.
61+
62+
Returns
63+
-------
64+
ind : np.ndarray, shape (n_spikes,)
65+
"""
66+
return np.digitize(spike_times, time[1:-1])
67+
68+
5369
@jax.jit
5470
def log_gaussian_pdf(
5571
x: jnp.ndarray, mean: jnp.ndarray, sigma: jnp.ndarray

0 commit comments

Comments
 (0)