Skip to content

Commit 6b60f65

Browse files
committed
Fix GMM encoding for empty spike windows and precision
Skip electrodes with no spikes in the encoding window to prevent errors. Avoid storing position_time in the encoding model to prevent argument conflicts, and keep position_time as a numpy array in likelihood functions to preserve float64 precision.
1 parent c7d747d commit 6b60f65

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ def fit_clusterless_gmm_encoding_model(
369369
elect_times = elect_times[in_bounds]
370370
elect_feats = elect_feats[in_bounds]
371371

372+
# Skip electrodes with no spikes in encoding window
373+
if elect_times.shape[0] == 0:
374+
continue
375+
372376
# Interpolate position weights at spike times
373377
elect_weights = jnp.interp(elect_times, position_time, weights)
374378

@@ -422,7 +426,8 @@ def fit_clusterless_gmm_encoding_model(
422426
"joint_models": joint_models,
423427
"mean_rates": jnp.asarray(mean_rates),
424428
"summed_ground_process_intensity": summed_ground_process_intensity,
425-
"position_time": position_time,
429+
# NOTE: Don't store position_time - it's passed as a separate argument to predict()
430+
# Storing it causes "multiple values for argument 'position_time'" error
426431
"gmm_components_occupancy": gmm_components_occupancy,
427432
"gmm_components_gpi": gmm_components_gpi,
428433
"gmm_components_joint": gmm_components_joint,
@@ -488,7 +493,8 @@ def predict_clusterless_gmm_log_likelihood(
488493
If local : jnp.ndarray, shape (n_time, 1)
489494
"""
490495
time = _as_jnp(time)
491-
position_time = _as_jnp(position_time)
496+
# NOTE: Keep position_time as numpy to avoid float64→float32 precision loss
497+
position_time = np.asarray(position_time)
492498
position = _as_jnp(position if position.ndim > 1 else position[:, None])
493499

494500
if is_local:
@@ -699,7 +705,8 @@ def compute_local_log_likelihood(
699705
Log likelihood at the animal's position for each time bin.
700706
"""
701707
time = _as_jnp(time)
702-
position_time = _as_jnp(position_time)
708+
# NOTE: Keep position_time as numpy to avoid float64→float32 precision loss
709+
position_time = np.asarray(position_time)
703710
position = _as_jnp(position if position.ndim > 1 else position[:, None])
704711

705712
env = encoding_model["environment"]

0 commit comments

Comments
 (0)