@@ -273,6 +273,9 @@ def fit_clusterless_gmm_encoding_model(
273273 # Keep as numpy for interpolation (scipy.interpolate.interpn requires numpy anyway).
274274 position_time = np .asarray (position_time )
275275
276+ # Ignore weights for now
277+ weights = None
278+
276279 # Interior bins (cached)
277280 if environment .is_track_interior_ is not None :
278281 is_track_interior = environment .is_track_interior_ .ravel ()
@@ -283,12 +286,6 @@ def fit_clusterless_gmm_encoding_model(
283286 )
284287 is_track_interior = jnp .ones (len (environment .place_bin_centers_ ), dtype = bool )
285288
286- # Occupancy weights over trajectory
287- if weights is None :
288- weights = jnp .ones ((position .shape [0 ],), dtype = position .dtype )
289- else :
290- weights = _as_jnp (weights )
291-
292289 # If environment has a graph and positions are 2D+, linearize to 1D for occupancy/GPI
293290 if environment .track_graph is not None and position .shape [1 ] > 1 :
294291 position1D = get_linearized_position (
@@ -324,7 +321,7 @@ def fit_clusterless_gmm_encoding_model(
324321
325322 # Fit per-electrode models
326323 for elect_feats , elect_times in tqdm (
327- zip (spike_waveform_features , spike_times , strict = False ),
324+ zip (spike_waveform_features , spike_times , strict = True ),
328325 desc = "Encoding models (GMM)" ,
329326 unit = "electrode" ,
330327 disable = disable_progress_bar ,
@@ -335,12 +332,7 @@ def fit_clusterless_gmm_encoding_model(
335332 elect_times >= position_time [0 ], elect_times <= position_time [- 1 ]
336333 )
337334 elect_times = elect_times [in_bounds ]
338- elect_feats = elect_feats [in_bounds ]
339- elect_feats = _as_jnp (elect_feats )
340-
341- # Skip electrodes with no spikes in encoding window
342- if elect_times .shape [0 ] == 0 :
343- continue
335+ elect_feats = _as_jnp (elect_feats [in_bounds ])
344336
345337 # Mean firing rate
346338 mean_rate = float (len (elect_times ) / n_time_bins )
@@ -376,8 +368,7 @@ def fit_clusterless_gmm_encoding_model(
376368 # Expected-counts term at bins: mean_rate * (gpi / occupancy)
377369 gpi_density = _gmm_density (gpi_gmm , interior_place_bin_centers )
378370 summed_ground_process_intensity += jnp .clip (
379- mean_rates [- 1 ]
380- * safe_divide (gpi_density , occupancy , condition = occupancy > EPS ),
371+ mean_rate * jnp .where (occupancy > 0.0 , gpi_density / occupancy , EPS ),
381372 a_min = EPS ,
382373 a_max = None ,
383374 )
0 commit comments