Skip to content

Commit 06093d8

Browse files
committed
Fix interpolation and mean rate calculation in GMM model
Corrects the mean rate calculation to use the number of time bins instead of position samples, ensuring units match the Poisson likelihood formula. Also ensures numpy arrays are passed to get_position_at_time for compatibility with scipy.interpolate.interpn, and adds a kwargs passthrough for model interface compatibility.
1 parent e442d6e commit 06093d8

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,14 +369,18 @@ def fit_clusterless_gmm_encoding_model(
369369
# Interpolate position weights at spike times
370370
elect_weights = jnp.interp(elect_times, position_time, weights)
371371

372-
# Mean rate contribution
373-
mean_rate = float(jnp.sum(elect_weights) / total_weight)
372+
# Mean rate contribution: spikes per time bin (matching KDE)
373+
# BUG FIX: Was using n_position_samples, should use n_time_bins
374+
# to match units required by Poisson likelihood formula
375+
n_time_bins = int((position_time[-1] - position_time[0]) * sampling_frequency)
376+
mean_rate = float(len(elect_times) / n_time_bins)
374377
mean_rate = jnp.clip(mean_rate, a_min=EPS) # avoid 0 rate
375378
mean_rates.append(mean_rate)
376379

377380
# Positions at spike times
381+
# Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
378382
enc_pos = get_position_at_time(
379-
position_time, position, elect_times, environment
383+
np.asarray(position_time), np.asarray(position), elect_times, environment
380384
)
381385

382386
# GPI GMM (position only)
@@ -444,6 +448,7 @@ def predict_clusterless_gmm_log_likelihood(
444448
spike_block_size: int = 1000,
445449
bin_tile_size: int | None = None,
446450
disable_progress_bar: bool = False,
451+
**kwargs, # Accept and ignore extra kwargs for compatibility with model interface
447452
) -> jnp.ndarray:
448453
"""
449454
Predict the (non-local or local) log likelihood using the fitted GMM model.
@@ -706,8 +711,9 @@ def compute_local_log_likelihood(
706711
# Interpolate position at bin times (use bin centers)
707712
# We'll take the midpoints as "time of interest" for local evaluation
708713
t_centers = 0.5 * (time[:-1] + time[1:])
714+
# Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
709715
interp_pos = get_position_at_time(
710-
position_time, position, t_centers, env
716+
np.asarray(position_time), np.asarray(position), t_centers, env
711717
) # (n_time, pos_dims)
712718

713719
# Occupancy density and its log at the animal's position
@@ -738,8 +744,9 @@ def compute_local_log_likelihood(
738744

739745
# Spike contributions at their true positions
740746
if elect_times.shape[0] > 0:
747+
# Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
741748
pos_at_spike_time = get_position_at_time(
742-
position_time, position, elect_times, env
749+
np.asarray(position_time), np.asarray(position), elect_times, env
743750
) # (n_spikes, pos_dims)
744751
eval_points = jnp.concatenate(
745752
[pos_at_spike_time, elect_feats], axis=1
@@ -759,9 +766,8 @@ def compute_local_log_likelihood(
759766
+ segment_sum(
760767
terms[:, None],
761768
seg_ids,
762-
n_time,
763-
indices_are_sorted=True,
764769
num_segments=n_time,
770+
indices_are_sorted=True,
765771
).ravel()
766772
)
767773

0 commit comments

Comments
 (0)