Skip to content

Commit ac578c8

Browse files
committed
Refactor clusterless GMM to use log occupancy values
Replaces occupancy bin density arrays with log occupancy values throughout the clusterless GMM encoding and prediction functions. This improves numerical stability and simplifies the computation of summed ground process intensity using log-space operations.
1 parent 04d0a4d commit ac578c8

File tree

1 file changed

+25
-30
lines changed

1 file changed

+25
-30
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from non_local_detector.environment import Environment
1717
from non_local_detector.likelihoods.common import (
1818
EPS,
19+
LOG_EPS,
1920
get_position_at_time,
2021
get_spike_time_bin_ind,
21-
safe_divide,
2222
)
2323
from non_local_detector.likelihoods.gmm import GaussianMixtureModel
2424

@@ -297,21 +297,11 @@ def fit_clusterless_gmm_encoding_model(
297297
edge_spacing=environment.edge_spacing,
298298
).linear_position.to_numpy()[:, None]
299299
pos_for_occ = _as_jnp(position1D)
300-
301-
# CRITICAL: Linearize bin centers too (must match occupancy space)
302-
raw_bin_centers = environment.place_bin_centers_[is_track_interior]
303-
lin_bins = get_linearized_position(
304-
np.asarray(raw_bin_centers),
305-
environment.track_graph,
306-
edge_order=environment.edge_order,
307-
edge_spacing=environment.edge_spacing,
308-
).linear_position.to_numpy()[:, None]
309-
interior_place_bin_centers = _as_jnp(lin_bins)
310300
else:
311301
pos_for_occ = position
312-
interior_place_bin_centers = _as_jnp(
313-
environment.place_bin_centers_[is_track_interior]
314-
)
302+
303+
is_track_interior = environment.is_track_interior_.ravel()
304+
interior_place_bin_centers = environment.place_bin_centers_[is_track_interior]
315305

316306
# Fit occupancy GMM and precompute per-bin terms
317307
occupancy_model = _fit_gmm_density(
@@ -321,17 +311,12 @@ def fit_clusterless_gmm_encoding_model(
321311
random_state=gmm_random_state,
322312
covariance_type=gmm_covariance_type_occupancy,
323313
)
324-
occupancy_bins = _gmm_density(
325-
occupancy_model, interior_place_bin_centers
326-
) # (n_bins,)
327-
log_occupancy_bins = _gmm_logp(
328-
occupancy_model, interior_place_bin_centers
329-
) # (n_bins,)
314+
log_occupancy = _gmm_logp(occupancy_model, interior_place_bin_centers)
330315

331316
gpi_models: list[GaussianMixtureModel] = []
332317
joint_models: list[GaussianMixtureModel] = []
333318
mean_rates: list[float] = []
334-
summed_ground_process_intensity = jnp.zeros_like(occupancy_bins)
319+
log_summed_ground_process_intensity = jnp.full_like(log_occupancy, -jnp.inf)
335320

336321
# Fit per-electrode models
337322
for elect_feats, elect_times in tqdm(
@@ -392,17 +377,28 @@ def fit_clusterless_gmm_encoding_model(
392377
joint_models.append(joint_gmm)
393378

394379
# Expected-counts term at bins: mean_rate * (gpi / occupancy)
395-
gpi_bins = _gmm_density(gpi_gmm, interior_place_bin_centers) # (n_bins,)
396-
summed_ground_process_intensity = summed_ground_process_intensity + jnp.clip(
397-
mean_rate * safe_divide(gpi_bins, occupancy_bins), a_min=EPS
380+
log_gp_num = _gmm_logp(gpi_gmm, interior_place_bin_centers) # (n_bins,)
381+
log_gpi = jnp.log(mean_rate) + log_gp_num - log_occupancy
382+
383+
log_summed_ground_process_intensity = jnp.logaddexp(
384+
log_summed_ground_process_intensity, log_gpi
398385
)
399386

387+
max_log = jnp.log(jnp.finfo(log_summed_ground_process_intensity.dtype).max)
388+
summed_ground_process_intensity = jnp.clip(
389+
jnp.exp(
390+
jnp.clip(
391+
log_summed_ground_process_intensity, min=LOG_EPS, max=jnp.exp(max_log)
392+
)
393+
),
394+
min=EPS,
395+
)
396+
400397
return {
401398
"environment": environment,
402399
"occupancy_model": occupancy_model,
403400
"interior_place_bin_centers": interior_place_bin_centers,
404-
"occupancy_bins": occupancy_bins,
405-
"log_occupancy_bins": log_occupancy_bins,
401+
"log_occupancy": log_occupancy,
406402
"gpi_models": gpi_models,
407403
"joint_models": joint_models,
408404
"mean_rates": jnp.asarray(mean_rates),
@@ -425,7 +421,7 @@ def predict_clusterless_gmm_log_likelihood(
425421
environment: Environment,
426422
occupancy_model: GaussianMixtureModel,
427423
interior_place_bin_centers: jnp.ndarray,
428-
log_occupancy_bins: jnp.ndarray,
424+
log_occupancy: jnp.ndarray,
429425
gpi_models: list[GaussianMixtureModel],
430426
joint_models: list[GaussianMixtureModel],
431427
mean_rates: jnp.ndarray,
@@ -471,7 +467,6 @@ def predict_clusterless_gmm_log_likelihood(
471467
If non-local: jnp.ndarray, shape (n_time, n_bins)
472468
If local : jnp.ndarray, shape (n_time, 1)
473469
"""
474-
time = _as_jnp(time)
475470
# NOTE: Keep position_time as numpy to avoid float64→float32 precision loss
476471
position_time = np.asarray(position_time)
477472
position = _as_jnp(position if position.ndim > 1 else position[:, None])
@@ -608,7 +603,7 @@ def _update_block_one_tile(
608603
joint_logp_block,
609604
block_seg_ids,
610605
log_mean_rate,
611-
log_occupancy_bins,
606+
log_occupancy,
612607
n_time,
613608
)
614609
else:
@@ -638,7 +633,7 @@ def _update_block_one_tile(
638633
joint_logp_tile,
639634
block_seg_ids,
640635
log_mean_rate,
641-
log_occupancy_bins[bin_start:bin_end],
636+
log_occupancy[bin_start:bin_end],
642637
n_time,
643638
)
644639

0 commit comments

Comments
 (0)