@@ -426,15 +426,6 @@ def fit_clusterless_gmm_encoding_model(
426426 "joint_models" : joint_models ,
427427 "mean_rates" : jnp .asarray (mean_rates ),
428428 "summed_ground_process_intensity" : summed_ground_process_intensity ,
429- # NOTE: Don't store position_time - it's passed as a separate argument to predict()
430- # Storing it causes "multiple values for argument 'position_time'" error
431- "gmm_components_occupancy" : gmm_components_occupancy ,
432- "gmm_components_gpi" : gmm_components_gpi ,
433- "gmm_components_joint" : gmm_components_joint ,
434- "gmm_covariance_type_occupancy" : gmm_covariance_type_occupancy ,
435- "gmm_covariance_type_gpi" : gmm_covariance_type_gpi ,
436- "gmm_covariance_type_joint" : gmm_covariance_type_joint ,
437- "gmm_random_state" : gmm_random_state ,
438429 "disable_progress_bar" : disable_progress_bar ,
439430 }
440431
@@ -450,7 +441,14 @@ def predict_clusterless_gmm_log_likelihood(
450441 position : jnp .ndarray ,
451442 spike_times : list [jnp .ndarray ],
452443 spike_waveform_features : list [jnp .ndarray ],
453- encoding_model : dict ,
444+ environment : Environment ,
445+ occupancy_model : GaussianMixtureModel ,
446+ interior_place_bin_centers : jnp .ndarray ,
447+ log_occupancy_bins : jnp .ndarray ,
448+ gpi_models : list [GaussianMixtureModel ],
449+ joint_models : list [GaussianMixtureModel ],
450+ mean_rates : jnp .ndarray ,
451+ summed_ground_process_intensity : jnp .ndarray ,
454452 is_local : bool = False ,
455453 spike_block_size : int = 1000 ,
456454 bin_tile_size : int | None = None ,
@@ -504,22 +502,20 @@ def predict_clusterless_gmm_log_likelihood(
504502 position = position ,
505503 spike_times = spike_times ,
506504 spike_waveform_features = spike_waveform_features ,
507- encoding_model = encoding_model ,
505+ environment = environment ,
506+ occupancy_model = occupancy_model ,
507+ gpi_models = gpi_models ,
508+ joint_models = joint_models ,
509+ mean_rates = mean_rates ,
508510 disable_progress_bar = disable_progress_bar ,
509511 )
510512
511- bin_centers = encoding_model ["interior_place_bin_centers" ]
512- log_occ_bins = encoding_model ["log_occupancy_bins" ] # log density
513- mean_rates = encoding_model ["mean_rates" ]
514- joint_models = encoding_model ["joint_models" ]
515- summed_ground = encoding_model ["summed_ground_process_intensity" ]
516-
517513 n_time = time .shape [0 ]
518- n_bins = bin_centers .shape [0 ]
514+ n_bins = interior_place_bin_centers .shape [0 ]
519515
520516 # Start with the expected-counts (ground process) term, broadcast over time
521517 log_likelihood = (
522- (- summed_ground ).reshape (1 , - 1 ).repeat (n_time , axis = 0 )
518+ (- summed_ground_process_intensity ).reshape (1 , - 1 ).repeat (n_time , axis = 0 )
523519 ) # (n_time, n_bins)
524520
525521 # Per-electrode contributions in log-space
@@ -617,7 +613,7 @@ def _update_block_one_tile(
617613
618614 if bin_tile_size is None or bin_tile_size >= n_bins :
619615 # No bin tiling: process all bins at once (default)
620- tiled_bins = jnp .tile (bin_centers , (block_size , 1 ))
616+ tiled_bins = jnp .tile (interior_place_bin_centers , (block_size , 1 ))
621617 repeated_feats = jnp .repeat (block_feats , n_bins , axis = 0 )
622618 eval_points = jnp .concatenate ([tiled_bins , repeated_feats ], axis = 1 )
623619
@@ -631,7 +627,7 @@ def _update_block_one_tile(
631627 joint_logp_block ,
632628 block_seg_ids ,
633629 log_mean_rate ,
634- log_occ_bins ,
630+ log_occupancy_bins ,
635631 n_time ,
636632 )
637633 else :
@@ -643,7 +639,7 @@ def _update_block_one_tile(
643639
644640 # Build eval points for this tile
645641 tiled_bins_tile = jnp .tile (
646- bin_centers [bin_start :bin_end ], (block_size , 1 )
642+ interior_place_bin_centers [bin_start :bin_end ], (block_size , 1 )
647643 )
648644 repeated_feats_tile = jnp .repeat (block_feats , n_tile , axis = 0 )
649645 eval_points_tile = jnp .concatenate (
@@ -661,7 +657,7 @@ def _update_block_one_tile(
661657 joint_logp_tile ,
662658 block_seg_ids ,
663659 log_mean_rate ,
664- log_occ_bins [bin_start :bin_end ],
660+ log_occupancy_bins [bin_start :bin_end ],
665661 n_time ,
666662 )
667663
@@ -674,7 +670,11 @@ def compute_local_log_likelihood(
674670 position : jnp .ndarray ,
675671 spike_times : list [jnp .ndarray ],
676672 spike_waveform_features : list [jnp .ndarray ],
677- encoding_model : dict ,
673+ environment : Environment ,
674+ occupancy_model : GaussianMixtureModel ,
675+ gpi_models : list [GaussianMixtureModel ],
676+ joint_models : list [GaussianMixtureModel ],
677+ mean_rates : jnp .ndarray ,
678678 disable_progress_bar : bool = False ,
679679) -> jnp .ndarray :
680680 """Local log-likelihood at the animal's interpolated position.
@@ -709,19 +709,13 @@ def compute_local_log_likelihood(
709709 position_time = np .asarray (position_time )
710710 position = _as_jnp (position if position .ndim > 1 else position [:, None ])
711711
712- env = encoding_model ["environment" ]
713- occupancy_model = encoding_model ["occupancy_model" ]
714- gpi_models = encoding_model ["gpi_models" ]
715- joint_models = encoding_model ["joint_models" ]
716- mean_rates = encoding_model ["mean_rates" ]
717-
718712 n_time = time .shape [0 ] - 1
719713
720714 # Interpolate position at bin times (use bin centers)
721715 # We'll take the midpoints as "time of interest" for local evaluation
722716 t_centers = 0.5 * (time [:- 1 ] + time [1 :])
723717 interp_pos = get_position_at_time (
724- position_time , np .asarray (position ), t_centers , env
718+ position_time , np .asarray (position ), t_centers , environment
725719 ) # (n_time, pos_dims)
726720
727721 # Occupancy density and its log at the animal's position
@@ -753,7 +747,7 @@ def compute_local_log_likelihood(
753747 # Spike contributions at their true positions
754748 if elect_times .shape [0 ] > 0 :
755749 pos_at_spike_time = get_position_at_time (
756- position_time , np .asarray (position ), elect_times , env
750+ position_time , np .asarray (position ), elect_times , environment
757751 ) # (n_spikes, pos_dims)
758752 eval_points = jnp .concatenate (
759753 [pos_at_spike_time , elect_feats ], axis = 1
0 commit comments