@@ -228,6 +228,7 @@ def fit_clusterless_gmm_encoding_model(
228228 gmm_covariance_type_joint : str = "full" ,
229229 gmm_random_state : int | None = 0 ,
230230 disable_progress_bar : bool = False ,
231+ ** kwargs , # Accept but ignore KDE-specific parameters for API compatibility
231232) -> EncodingModel :
232233 """
233234 Fit the clusterless encoding model using GMMs.
@@ -285,7 +286,10 @@ def fit_clusterless_gmm_encoding_model(
285286 - disable_progress_bar
286287 """
287288 position = _as_jnp (position if position .ndim > 1 else position [:, None ])
288- position_time = _as_jnp (position_time )
289+ # NOTE: Do NOT convert position_time to JAX! It causes float64→float32 precision loss
290+ # with large timestamp values (e.g., Unix timestamps), creating apparent duplicates.
291+ # Keep as numpy for interpolation (scipy.interpolate.interpn requires numpy anyway).
292+ position_time = np .asarray (position_time )
289293
290294 # Interior bins (cached)
291295 if environment .is_track_interior_ is not None :
@@ -302,7 +306,6 @@ def fit_clusterless_gmm_encoding_model(
302306 weights = jnp .ones ((position .shape [0 ],), dtype = position .dtype )
303307 else :
304308 weights = _as_jnp (weights )
305- total_weight = float (jnp .sum (weights ))
306309
307310 # If environment has a graph and positions are 2D+, linearize to 1D for occupancy/GPI
308311 if environment .track_graph is not None and position .shape [1 ] > 1 :
@@ -378,9 +381,8 @@ def fit_clusterless_gmm_encoding_model(
378381 mean_rates .append (mean_rate )
379382
380383 # Positions at spike times
381- # Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
382384 enc_pos = get_position_at_time (
383- np . asarray ( position_time ) , np .asarray (position ), elect_times , environment
385+ position_time , np .asarray (position ), elect_times , environment
384386 )
385387
386388 # GPI GMM (position only)
@@ -711,9 +713,8 @@ def compute_local_log_likelihood(
711713 # Interpolate position at bin times (use bin centers)
712714 # We'll take the midpoints as "time of interest" for local evaluation
713715 t_centers = 0.5 * (time [:- 1 ] + time [1 :])
714- # Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
715716 interp_pos = get_position_at_time (
716- np . asarray ( position_time ) , np .asarray (position ), t_centers , env
717+ position_time , np .asarray (position ), t_centers , env
717718 ) # (n_time, pos_dims)
718719
719720 # Occupancy density and its log at the animal's position
@@ -744,9 +745,8 @@ def compute_local_log_likelihood(
744745
745746 # Spike contributions at their true positions
746747 if elect_times .shape [0 ] > 0 :
747- # Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
748748 pos_at_spike_time = get_position_at_time (
749- np . asarray ( position_time ) , np .asarray (position ), elect_times , env
749+ position_time , np .asarray (position ), elect_times , env
750750 ) # (n_spikes, pos_dims)
751751 eval_points = jnp .concatenate (
752752 [pos_at_spike_time , elect_feats ], axis = 1
0 commit comments