@@ -369,14 +369,18 @@ def fit_clusterless_gmm_encoding_model(
369369 # Interpolate position weights at spike times
370370 elect_weights = jnp .interp (elect_times , position_time , weights )
371371
372- # Mean rate contribution
373- mean_rate = float (jnp .sum (elect_weights ) / total_weight )
372+ # Mean rate contribution: spikes per time bin (matching KDE)
373+ # BUG FIX: Was using n_position_samples, should use n_time_bins
374+ # to match units required by Poisson likelihood formula
375+ n_time_bins = int ((position_time [- 1 ] - position_time [0 ]) * sampling_frequency )
376+ mean_rate = float (len (elect_times ) / n_time_bins )
374377 mean_rate = jnp .clip (mean_rate , a_min = EPS ) # avoid 0 rate
375378 mean_rates .append (mean_rate )
376379
377380 # Positions at spike times
381+ # Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
378382 enc_pos = get_position_at_time (
379- position_time , position , elect_times , environment
383+ np . asarray ( position_time ), np . asarray ( position ) , elect_times , environment
380384 )
381385
382386 # GPI GMM (position only)
@@ -444,6 +448,7 @@ def predict_clusterless_gmm_log_likelihood(
444448 spike_block_size : int = 1000 ,
445449 bin_tile_size : int | None = None ,
446450 disable_progress_bar : bool = False ,
451+ ** kwargs , # Accept and ignore extra kwargs for compatibility with model interface
447452) -> jnp .ndarray :
448453 """
449454 Predict the (non-local or local) log likelihood using the fitted GMM model.
@@ -706,8 +711,9 @@ def compute_local_log_likelihood(
706711 # Interpolate position at bin times (use bin centers)
707712 # We'll take the midpoints as "time of interest" for local evaluation
708713 t_centers = 0.5 * (time [:- 1 ] + time [1 :])
714+ # Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
709715 interp_pos = get_position_at_time (
710- position_time , position , t_centers , env
716+ np . asarray ( position_time ), np . asarray ( position ) , t_centers , env
711717 ) # (n_time, pos_dims)
712718
713719 # Occupancy density and its log at the animal's position
@@ -738,8 +744,9 @@ def compute_local_log_likelihood(
738744
739745 # Spike contributions at their true positions
740746 if elect_times .shape [0 ] > 0 :
747+ # Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
741748 pos_at_spike_time = get_position_at_time (
742- position_time , position , elect_times , env
749+ np . asarray ( position_time ), np . asarray ( position ) , elect_times , env
743750 ) # (n_spikes, pos_dims)
744751 eval_points = jnp .concatenate (
745752 [pos_at_spike_time , elect_feats ], axis = 1
@@ -759,9 +766,8 @@ def compute_local_log_likelihood(
759766 + segment_sum (
760767 terms [:, None ],
761768 seg_ids ,
762- n_time ,
763- indices_are_sorted = True ,
764769 num_segments = n_time ,
770+ indices_are_sorted = True ,
765771 ).ravel ()
766772 )
767773
0 commit comments