Skip to content

Commit 31521a1

Browse files
committed
Remove redundant time conversion in compute_local_log_likelihood
Eliminated unnecessary conversion of the 'time' variable to a JAX array and adjusted interpolation to use original time values. Also removed redundant conversion of 'elect_times' to JAX array, as it is not needed.
1 parent ac578c8 commit 31521a1

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,18 +680,16 @@ def compute_local_log_likelihood(
680680
log_likelihood : jnp.ndarray, shape (n_time, 1)
681681
Log likelihood at the animal's position for each time bin.
682682
"""
683-
time = _as_jnp(time)
684683
# NOTE: Keep position_time as numpy to avoid float64→float32 precision loss
685684
position_time = np.asarray(position_time)
686685
position = _as_jnp(position if position.ndim > 1 else position[:, None])
687686

688687
n_time = time.shape[0] - 1
689688

690689
# Interpolate position at bin times (use bin centers)
691-
# We'll take the midpoints as "time of interest" for local evaluation
692-
t_centers = 0.5 * (time[:-1] + time[1:])
690+
693691
interp_pos = get_position_at_time(
694-
position_time, np.asarray(position), t_centers, environment
692+
position_time, np.asarray(position), time, environment
695693
) # (n_time, pos_dims)
696694

697695
# Occupancy density and its log at the animal's position
@@ -712,7 +710,6 @@ def compute_local_log_likelihood(
712710
unit="electrode",
713711
disable=disable_progress_bar,
714712
):
715-
elect_times = _as_jnp(elect_times)
716713
elect_feats = _as_jnp(elect_feats)
717714

718715
# Clip to decoding window

0 commit comments

Comments
 (0)