Skip to content

Commit 33d70ee

Browse files
committed
Refactor log likelihood initialization in clusterless GMM
Replaces the broadcasting approach for initializing log_likelihood with a more efficient multiplication using jnp.ones. Also simplifies the in-bounds filtering and conversion to JAX arrays for electrode features.
1 parent 68d7f79 commit 33d70ee

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,10 @@ def predict_clusterless_gmm_log_likelihood(
469469
n_bins = interior_place_bin_centers.shape[0]
470470

471471
# Start with the expected-counts (ground process) term, broadcast over time
472-
log_likelihood = (
473-
(-summed_ground_process_intensity).reshape(1, -1).repeat(n_time, axis=0)
474-
) # (n_time, n_bins)
472+
# log_likelihood = (
473+
# (-summed_ground_process_intensity).reshape(1, -1).repeat(n_time, axis=0)
474+
# ) # (n_time, n_bins)
475+
log_likelihood = -1.0 * summed_ground_process_intensity * jnp.ones((n_time, 1))
475476

476477
# Per-electrode contributions in log-space
477478
for elect_feats, elect_times, joint_gmm, mean_rate in tqdm(
@@ -486,11 +487,7 @@ def predict_clusterless_gmm_log_likelihood(
486487
# Clip to decoding window
487488
in_bounds = np.logical_and(elect_times >= time[0], elect_times <= time[-1])
488489
elect_times = elect_times[in_bounds]
489-
elect_feats = elect_feats[in_bounds]
490-
elect_feats = _as_jnp(elect_feats)
491-
492-
if elect_times.shape[0] == 0:
493-
continue
490+
elect_feats = _as_jnp(elect_feats[in_bounds])
494491

495492
# Bin spikes
496493
seg_ids = get_spike_time_bin_ind(elect_times, time) # (n_spikes,)

0 commit comments

Comments
 (0)