Skip to content

Commit 58ae32a

Browse files
committed
Expand predictive output options and update tests
Added 'predictive_posterior' as a valid output option for model results, updated output normalization logic, and clarified documentation for output choices and memory usage. Refactored and extended tests to verify correct behavior for new and existing output combinations, and improved code formatting and assertion messages for clarity and consistency.
1 parent 096471e commit 58ae32a

File tree

9 files changed

+183
-89
lines changed

9 files changed

+183
-89
lines changed

src/non_local_detector/likelihoods/clusterless_gmm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
from non_local_detector.environment import Environment
1717
from non_local_detector.likelihoods.common import (
1818
EPS,
19-
LOG_EPS,
2019
get_position_at_time,
2120
get_spike_time_bin_ind,
22-
safe_divide,
2321
)
2422
from non_local_detector.likelihoods.gmm import GaussianMixtureModel
2523

@@ -326,7 +324,6 @@ def fit_clusterless_gmm_encoding_model(
326324
unit="electrode",
327325
disable=disable_progress_bar,
328326
):
329-
330327
# Clip to encoding window
331328
in_bounds = np.logical_and(
332329
elect_times >= position_time[0], elect_times <= position_time[-1]
@@ -483,7 +480,6 @@ def predict_clusterless_gmm_log_likelihood(
483480
unit="electrode",
484481
disable=disable_progress_bar,
485482
):
486-
487483
# Clip to decoding window
488484
in_bounds = np.logical_and(elect_times >= time[0], elect_times <= time[-1])
489485
elect_times = elect_times[in_bounds]

src/non_local_detector/models/base.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,21 @@
6262
}
6363

6464
# Valid options for return_outputs parameter
65-
VALID_OUTPUTS: set[str] = {"filter", "predictive", "log_likelihood", "all"}
65+
VALID_OUTPUTS: set[str] = {
66+
"filter",
67+
"predictive",
68+
"predictive_posterior",
69+
"log_likelihood",
70+
"all",
71+
}
6672

6773
# Mapping of single string options to sets of outputs
6874
OUTPUT_INCLUDES: dict[str, set[str]] = {
6975
"filter": {"filter"},
70-
"predictive": {"predictive"},
76+
"predictive": {"predictive", "predictive_posterior"},
77+
"predictive_posterior": {"predictive_posterior"},
7178
"log_likelihood": {"log_likelihood"},
72-
"all": {"filter", "predictive", "log_likelihood"},
79+
"all": {"filter", "predictive", "predictive_posterior", "log_likelihood"},
7380
}
7481

7582

