Skip to content

Commit 68d7f79

Browse files
committed
Refactor GMM encoding model fitting logic
Set weights to None and remove unused occupancy weights logic. Enforce strict zipping of spike features and times, streamline electrode feature selection, and simplify expected-counts calculation for clarity and correctness.
1 parent 959accd commit 68d7f79

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def fit_clusterless_gmm_encoding_model(
273273
# Keep as numpy for interpolation (scipy.interpolate.interpn requires numpy anyway).
274274
position_time = np.asarray(position_time)
275275

276+
# Ignore weights for now
277+
weights = None
278+
276279
# Interior bins (cached)
277280
if environment.is_track_interior_ is not None:
278281
is_track_interior = environment.is_track_interior_.ravel()
@@ -283,12 +286,6 @@ def fit_clusterless_gmm_encoding_model(
283286
)
284287
is_track_interior = jnp.ones(len(environment.place_bin_centers_), dtype=bool)
285288

286-
# Occupancy weights over trajectory
287-
if weights is None:
288-
weights = jnp.ones((position.shape[0],), dtype=position.dtype)
289-
else:
290-
weights = _as_jnp(weights)
291-
292289
# If environment has a graph and positions are 2D+, linearize to 1D for occupancy/GPI
293290
if environment.track_graph is not None and position.shape[1] > 1:
294291
position1D = get_linearized_position(
@@ -324,7 +321,7 @@ def fit_clusterless_gmm_encoding_model(
324321

325322
# Fit per-electrode models
326323
for elect_feats, elect_times in tqdm(
327-
zip(spike_waveform_features, spike_times, strict=False),
324+
zip(spike_waveform_features, spike_times, strict=True),
328325
desc="Encoding models (GMM)",
329326
unit="electrode",
330327
disable=disable_progress_bar,
@@ -335,12 +332,7 @@ def fit_clusterless_gmm_encoding_model(
335332
elect_times >= position_time[0], elect_times <= position_time[-1]
336333
)
337334
elect_times = elect_times[in_bounds]
338-
elect_feats = elect_feats[in_bounds]
339-
elect_feats = _as_jnp(elect_feats)
340-
341-
# Skip electrodes with no spikes in encoding window
342-
if elect_times.shape[0] == 0:
343-
continue
335+
elect_feats = _as_jnp(elect_feats[in_bounds])
344336

345337
# Mean firing rate
346338
mean_rate = float(len(elect_times) / n_time_bins)
@@ -376,8 +368,7 @@ def fit_clusterless_gmm_encoding_model(
376368
# Expected-counts term at bins: mean_rate * (gpi / occupancy)
377369
gpi_density = _gmm_density(gpi_gmm, interior_place_bin_centers)
378370
summed_ground_process_intensity += jnp.clip(
379-
mean_rates[-1]
380-
* safe_divide(gpi_density, occupancy, condition=occupancy > EPS),
371+
mean_rate * jnp.where(occupancy > 0.0, gpi_density / occupancy, EPS),
381372
a_min=EPS,
382373
a_max=None,
383374
)

0 commit comments

Comments
 (0)