Skip to content

Commit c7d747d

Browse files
committed
Preserve precision by avoiding JAX conversion of position_time
Updated fit_clusterless_gmm_encoding_model and related calls to keep position_time as a numpy array instead of converting to JAX. This prevents float64 to float32 precision loss with large timestamp values and ensures compatibility with scipy.interpolate.interpn, which requires numpy arrays.
1 parent 7c33edd commit c7d747d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def fit_clusterless_gmm_encoding_model(
228228
gmm_covariance_type_joint: str = "full",
229229
gmm_random_state: int | None = 0,
230230
disable_progress_bar: bool = False,
231+
**kwargs, # Accept but ignore KDE-specific parameters for API compatibility
231232
) -> EncodingModel:
232233
"""
233234
Fit the clusterless encoding model using GMMs.
@@ -285,7 +286,10 @@ def fit_clusterless_gmm_encoding_model(
285286
- disable_progress_bar
286287
"""
287288
position = _as_jnp(position if position.ndim > 1 else position[:, None])
288-
position_time = _as_jnp(position_time)
289+
# NOTE: Do NOT convert position_time to JAX! It causes float64→float32 precision loss
290+
# with large timestamp values (e.g., Unix timestamps), creating apparent duplicates.
291+
# Keep as numpy for interpolation (scipy.interpolate.interpn requires numpy anyway).
292+
position_time = np.asarray(position_time)
289293

290294
# Interior bins (cached)
291295
if environment.is_track_interior_ is not None:
@@ -302,7 +306,6 @@ def fit_clusterless_gmm_encoding_model(
302306
weights = jnp.ones((position.shape[0],), dtype=position.dtype)
303307
else:
304308
weights = _as_jnp(weights)
305-
total_weight = float(jnp.sum(weights))
306309

307310
# If environment has a graph and positions are 2D+, linearize to 1D for occupancy/GPI
308311
if environment.track_graph is not None and position.shape[1] > 1:
@@ -378,9 +381,8 @@ def fit_clusterless_gmm_encoding_model(
378381
mean_rates.append(mean_rate)
379382

380383
# Positions at spike times
381-
# Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
382384
enc_pos = get_position_at_time(
383-
np.asarray(position_time), np.asarray(position), elect_times, environment
385+
position_time, np.asarray(position), elect_times, environment
384386
)
385387

386388
# GPI GMM (position only)
@@ -711,9 +713,8 @@ def compute_local_log_likelihood(
711713
# Interpolate position at bin times (use bin centers)
712714
# We'll take the midpoints as "time of interest" for local evaluation
713715
t_centers = 0.5 * (time[:-1] + time[1:])
714-
# Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
715716
interp_pos = get_position_at_time(
716-
np.asarray(position_time), np.asarray(position), t_centers, env
717+
position_time, np.asarray(position), t_centers, env
717718
) # (n_time, pos_dims)
718719

719720
# Occupancy density and its log at the animal's position
@@ -744,9 +745,8 @@ def compute_local_log_likelihood(
744745

745746
# Spike contributions at their true positions
746747
if elect_times.shape[0] > 0:
747-
# Note: get_position_at_time uses scipy.interpolate.interpn which requires numpy arrays
748748
pos_at_spike_time = get_position_at_time(
749-
np.asarray(position_time), np.asarray(position), elect_times, env
749+
position_time, np.asarray(position), elect_times, env
750750
) # (n_spikes, pos_dims)
751751
eval_points = jnp.concatenate(
752752
[pos_at_spike_time, elect_feats], axis=1

0 commit comments

Comments
 (0)