Skip to content

Commit 15f3b2b

Browse files
committed
Refactor GMM likelihood functions to use explicit arguments
Replaces the use of a single encoding_model dictionary with explicit function arguments for environment, GMM models, and related parameters in likelihood computation functions. This improves clarity, type safety, and reduces reliance on dictionary key lookups.
1 parent 6b60f65 commit 15f3b2b

File tree

1 file changed

+26
-32
lines changed

1 file changed

+26
-32
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -426,15 +426,6 @@ def fit_clusterless_gmm_encoding_model(
426426
"joint_models": joint_models,
427427
"mean_rates": jnp.asarray(mean_rates),
428428
"summed_ground_process_intensity": summed_ground_process_intensity,
429-
# NOTE: Don't store position_time - it's passed as a separate argument to predict()
430-
# Storing it causes "multiple values for argument 'position_time'" error
431-
"gmm_components_occupancy": gmm_components_occupancy,
432-
"gmm_components_gpi": gmm_components_gpi,
433-
"gmm_components_joint": gmm_components_joint,
434-
"gmm_covariance_type_occupancy": gmm_covariance_type_occupancy,
435-
"gmm_covariance_type_gpi": gmm_covariance_type_gpi,
436-
"gmm_covariance_type_joint": gmm_covariance_type_joint,
437-
"gmm_random_state": gmm_random_state,
438429
"disable_progress_bar": disable_progress_bar,
439430
}
440431

@@ -450,7 +441,14 @@ def predict_clusterless_gmm_log_likelihood(
450441
position: jnp.ndarray,
451442
spike_times: list[jnp.ndarray],
452443
spike_waveform_features: list[jnp.ndarray],
453-
encoding_model: dict,
444+
environment: Environment,
445+
occupancy_model: GaussianMixtureModel,
446+
interior_place_bin_centers: jnp.ndarray,
447+
log_occupancy_bins: jnp.ndarray,
448+
gpi_models: list[GaussianMixtureModel],
449+
joint_models: list[GaussianMixtureModel],
450+
mean_rates: jnp.ndarray,
451+
summed_ground_process_intensity: jnp.ndarray,
454452
is_local: bool = False,
455453
spike_block_size: int = 1000,
456454
bin_tile_size: int | None = None,
@@ -504,22 +502,20 @@ def predict_clusterless_gmm_log_likelihood(
504502
position=position,
505503
spike_times=spike_times,
506504
spike_waveform_features=spike_waveform_features,
507-
encoding_model=encoding_model,
505+
environment=environment,
506+
occupancy_model=occupancy_model,
507+
gpi_models=gpi_models,
508+
joint_models=joint_models,
509+
mean_rates=mean_rates,
508510
disable_progress_bar=disable_progress_bar,
509511
)
510512

511-
bin_centers = encoding_model["interior_place_bin_centers"]
512-
log_occ_bins = encoding_model["log_occupancy_bins"] # log density
513-
mean_rates = encoding_model["mean_rates"]
514-
joint_models = encoding_model["joint_models"]
515-
summed_ground = encoding_model["summed_ground_process_intensity"]
516-
517513
n_time = time.shape[0]
518-
n_bins = bin_centers.shape[0]
514+
n_bins = interior_place_bin_centers.shape[0]
519515

520516
# Start with the expected-counts (ground process) term, broadcast over time
521517
log_likelihood = (
522-
(-summed_ground).reshape(1, -1).repeat(n_time, axis=0)
518+
(-summed_ground_process_intensity).reshape(1, -1).repeat(n_time, axis=0)
523519
) # (n_time, n_bins)
524520

525521
# Per-electrode contributions in log-space
@@ -617,7 +613,7 @@ def _update_block_one_tile(
617613

618614
if bin_tile_size is None or bin_tile_size >= n_bins:
619615
# No bin tiling: process all bins at once (default)
620-
tiled_bins = jnp.tile(bin_centers, (block_size, 1))
616+
tiled_bins = jnp.tile(interior_place_bin_centers, (block_size, 1))
621617
repeated_feats = jnp.repeat(block_feats, n_bins, axis=0)
622618
eval_points = jnp.concatenate([tiled_bins, repeated_feats], axis=1)
623619

@@ -631,7 +627,7 @@ def _update_block_one_tile(
631627
joint_logp_block,
632628
block_seg_ids,
633629
log_mean_rate,
634-
log_occ_bins,
630+
log_occupancy_bins,
635631
n_time,
636632
)
637633
else:
@@ -643,7 +639,7 @@ def _update_block_one_tile(
643639

644640
# Build eval points for this tile
645641
tiled_bins_tile = jnp.tile(
646-
bin_centers[bin_start:bin_end], (block_size, 1)
642+
interior_place_bin_centers[bin_start:bin_end], (block_size, 1)
647643
)
648644
repeated_feats_tile = jnp.repeat(block_feats, n_tile, axis=0)
649645
eval_points_tile = jnp.concatenate(
@@ -661,7 +657,7 @@ def _update_block_one_tile(
661657
joint_logp_tile,
662658
block_seg_ids,
663659
log_mean_rate,
664-
log_occ_bins[bin_start:bin_end],
660+
log_occupancy_bins[bin_start:bin_end],
665661
n_time,
666662
)
667663

@@ -674,7 +670,11 @@ def compute_local_log_likelihood(
674670
position: jnp.ndarray,
675671
spike_times: list[jnp.ndarray],
676672
spike_waveform_features: list[jnp.ndarray],
677-
encoding_model: dict,
673+
environment: Environment,
674+
occupancy_model: GaussianMixtureModel,
675+
gpi_models: list[GaussianMixtureModel],
676+
joint_models: list[GaussianMixtureModel],
677+
mean_rates: jnp.ndarray,
678678
disable_progress_bar: bool = False,
679679
) -> jnp.ndarray:
680680
"""Local log-likelihood at the animal's interpolated position.
@@ -709,19 +709,13 @@ def compute_local_log_likelihood(
709709
position_time = np.asarray(position_time)
710710
position = _as_jnp(position if position.ndim > 1 else position[:, None])
711711

712-
env = encoding_model["environment"]
713-
occupancy_model = encoding_model["occupancy_model"]
714-
gpi_models = encoding_model["gpi_models"]
715-
joint_models = encoding_model["joint_models"]
716-
mean_rates = encoding_model["mean_rates"]
717-
718712
n_time = time.shape[0] - 1
719713

720714
# Interpolate position at bin times (use bin centers)
721715
# We'll take the midpoints as "time of interest" for local evaluation
722716
t_centers = 0.5 * (time[:-1] + time[1:])
723717
interp_pos = get_position_at_time(
724-
position_time, np.asarray(position), t_centers, env
718+
position_time, np.asarray(position), t_centers, environment
725719
) # (n_time, pos_dims)
726720

727721
# Occupancy density and its log at the animal's position
@@ -753,7 +747,7 @@ def compute_local_log_likelihood(
753747
# Spike contributions at their true positions
754748
if elect_times.shape[0] > 0:
755749
pos_at_spike_time = get_position_at_time(
756-
position_time, np.asarray(position), elect_times, env
750+
position_time, np.asarray(position), elect_times, environment
757751
) # (n_spikes, pos_dims)
758752
eval_points = jnp.concatenate(
759753
[pos_at_spike_time, elect_feats], axis=1

0 commit comments

Comments
 (0)