@@ -86,7 +93,8 @@ def _normalize_return_outputs(
8693
Returns
8794
-------
8895
set of str
89-
Normalized set containing any of: 'filter', 'predictive', 'log_likelihood'
96+
Normalized set containing any of: 'filter', 'predictive',
97+
'predictive_posterior', 'log_likelihood'
9098
9199
Raises
92100
------
@@ -2233,22 +2241,24 @@ def predict(
22332241
Options:
22342242
- None: smoother only (default, minimal memory)
22352243
- 'filter': filtered (causal) posterior and state probabilities
2236-
- 'predictive': one-step-ahead predictive state distributions
2244+
- 'predictive': both aggregated and full predictive distributions
2245+
- 'predictive_posterior': only full predictive posterior (state bins)
22372246
- 'log_likelihood': per-timepoint log likelihoods
22382247
- 'all': all outputs above
2239-
- List/set: e.g., ['filter', 'predictive'] for multiple outputs
2248+
- List/set: e.g., ['filter', 'log_likelihood'] for multiple outputs
22402249
22412250
The smoother (acausal_posterior, acausal_state_probabilities) and
22422251
marginal_log_likelihood are ALWAYS included.
22432252
22442253
When to use each output:
22452254
- 'filter': Online/causal decoding, debugging forward pass
2246-
- 'predictive': Model evaluation, predictive checks
2255+
- 'predictive': Model evaluation, predictive checks (includes both formats)
2256+
- 'predictive_posterior': When you only need full distribution, not aggregated
22472257
- 'log_likelihood': Diagnostics, per-timepoint metrics, model comparison
22482258
2249-
Memory warning: 'log_likelihood' and 'filter' can be very large
2250-
(~400 GB for 1M timepoints × 100k spatial bins). Only request
2251-
what you need for your analysis.
2259+
Memory warning: 'log_likelihood', 'filter', 'predictive', and
2260+
'predictive_posterior' can be very large (~400 GB for 1M timepoints × 100k
2261+
spatial bins). Use None for minimal memory (smoother only).
22522262
save_log_likelihood_to_results : bool, optional
22532263
DEPRECATED. Use return_outputs='log_likelihood' instead.
22542264
Whether to save the log likelihood to the results, by default None.
@@ -2276,8 +2286,9 @@ def predict(
22762286
Filtered discrete state probabilities
22772287
- predictive_state_probabilities : (n_time, n_states) - if 'predictive'
22782288
One-step-ahead predictive distributions over discrete states
2279-
- predictive_posterior : (n_time, n_state_bins) - if 'predictive'
2289+
- predictive_posterior : (n_time, n_state_bins) - if 'predictive_posterior'
22802290
One-step-ahead predictive distributions over state bins
2291+
(Warning: can be very large, ~same size as causal_posterior)
22812292
- log_likelihood : (n_time, n_state_bins) - if 'log_likelihood'
22822293
Per-timepoint observation log likelihoods
22832294
@@ -2406,7 +2417,9 @@ def predict(
24062417
else None
24072418
),
24082419
predictive_posterior=(
2409-
predictive_posterior if "predictive" in requested_outputs else None
2420+
predictive_posterior
2421+
if "predictive_posterior" in requested_outputs
2422+
else None
24102423
),
24112424
)
24122425

@@ -3087,7 +3100,9 @@ def predict(
30873100
else None
30883101
),
30893102
predictive_posterior=(
3090-
predictive_posterior if "predictive" in requested_outputs else None
3103+
predictive_posterior
3104+
if "predictive_posterior" in requested_outputs
3105+
else None
30913106
),
30923107
)
30933108

src/non_local_detector/tests/likelihoods/test_clusterless_kde_optimization.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ def test_numerical_equivalence_high_dim(self):
3737
result_original = kde_distance(eval_points, samples, std)
3838
result_vectorized = kde_distance_vectorized(eval_points, samples, std)
3939

40-
assert jnp.allclose(result_original, result_vectorized, rtol=1e-5, atol=1e-8), \
41-
f"Failed for {n_features}D"
40+
assert jnp.allclose(
41+
result_original, result_vectorized, rtol=1e-5, atol=1e-8
42+
), f"Failed for {n_features}D"
4243

4344
def test_numerical_stability(self):
4445
"""Test numerical stability with small std values."""
@@ -198,8 +199,12 @@ def test_jit_compilation_caching(self):
198199

199200
# Warmup call
200201
result_warmup = estimate_log_joint_mark_intensity(
201-
dec_features, enc_features, waveform_stds,
202-
occupancy, mean_rate, position_distance
202+
dec_features,
203+
enc_features,
204+
waveform_stds,
205+
occupancy,
206+
mean_rate,
207+
position_distance,
203208
)
204209
result_warmup.block_until_ready()
205210

@@ -208,24 +213,37 @@ def test_jit_compilation_caching(self):
208213
for _ in range(10):
209214
start = time.perf_counter()
210215
result = estimate_log_joint_mark_intensity(
211-
dec_features, enc_features, waveform_stds,
212-
occupancy, mean_rate, position_distance
216+
dec_features,
217+
enc_features,
218+
waveform_stds,
219+
occupancy,
220+
mean_rate,
221+
position_distance,
213222
)
214223
result.block_until_ready()
215224
times.append(time.perf_counter() - start)
216225

217226
# All calls should be fast (< 10ms) after compilation
218227
avg_time = np.mean(times)
219-
assert avg_time < 0.01, \
228+
assert avg_time < 0.01, (
220229
f"Average call time ({avg_time:.4f}s) too slow, JIT may not be working"
230+
)
221231

222232
# Results should be consistent
223233
result1 = estimate_log_joint_mark_intensity(
224-
dec_features, enc_features, waveform_stds,
225-
occupancy, mean_rate, position_distance
234+
dec_features,
235+
enc_features,
236+
waveform_stds,
237+
occupancy,
238+
mean_rate,
239+
position_distance,
226240
)
227241
result2 = estimate_log_joint_mark_intensity(
228-
dec_features, enc_features, waveform_stds,
229-
occupancy, mean_rate, position_distance
242+
dec_features,
243+
enc_features,
244+
waveform_stds,
245+
occupancy,
246+
mean_rate,
247+
position_distance,
230248
)
231249
assert jnp.allclose(result1, result2)

src/non_local_detector/tests/likelihoods/test_enc_tile_size.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ def test_enc_tile_size_with_pos_tile_size(enc_tile_size, pos_tile_size):
148148

149149
# Should match
150150
max_diff = np.max(np.abs(result_baseline - result_both_tiling))
151-
assert np.allclose(
152-
result_baseline, result_both_tiling, rtol=1e-5, atol=1e-7
153-
), f"enc_tile_size={enc_tile_size}, pos_tile_size={pos_tile_size}: Max diff = {max_diff}"
151+
assert np.allclose(result_baseline, result_both_tiling, rtol=1e-5, atol=1e-7), (
152+
f"enc_tile_size={enc_tile_size}, pos_tile_size={pos_tile_size}: Max diff = {max_diff}"
153+
)
154154

155155

156156
@pytest.mark.unit

src/non_local_detector/tests/likelihoods/test_gmm_kde_convergence.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,24 @@ def normalize_per_time(ll):
170170
peak_gmm = np.argmax(ll_gmm, axis=1)
171171
peak_agreement = np.mean(peak_kde == peak_gmm)
172172

173-
print(f"{n_comp:4d} | {corr:11.4f} | {mse:16.6f} | {peak_agreement:14.1%}")
173+
print(
174+
f"{n_comp:4d} | {corr:11.4f} | {mse:16.6f} | {peak_agreement:14.1%}"
175+
)
174176

175177
# Verify convergence trend
176178
print("\n=== Convergence Analysis ===")
177179
print(f"Correlation improvement: {correlations[0]:.4f}{correlations[-1]:.4f}")
178180
print(f"MSE improvement: {mse_values[0]:.6f}{mse_values[-1]:.6f}")
179181

180182
# Key assertion: correlation should increase with more components
181-
assert correlations[-1] > correlations[0], \
183+
assert correlations[-1] > correlations[0], (
182184
f"Correlation should increase: {correlations[0]:.3f}{correlations[-1]:.3f}"
185+
)
183186

184187
# MSE should decrease
185-
assert mse_values[-1] < mse_values[0], \
188+
assert mse_values[-1] < mse_values[0], (
186189
f"MSE should decrease: {mse_values[0]:.4f}{mse_values[-1]:.4f}"
190+
)
187191

188192

189193
def test_mathematical_formula_consistency(convergence_test_data):
@@ -219,12 +223,17 @@ def test_mathematical_formula_consistency(convergence_test_data):
219223
gmm_formula = np.log(mean_rate) + np.log(marginal_density) - np.log(occupancy)
220224

221225
print("\n=== Formula Verification ===")
222-
print(f"KDE: log({mean_rate} * {marginal_density} / {occupancy}) = {kde_formula:.6f}")
223-
print(f"GMM: log({mean_rate}) + log({marginal_density}) - log({occupancy}) = {gmm_formula:.6f}")
226+
print(
227+
f"KDE: log({mean_rate} * {marginal_density} / {occupancy}) = {kde_formula:.6f}"
228+
)
229+
print(
230+
f"GMM: log({mean_rate}) + log({marginal_density}) - log({occupancy}) = {gmm_formula:.6f}"
231+
)
224232
print(f"Difference: {abs(kde_formula - gmm_formula):.10f}")
225233

226-
assert np.isclose(kde_formula, gmm_formula, rtol=1e-10), \
234+
assert np.isclose(kde_formula, gmm_formula, rtol=1e-10), (
227235
"KDE and GMM formulas should be mathematically identical"
236+
)
228237

229238

230239
def test_ground_process_intensity_calculation(convergence_test_data):
@@ -298,18 +307,25 @@ def test_segment_sum_correctness(convergence_test_data):
298307
segment_ids = jnp.array([0, 0, 1, 1, 2])
299308

300309
# KDE way
301-
result_kde = jax.ops.segment_sum(values, segment_ids, num_segments=3, indices_are_sorted=True)
310+
result_kde = jax.ops.segment_sum(
311+
values, segment_ids, num_segments=3, indices_are_sorted=True
312+
)
302313

303314
# GMM way (from jax.ops import segment_sum)
304315
from jax.ops import segment_sum
305-
result_gmm = segment_sum(values, segment_ids, num_segments=3, indices_are_sorted=True)
316+
317+
result_gmm = segment_sum(
318+
values, segment_ids, num_segments=3, indices_are_sorted=True
319+
)
306320

307321
print("\n=== segment_sum Verification ===")
308322
print(f"Input: values={values}, segment_ids={segment_ids}")
309323
print(f"KDE result: {result_kde}")
310324
print(f"GMM result: {result_gmm}")
311325

312-
assert jnp.allclose(result_kde, result_gmm), "segment_sum implementations should match"
326+
assert jnp.allclose(result_kde, result_gmm), (
327+
"segment_sum implementations should match"
328+
)
313329

314330

315331
def test_log_space_operations(convergence_test_data):

src/non_local_detector/tests/likelihoods/test_kde_gmm_comparison.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ def shared_simulation_data():
6666

6767
# Decoding period spikes (subset of encoding spikes for simplicity)
6868
decoding_spike_times = [times[:20] for times in encoding_spike_times]
69-
decoding_spike_features = [
70-
feats[:20] for feats in encoding_spike_features
71-
]
69+
decoding_spike_features = [feats[:20] for feats in encoding_spike_features]
7270

7371
# Create and fit environment
7472
environment = Environment(position_range=[(0, 10), (-3, 3)])
@@ -159,9 +157,7 @@ def test_kde_end_to_end_pipeline(shared_simulation_data):
159157
encoding_positions=kde_encoding["encoding_positions"],
160158
environment=data["environment"],
161159
mean_rates=jnp.asarray(kde_encoding["mean_rates"]),
162-
summed_ground_process_intensity=kde_encoding[
163-
"summed_ground_process_intensity"
164-
],
160+
summed_ground_process_intensity=kde_encoding["summed_ground_process_intensity"],
165161
position_std=jnp.asarray(kde_encoding["position_std"]),
166162
waveform_std=jnp.asarray(kde_encoding["waveform_std"]),
167163
is_local=False,
@@ -191,9 +187,7 @@ def test_kde_end_to_end_pipeline(shared_simulation_data):
191187
encoding_positions=kde_encoding["encoding_positions"],
192188
environment=data["environment"],
193189
mean_rates=jnp.asarray(kde_encoding["mean_rates"]),
194-
summed_ground_process_intensity=kde_encoding[
195-
"summed_ground_process_intensity"
196-
],
190+
summed_ground_process_intensity=kde_encoding["summed_ground_process_intensity"],
197191
position_std=jnp.asarray(kde_encoding["position_std"]),
198192
waveform_std=jnp.asarray(kde_encoding["waveform_std"]),
199193
is_local=True,
@@ -404,9 +398,7 @@ def test_api_consistency_predict_functions(shared_simulation_data):
404398
encoding_positions=kde_encoding["encoding_positions"],
405399
environment=data["environment"],
406400
mean_rates=kde_encoding["mean_rates"],
407-
summed_ground_process_intensity=kde_encoding[
408-
"summed_ground_process_intensity"
409-
],
401+
summed_ground_process_intensity=kde_encoding["summed_ground_process_intensity"],
410402
position_std=kde_encoding["position_std"],
411403
waveform_std=kde_encoding["waveform_std"],
412404
is_local=False,
@@ -454,7 +446,10 @@ def test_kde_gmm_output_shape_consistency(shared_simulation_data):
454446

455447
kde_enc = fit_clusterless_kde_encoding_model(**common_params, position_std=1.0)
456448
gmm_enc = fit_clusterless_gmm_encoding_model(
457-
**common_params, gmm_components_occupancy=4, gmm_components_gpi=4, gmm_components_joint=8
449+
**common_params,
450+
gmm_components_occupancy=4,
451+
gmm_components_gpi=4,
452+
gmm_components_joint=8,
458453
)
459454

460455
# Predict with both
@@ -522,7 +517,10 @@ def test_both_support_local_and_nonlocal_modes(shared_simulation_data):
522517
# Fit models
523518
kde_enc = fit_clusterless_kde_encoding_model(**common_params, position_std=1.0)
524519
gmm_enc = fit_clusterless_gmm_encoding_model(
525-
**common_params, gmm_components_occupancy=4, gmm_components_gpi=4, gmm_components_joint=8
520+
**common_params,
521+
gmm_components_occupancy=4,
522+
gmm_components_gpi=4,
523+
gmm_components_joint=8,
526524
)
527525

528526
time = jnp.asarray(data["time"])

0 commit comments

Comments
 (0)