Skip to content

Commit 096471e

Browse files
committed
Add predictive_posterior output to filter/smoother
Extended core filter/smoother functions and model base classes to compute and return predictive_posterior (one-step-ahead predicted state probabilities over state bins). Updated xarray dataset conversion and output handling to include predictive_posterior. Adjusted tests to verify presence and correctness of predictive_posterior in results.
1 parent fb5da73 commit 096471e

File tree

4 files changed

+85
-9
lines changed

4 files changed

+85
-9
lines changed

src/non_local_detector/core.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,14 @@ def chunked_filter_smoother(
263263
cache_log_likelihoods: bool = True,
264264
dtype: jnp.dtype = jnp.float32,
265265
) -> tuple[
266-
np.ndarray, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray
266+
np.ndarray,
267+
np.ndarray,
268+
float,
269+
np.ndarray,
270+
np.ndarray,
271+
np.ndarray,
272+
np.ndarray,
273+
np.ndarray,
267274
]:
268275
"""Filter and smooth the state probabilities in chunks.
269276
@@ -302,8 +309,11 @@ def chunked_filter_smoother(
302309
Log likelihoods for each state at each time point
303310
causal_posterior : np.ndarray, shape (n_time, n_state_bins)
304311
Filtered state probabilities
312+
predictive_posterior : np.ndarray, shape (n_time, n_state_bins)
313+
One-step-ahead predicted state probabilities over state bins
305314
"""
306315
causal_posterior = []
316+
predictive_posterior = []
307317
predictive_state_probabilities = []
308318
causal_state_probabilities = []
309319
acausal_posterior = []
@@ -394,6 +404,7 @@ def chunked_filter_smoother(
394404
causal_state_probabilities.append(
395405
causal_posterior_chunk @ state_aggregation_matrix
396406
)
407+
predictive_posterior.append(predicted_probs_chunk)
397408
predictive_state_probabilities.append(
398409
predicted_probs_chunk @ state_aggregation_matrix
399410
)
@@ -403,6 +414,7 @@ def chunked_filter_smoother(
403414
# Concatenate JAX arrays on device
404415
causal_posterior_jax = jnp.concatenate(causal_posterior)
405416
causal_state_probabilities_jax = jnp.concatenate(causal_state_probabilities)
417+
predictive_posterior_jax = jnp.concatenate(predictive_posterior)
406418
predictive_state_probabilities_jax = jnp.concatenate(predictive_state_probabilities)
407419

408420
# Backward pass: accumulate JAX arrays
@@ -439,6 +451,7 @@ def chunked_filter_smoother(
439451
np.asarray(predictive_state_probabilities_jax),
440452
log_likelihoods, # Keep as original (may be None or NumPy)
441453
np.asarray(causal_posterior_jax),
454+
np.asarray(predictive_posterior_jax),
442455
)
443456

444457

@@ -751,7 +764,14 @@ def chunked_filter_smoother_covariate_dependent(
751764
cache_log_likelihoods: bool = True,
752765
dtype: jnp.dtype = jnp.float32,
753766
) -> tuple[
754-
np.ndarray, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray
767+
np.ndarray,
768+
np.ndarray,
769+
float,
770+
np.ndarray,
771+
np.ndarray,
772+
np.ndarray,
773+
np.ndarray,
774+
np.ndarray,
755775
]:
756776
"""Filter and smooth the state probabilities in chunks with covariate dependent transitions.
757777
@@ -790,8 +810,11 @@ def chunked_filter_smoother_covariate_dependent(
790810
Log likelihoods for each state at each time point
791811
causal_posterior : np.ndarray, shape (n_time, n_state_bins)
792812
Filtered state probabilities
813+
predictive_posterior : np.ndarray, shape (n_time, n_state_bins)
814+
One-step-ahead predicted state probabilities over state bins
793815
"""
794816
causal_posterior = []
817+
predictive_posterior = []
795818
predictive_state_probabilities = []
796819
causal_state_probabilities = []
797820
acausal_posterior = []
@@ -893,6 +916,7 @@ def chunked_filter_smoother_covariate_dependent(
893916
causal_state_probabilities.append(
894917
causal_posterior_chunk @ state_aggregation_matrix
895918
)
919+
predictive_posterior.append(predicted_probs_chunk)
896920
predictive_state_probabilities.append(
897921
predicted_probs_chunk @ state_aggregation_matrix
898922
)
@@ -902,6 +926,7 @@ def chunked_filter_smoother_covariate_dependent(
902926
# Concatenate JAX arrays on device
903927
causal_posterior_jax = jnp.concatenate(causal_posterior)
904928
causal_state_probabilities_jax = jnp.concatenate(causal_state_probabilities)
929+
predictive_posterior_jax = jnp.concatenate(predictive_posterior)
905930
predictive_state_probabilities_jax = jnp.concatenate(predictive_state_probabilities)
906931

907932
# Backward pass: accumulate JAX arrays
@@ -946,6 +971,7 @@ def chunked_filter_smoother_covariate_dependent(
946971
np.asarray(predictive_state_probabilities_jax),
947972
log_likelihoods, # Keep as original (may be None or NumPy)
948973
np.asarray(causal_posterior_jax),
974+
np.asarray(predictive_posterior_jax),
949975
)
950976

951977

src/non_local_detector/models/base.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _normalize_return_outputs(
106106
)
107107
return OUTPUT_INCLUDES.get(return_outputs, {return_outputs})
108108

109-
if isinstance(return_outputs, (list, set)):
109+
if isinstance(return_outputs, list | set):
110110
outputs_set = set(return_outputs)
111111
invalid = outputs_set - VALID_OUTPUTS
112112
if invalid:
@@ -1063,7 +1063,14 @@ def _predict(
10631063
cache_likelihood: bool = True,
10641064
n_chunks: int = 1,
10651065
) -> tuple[
1066-
np.ndarray, np.ndarray, float, np.ndarray, np.ndarray, np.ndarray, np.ndarray
1066+
np.ndarray,
1067+
np.ndarray,
1068+
float,
1069+
np.ndarray,
1070+
np.ndarray,
1071+
np.ndarray,
1072+
np.ndarray,
1073+
np.ndarray,
10671074
]:
10681075
"""
10691076
Compute the posterior probabilities.
@@ -1092,6 +1099,7 @@ def _predict(
10921099
predictive_state_probabilities : np.ndarray, shape (n_time, n_states)
10931100
log_likelihoods : np.ndarray, shape (n_time, n_state_bins)
10941101
causal_posterior : np.ndarray, shape (n_time, n_state_bins)
1102+
predictive_posterior : np.ndarray, shape (n_time, n_state_bins)
10951103
"""
10961104

10971105
logger.info("Computing posterior...")
@@ -1549,6 +1557,7 @@ def _convert_results_to_xarray(
15491557
causal_posterior: np.ndarray | None = None,
15501558
causal_state_probabilities: np.ndarray | None = None,
15511559
predictive_state_probabilities: np.ndarray | None = None,
1560+
predictive_posterior: np.ndarray | None = None,
15521561
) -> xr.Dataset:
15531562
"""
15541563
Convert the results to an xarray Dataset.
@@ -1571,6 +1580,8 @@ def _convert_results_to_xarray(
15711580
Causal state probabilities, by default None.
15721581
predictive_state_probabilities : np.ndarray, optional, shape (n_time, n_states)
15731582
One-step-ahead predicted state probabilities, by default None.
1583+
predictive_posterior : np.ndarray, optional, shape (n_time, n_state_bins)
1584+
One-step-ahead predicted posterior probabilities over state bins, by default None.
15741585
15751586
Returns
15761587
-------
@@ -1698,6 +1709,14 @@ def _convert_results_to_xarray(
16981709
predictive_state_probabilities,
16991710
)
17001711

1712+
if predictive_posterior is not None:
1713+
data_vars["predictive_posterior"] = (
1714+
("time", "state_bins"),
1715+
self._create_masked_posterior(
1716+
predictive_posterior, is_track_interior, n_total_bins
1717+
),
1718+
)
1719+
17011720
# Create Dataset with MultiIndex coordinates
17021721
results = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
17031722

@@ -2256,7 +2275,9 @@ def predict(
22562275
- causal_state_probabilities : (n_time, n_states) - if 'filter'
22572276
Filtered discrete state probabilities
22582277
- predictive_state_probabilities : (n_time, n_states) - if 'predictive'
2259-
One-step-ahead predictive distributions
2278+
One-step-ahead predictive distributions over discrete states
2279+
- predictive_posterior : (n_time, n_state_bins) - if 'predictive'
2280+
One-step-ahead predictive distributions over state bins
22602281
- log_likelihood : (n_time, n_state_bins) - if 'log_likelihood'
22612282
Per-timepoint observation log likelihoods
22622283
@@ -2351,6 +2372,7 @@ def predict(
23512372
predictive_state_probabilities,
23522373
log_likelihood,
23532374
causal_posterior,
2375+
predictive_posterior,
23542376
) = self._predict(
23552377
time=time,
23562378
log_likelihood_args=(
@@ -2383,6 +2405,9 @@ def predict(
23832405
if "predictive" in requested_outputs
23842406
else None
23852407
),
2408+
predictive_posterior=(
2409+
predictive_posterior if "predictive" in requested_outputs else None
2410+
),
23862411
)
23872412

23882413
def estimate_parameters(
@@ -3029,6 +3054,7 @@ def predict(
30293054
predictive_state_probabilities,
30303055
log_likelihood,
30313056
causal_posterior,
3057+
predictive_posterior,
30323058
) = self._predict(
30333059
time=time,
30343060
log_likelihood_args=(
@@ -3060,6 +3086,9 @@ def predict(
30603086
if "predictive" in requested_outputs
30613087
else None
30623088
),
3089+
predictive_posterior=(
3090+
predictive_posterior if "predictive" in requested_outputs else None
3091+
),
30633092
)
30643093

30653094
def estimate_parameters(

src/non_local_detector/tests/core/test_chunked_parity.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def log_likelihood_func(time_idx, *args):
7070
predictive_state_probabilities, # one-step-ahead predictions
7171
_,
7272
causal_posterior,
73+
_, # predictive_posterior (not used in this test)
7374
) = chunked_filter_smoother(
7475
time=time,
7576
state_ind=state_ind,
@@ -114,6 +115,7 @@ def log_likelihood_func(time_idx, *args):
114115
_, # predictive_state_probabilities
115116
_,
116117
_,
118+
_, # predictive_posterior (not used in this test)
117119
) = chunked_filter_smoother(
118120
time=time,
119121
state_ind=state_ind,
@@ -155,6 +157,7 @@ def log_likelihood_func(time_idx, *args):
155157
_, # predictive_state_probabilities
156158
_,
157159
_,
160+
_, # predictive_posterior (not used in this test)
158161
) = chunked_filter_smoother(
159162
time=time,
160163
state_ind=state_ind,
@@ -175,6 +178,7 @@ def log_likelihood_func(time_idx, *args):
175178
_, # predictive_state_probabilities
176179
_,
177180
_,
181+
_, # predictive_posterior (not used in this test)
178182
) = chunked_filter_smoother(
179183
time=time,
180184
state_ind=state_ind,
@@ -211,7 +215,7 @@ def log_likelihood_func(time_idx, *args):
211215
total_log_like_std = log_marginals.sum()
212216

213217
# Act - Chunked version
214-
(_, _, log_like_chunked, _, _, _, _) = chunked_filter_smoother(
218+
(_, _, log_like_chunked, _, _, _, _, _) = chunked_filter_smoother(
215219
time=time,
216220
state_ind=state_ind,
217221
initial_distribution=np.array(init),
@@ -251,6 +255,7 @@ def log_likelihood_func(time_idx, *args):
251255
_, # predictive_state_probabilities
252256
_,
253257
_,
258+
_, # predictive_posterior (not used in this test)
254259
) = chunked_filter_smoother(
255260
time=time,
256261
state_ind=state_ind,
@@ -286,6 +291,7 @@ def log_likelihood_func(time_idx, *args):
286291
_, # predictive_state_probabilities
287292
_,
288293
_,
294+
_, # predictive_posterior (not used in this test)
289295
) = chunked_filter_smoother(
290296
time=time,
291297
state_ind=state_ind,
@@ -330,6 +336,7 @@ def log_likelihood_func(time_idx, *args):
330336
_, # predictive_state_probabilities
331337
_,
332338
_,
339+
_, # predictive_posterior (not used in this test)
333340
) = chunked_filter_smoother(
334341
time=time,
335342
state_ind=state_ind,
@@ -369,6 +376,7 @@ def log_likelihood_func(time_idx, *args):
369376
_,
370377
_,
371378
_,
379+
_, # predictive_posterior (not used in this test)
372380
) = chunked_filter_smoother(
373381
time=time,
374382
state_ind=state_ind,
@@ -390,6 +398,7 @@ def log_likelihood_func(time_idx, *args):
390398
_,
391399
_,
392400
_,
401+
_, # predictive_posterior (not used in this test)
393402
) = chunked_filter_smoother(
394403
time=time,
395404
state_ind=state_ind,
@@ -459,6 +468,7 @@ def log_likelihood_func(time_idx, *args):
459468
_, # predictive_state_probabilities
460469
_,
461470
_,
471+
_, # predictive_posterior (not used in this test)
462472
) = chunked_filter_smoother(
463473
time=time,
464474
state_ind=state_ind,

src/non_local_detector/tests/models/test_return_outputs.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,31 @@ def test_return_predictive_string(self, simple_fitted_detector):
112112
# Should have smoother (always)
113113
assert "acausal_posterior" in results
114114

115-
# Should have predictive
115+
# Should have predictive (both aggregated and full)
116116
assert "predictive_state_probabilities" in results
117+
assert "predictive_posterior" in results
117118

118119
# Should NOT have filter or log_likelihood
119120
assert "causal_posterior" not in results
120121
assert "log_likelihood" not in results
121122

122-
# Verify shape (aggregated to discrete states)
123+
# Verify shape of aggregated version (discrete states)
123124
n_time = len(time)
124125
n_states = results.acausal_state_probabilities.shape[1]
125126
assert results.predictive_state_probabilities.shape == (n_time, n_states)
126127

127-
# Verify probabilities sum to 1
128+
# Verify shape of full version (state bins)
129+
n_state_bins = results.acausal_posterior.shape[1]
130+
assert results.predictive_posterior.shape == (n_time, n_state_bins)
131+
132+
# Verify probabilities sum to 1 (aggregated version)
128133
predictive_sums = results.predictive_state_probabilities.sum(dim="states")
129134
assert np.allclose(predictive_sums.values, 1.0, atol=1e-10)
130135

136+
# Verify probabilities sum to 1 (full version)
137+
predictive_posterior_sums = results.predictive_posterior.sum(dim="state_bins")
138+
assert np.allclose(predictive_posterior_sums.values, 1.0, atol=1e-10)
139+
131140
def test_return_log_likelihood_string(self, simple_fitted_detector):
132141
"""Test return_outputs='log_likelihood' returns log likelihoods."""
133142
detector, spike_times, time, position = simple_fitted_detector
@@ -173,6 +182,7 @@ def test_return_all_string(self, simple_fitted_detector):
173182
assert "causal_posterior" in results
174183
assert "causal_state_probabilities" in results
175184
assert "predictive_state_probabilities" in results
185+
assert "predictive_posterior" in results
176186
assert "log_likelihood" in results
177187
assert "marginal_log_likelihoods" in results.attrs
178188

@@ -195,6 +205,7 @@ def test_return_multiple_outputs_list(self, simple_fitted_detector):
195205
assert "causal_posterior" in results
196206
assert "causal_state_probabilities" in results
197207
assert "predictive_state_probabilities" in results
208+
assert "predictive_posterior" in results
198209

199210
# Should NOT have log_likelihood
200211
assert "log_likelihood" not in results

0 commit comments

Comments
 (0)