1616from non_local_detector .environment import Environment
1717from non_local_detector .likelihoods .common import (
1818 EPS ,
19+ LOG_EPS ,
1920 get_position_at_time ,
2021 get_spike_time_bin_ind ,
21- safe_divide ,
2222)
2323from non_local_detector .likelihoods .gmm import GaussianMixtureModel
2424
@@ -297,21 +297,11 @@ def fit_clusterless_gmm_encoding_model(
297297 edge_spacing = environment .edge_spacing ,
298298 ).linear_position .to_numpy ()[:, None ]
299299 pos_for_occ = _as_jnp (position1D )
300-
301- # CRITICAL: Linearize bin centers too (must match occupancy space)
302- raw_bin_centers = environment .place_bin_centers_ [is_track_interior ]
303- lin_bins = get_linearized_position (
304- np .asarray (raw_bin_centers ),
305- environment .track_graph ,
306- edge_order = environment .edge_order ,
307- edge_spacing = environment .edge_spacing ,
308- ).linear_position .to_numpy ()[:, None ]
309- interior_place_bin_centers = _as_jnp (lin_bins )
310300 else :
311301 pos_for_occ = position
312- interior_place_bin_centers = _as_jnp (
313- environment .place_bin_centers_ [ is_track_interior ]
314- )
302+
303+ is_track_interior = environment .is_track_interior_ . ravel ()
304+ interior_place_bin_centers = environment . place_bin_centers_ [ is_track_interior ]
315305
316306 # Fit occupancy GMM and precompute per-bin terms
317307 occupancy_model = _fit_gmm_density (
@@ -321,17 +311,12 @@ def fit_clusterless_gmm_encoding_model(
321311 random_state = gmm_random_state ,
322312 covariance_type = gmm_covariance_type_occupancy ,
323313 )
324- occupancy_bins = _gmm_density (
325- occupancy_model , interior_place_bin_centers
326- ) # (n_bins,)
327- log_occupancy_bins = _gmm_logp (
328- occupancy_model , interior_place_bin_centers
329- ) # (n_bins,)
314+ log_occupancy = _gmm_logp (occupancy_model , interior_place_bin_centers )
330315
331316 gpi_models : list [GaussianMixtureModel ] = []
332317 joint_models : list [GaussianMixtureModel ] = []
333318 mean_rates : list [float ] = []
334- summed_ground_process_intensity = jnp .zeros_like ( occupancy_bins )
319+ log_summed_ground_process_intensity = jnp .full_like ( log_occupancy , - jnp . inf )
335320
336321 # Fit per-electrode models
337322 for elect_feats , elect_times in tqdm (
@@ -392,17 +377,28 @@ def fit_clusterless_gmm_encoding_model(
392377 joint_models .append (joint_gmm )
393378
394379 # Expected-counts term at bins: mean_rate * (gpi / occupancy)
395- gpi_bins = _gmm_density (gpi_gmm , interior_place_bin_centers ) # (n_bins,)
396- summed_ground_process_intensity = summed_ground_process_intensity + jnp .clip (
397- mean_rate * safe_divide (gpi_bins , occupancy_bins ), a_min = EPS
380+ log_gp_num = _gmm_logp (gpi_gmm , interior_place_bin_centers ) # (n_bins,)
381+ log_gpi = jnp .log (mean_rate ) + log_gp_num - log_occupancy
382+
383+ log_summed_ground_process_intensity = jnp .logaddexp (
384+ log_summed_ground_process_intensity , log_gpi
398385 )
399386
387+ max_log = jnp .log (jnp .finfo (log_summed_ground_process_intensity .dtype ).max )
388+ summed_ground_process_intensity = jnp .clip (
389+ jnp .exp (
390+ jnp .clip (
391+ log_summed_ground_process_intensity , min = LOG_EPS , max = jnp .exp (max_log )
392+ )
393+ ),
394+ min = EPS ,
395+ )
396+
400397 return {
401398 "environment" : environment ,
402399 "occupancy_model" : occupancy_model ,
403400 "interior_place_bin_centers" : interior_place_bin_centers ,
404- "occupancy_bins" : occupancy_bins ,
405- "log_occupancy_bins" : log_occupancy_bins ,
401+ "log_occupancy" : log_occupancy ,
406402 "gpi_models" : gpi_models ,
407403 "joint_models" : joint_models ,
408404 "mean_rates" : jnp .asarray (mean_rates ),
@@ -425,7 +421,7 @@ def predict_clusterless_gmm_log_likelihood(
425421 environment : Environment ,
426422 occupancy_model : GaussianMixtureModel ,
427423 interior_place_bin_centers : jnp .ndarray ,
428- log_occupancy_bins : jnp .ndarray ,
424+ log_occupancy : jnp .ndarray ,
429425 gpi_models : list [GaussianMixtureModel ],
430426 joint_models : list [GaussianMixtureModel ],
431427 mean_rates : jnp .ndarray ,
@@ -471,7 +467,6 @@ def predict_clusterless_gmm_log_likelihood(
471467 If non-local: jnp.ndarray, shape (n_time, n_bins)
472468 If local : jnp.ndarray, shape (n_time, 1)
473469 """
474- time = _as_jnp (time )
475470 # NOTE: Keep position_time as numpy to avoid float64→float32 precision loss
476471 position_time = np .asarray (position_time )
477472 position = _as_jnp (position if position .ndim > 1 else position [:, None ])
@@ -608,7 +603,7 @@ def _update_block_one_tile(
608603 joint_logp_block ,
609604 block_seg_ids ,
610605 log_mean_rate ,
611- log_occupancy_bins ,
606+ log_occupancy ,
612607 n_time ,
613608 )
614609 else :
@@ -638,7 +633,7 @@ def _update_block_one_tile(
638633 joint_logp_tile ,
639634 block_seg_ids ,
640635 log_mean_rate ,
641- log_occupancy_bins [bin_start :bin_end ],
636+ log_occupancy [bin_start :bin_end ],
642637 n_time ,
643638 )
644639
0 commit comments