Skip to content

Commit 6eb8dcd

Browse files
committed
Add flexible return_outputs parameter to predict methods
Introduces a new string/list/set-based 'return_outputs' parameter to ClusterlessDetector and SortedSpikesDetector predict() methods, allowing users to specify which optional outputs (filter, predictive, log_likelihood, all) are returned. Deprecates old boolean flags, adds normalization logic, updates docstrings, and provides comprehensive tests for new and legacy interfaces.
1 parent 33d70ee commit 6eb8dcd

File tree

2 files changed

+602
-15
lines changed

2 files changed

+602
-15
lines changed

src/non_local_detector/models/base.py

Lines changed: 252 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,69 @@
6161
"block_size": 10_000,
6262
}
6363

64+
# Valid options for return_outputs parameter
65+
VALID_OUTPUTS: set[str] = {"filter", "predictive", "log_likelihood", "all"}
66+
67+
# Mapping of single string options to sets of outputs
68+
OUTPUT_INCLUDES: dict[str, set[str]] = {
69+
"filter": {"filter"},
70+
"predictive": {"predictive"},
71+
"log_likelihood": {"log_likelihood"},
72+
"all": {"filter", "predictive", "log_likelihood"},
73+
}
74+
75+
76+
def _normalize_return_outputs(
77+
return_outputs: str | list[str] | set[str] | None,
78+
) -> set[str]:
79+
"""Convert return_outputs to canonical set of output names.
80+
81+
Parameters
82+
----------
83+
return_outputs : str, list of str, set of str, or None
84+
Controls which optional outputs are included.
85+
86+
Returns
87+
-------
88+
set of str
89+
Normalized set containing any of: 'filter', 'predictive', 'log_likelihood'
90+
91+
Raises
92+
------
93+
ValueError
94+
If return_outputs contains invalid option names.
95+
TypeError
96+
If return_outputs is not str, list, set, or None.
97+
"""
98+
if return_outputs is None:
99+
return set()
100+
101+
if isinstance(return_outputs, str):
102+
if return_outputs not in VALID_OUTPUTS:
103+
raise ValueError(
104+
f"Invalid return_outputs='{return_outputs}'. "
105+
f"Must be one of: {sorted(VALID_OUTPUTS)}"
106+
)
107+
return OUTPUT_INCLUDES.get(return_outputs, {return_outputs})
108+
109+
if isinstance(return_outputs, (list, set)):
110+
outputs_set = set(return_outputs)
111+
invalid = outputs_set - VALID_OUTPUTS
112+
if invalid:
113+
raise ValueError(
114+
f"Invalid outputs: {sorted(invalid)}. "
115+
f"Valid options are: {sorted(VALID_OUTPUTS)}"
116+
)
117+
# Expand 'all' if present
118+
if "all" in outputs_set:
119+
return OUTPUT_INCLUDES["all"]
120+
return outputs_set
121+
122+
raise TypeError(
123+
f"return_outputs must be str, list of str, set of str, or None. "
124+
f"Got {type(return_outputs).__name__}"
125+
)
126+
64127

65128
class _DetectorBase(BaseEstimator):
66129
"""Base class for detector objects."""
@@ -1485,6 +1548,7 @@ def _convert_results_to_xarray(
14851548
log_likelihood: np.ndarray | None = None,
14861549
causal_posterior: np.ndarray | None = None,
14871550
causal_state_probabilities: np.ndarray | None = None,
1551+
predictive_state_probabilities: np.ndarray | None = None,
14881552
) -> xr.Dataset:
14891553
"""
14901554
Convert the results to an xarray Dataset.
@@ -1505,6 +1569,8 @@ def _convert_results_to_xarray(
15051569
Causal (filtered) posterior probabilities, by default None.
15061570
causal_state_probabilities : np.ndarray, optional, shape (n_time, n_states)
15071571
Causal state probabilities, by default None.
1572+
predictive_state_probabilities : np.ndarray, optional, shape (n_time, n_states)
1573+
One-step-ahead predicted state probabilities, by default None.
15081574
15091575
Returns
15101576
-------
@@ -1626,6 +1692,12 @@ def _convert_results_to_xarray(
16261692
causal_state_probabilities,
16271693
)
16281694

1695+
if predictive_state_probabilities is not None:
1696+
data_vars["predictive_state_probabilities"] = (
1697+
("time", "states"),
1698+
predictive_state_probabilities,
1699+
)
1700+
16291701
# Create Dataset with MultiIndex coordinates
16301702
results = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
16311703

@@ -2109,8 +2181,9 @@ def predict(
21092181
discrete_transition_covariate_data: pd.DataFrame | dict | None = None,
21102182
cache_likelihood: bool = False,
21112183
n_chunks: int = 1,
2112-
save_log_likelihood_to_results: bool = False,
2113-
save_causal_posterior_to_results: bool = False,
2184+
return_outputs: str | list[str] | set[str] | None = None,
2185+
save_log_likelihood_to_results: bool | None = None,
2186+
save_causal_posterior_to_results: bool | None = None,
21142187
) -> xr.Dataset:
21152188
"""
21162189
Predict the posterior probabilities for the given data.
@@ -2135,15 +2208,82 @@ def predict(
21352208
If True, log likelihoods are cached instead of recomputed for each chunk, by default True
21362209
n_chunks : int, optional
21372210
Splits data into chunks for processing, by default 1
2211+
return_outputs : str, list of str, set of str, or None, optional
2212+
Controls which optional outputs are returned.
2213+
2214+
Options:
2215+
- None: smoother only (default, minimal memory)
2216+
- 'filter': filtered (causal) posterior and state probabilities
2217+
- 'predictive': one-step-ahead predictive state distributions
2218+
- 'log_likelihood': per-timepoint log likelihoods
2219+
- 'all': all outputs above
2220+
- List/set: e.g., ['filter', 'predictive'] for multiple outputs
2221+
2222+
The smoother (acausal_posterior, acausal_state_probabilities) and
2223+
marginal_log_likelihood are ALWAYS included.
2224+
2225+
When to use each output:
2226+
- 'filter': Online/causal decoding, debugging forward pass
2227+
- 'predictive': Model evaluation, predictive checks
2228+
- 'log_likelihood': Diagnostics, per-timepoint metrics, model comparison
2229+
2230+
Memory warning: 'log_likelihood' and 'filter' can be very large
2231+
(~400 GB for 1M timepoints × 100k spatial bins). Only request
2232+
what you need for your analysis.
21382233
save_log_likelihood_to_results : bool, optional
2139-
Whether to save the log likelihood to the results, by default False.
2234+
DEPRECATED. Use return_outputs='log_likelihood' instead.
2235+
Whether to save the log likelihood to the results, by default None.
21402236
save_causal_posterior_to_results : bool, optional
2141-
Whether to save the causal (filtered) posterior to the results, by default False.
2237+
DEPRECATED. Use return_outputs='filter' instead.
2238+
Whether to save the causal (filtered) posterior to the results, by default None.
21422239
21432240
Returns
21442241
-------
21452242
xr.Dataset
2146-
Predicted posterior probabilities.
2243+
Dataset containing decoded posterior distributions.
2244+
2245+
Always included:
2246+
- acausal_posterior : (n_time, n_state_bins)
2247+
Smoothed posterior over state bins
2248+
- acausal_state_probabilities : (n_time, n_states)
2249+
Smoothed discrete state probabilities
2250+
- marginal_log_likelihoods : float (in attrs)
2251+
Total log evidence for the model
2252+
2253+
Conditionally included based on return_outputs:
2254+
- causal_posterior : (n_time, n_state_bins) - if 'filter'
2255+
Filtered (forward-only) posterior over state bins
2256+
- causal_state_probabilities : (n_time, n_states) - if 'filter'
2257+
Filtered discrete state probabilities
2258+
- predictive_state_probabilities : (n_time, n_states) - if 'predictive'
2259+
One-step-ahead predictive distributions
2260+
- log_likelihood : (n_time, n_state_bins) - if 'log_likelihood'
2261+
Per-timepoint observation log likelihoods
2262+
2263+
Examples
2264+
--------
2265+
Get only smoother (default, minimal memory):
2266+
2267+
>>> results = model.predict(spike_times, time)
2268+
>>> results.acausal_posterior.shape
2269+
(10000, 50000)
2270+
2271+
Include filtered posterior for online decoding:
2272+
2273+
>>> results = model.predict(spike_times, time, return_outputs='filter')
2274+
>>> results.causal_posterior.shape
2275+
(10000, 50000)
2276+
2277+
Get multiple outputs for analysis:
2278+
2279+
>>> results = model.predict(
2280+
... spike_times, time,
2281+
... return_outputs=['filter', 'predictive']
2282+
... )
2283+
2284+
Get everything for debugging:
2285+
2286+
>>> results = model.predict(spike_times, time, return_outputs='all')
21472287
"""
21482288
if position is not None:
21492289
position = position[:, np.newaxis] if position.ndim == 1 else position
@@ -2158,6 +2298,41 @@ def predict(
21582298
f"Length of is_missing must match length of time. Time is n_samples: {len(time)}"
21592299
)
21602300

2301+
# Handle deprecated boolean flags
2302+
import warnings
2303+
2304+
if (
2305+
save_log_likelihood_to_results is not None
2306+
or save_causal_posterior_to_results is not None
2307+
):
2308+
warnings.warn(
2309+
"save_log_likelihood_to_results and save_causal_posterior_to_results "
2310+
"are deprecated. Use return_outputs parameter instead.",
2311+
DeprecationWarning,
2312+
stacklevel=2,
2313+
)
2314+
2315+
# Convert old flags to new format
2316+
outputs_from_flags = set()
2317+
if save_log_likelihood_to_results:
2318+
outputs_from_flags.add("log_likelihood")
2319+
if save_causal_posterior_to_results:
2320+
outputs_from_flags.add("filter")
2321+
2322+
if return_outputs is not None:
2323+
raise ValueError(
2324+
"Cannot specify both return_outputs and deprecated "
2325+
"save_*_to_results flags. Use return_outputs only."
2326+
)
2327+
return_outputs = outputs_from_flags if outputs_from_flags else None
2328+
2329+
# Normalize return_outputs to canonical set
2330+
requested_outputs = _normalize_return_outputs(return_outputs)
2331+
2332+
# Automatically enable caching if log_likelihood is requested
2333+
if "log_likelihood" in requested_outputs and not cache_likelihood:
2334+
cache_likelihood = True
2335+
21612336
if discrete_transition_covariate_data is not None:
21622337
if self.discrete_transition_coefficients_ is None:
21632338
raise ValueError(
@@ -2173,7 +2348,7 @@ def predict(
21732348
acausal_state_probabilities,
21742349
marginal_log_likelihood,
21752350
causal_state_probabilities,
2176-
_,
2351+
predictive_state_probabilities,
21772352
log_likelihood,
21782353
causal_posterior,
21792354
) = self._predict(
@@ -2194,12 +2369,19 @@ def predict(
21942369
acausal_posterior,
21952370
acausal_state_probabilities,
21962371
marginal_log_likelihood,
2197-
log_likelihood if save_log_likelihood_to_results else None,
2372+
log_likelihood=(
2373+
log_likelihood if "log_likelihood" in requested_outputs else None
2374+
),
21982375
causal_posterior=(
2199-
causal_posterior if save_causal_posterior_to_results else None
2376+
causal_posterior if "filter" in requested_outputs else None
22002377
),
22012378
causal_state_probabilities=(
2202-
causal_state_probabilities if save_causal_posterior_to_results else None
2379+
causal_state_probabilities if "filter" in requested_outputs else None
2380+
),
2381+
predictive_state_probabilities=(
2382+
predictive_state_probabilities
2383+
if "predictive" in requested_outputs
2384+
else None
22032385
),
22042386
)
22052387

@@ -2742,7 +2924,9 @@ def predict(
27422924
discrete_transition_covariate_data: pd.DataFrame | dict | None = None,
27432925
cache_likelihood: bool = False,
27442926
n_chunks: int = 1,
2745-
save_log_likelihood_to_results: bool = False,
2927+
return_outputs: str | list[str] | set[str] | None = None,
2928+
save_log_likelihood_to_results: bool | None = None,
2929+
save_causal_posterior_to_results: bool | None = None,
27462930
) -> xr.Dataset:
27472931
"""
27482932
Predict the posterior probabilities for the given data.
@@ -2765,8 +2949,13 @@ def predict(
27652949
Whether to cache the log likelihoods, by default False.
27662950
n_chunks : int, optional
27672951
Splits data into chunks for processing, by default 1
2952+
return_outputs : str, list of str, set of str, or None, optional
2953+
Controls which optional outputs are returned. See ClusterlessDetector.predict
2954+
for full documentation. By default None.
27682955
save_log_likelihood_to_results : bool, optional
2769-
Whether to save the log likelihood to the results, by default False.
2956+
DEPRECATED. Use return_outputs='log_likelihood' instead. By default None.
2957+
save_causal_posterior_to_results : bool, optional
2958+
DEPRECATED. Use return_outputs='filter' instead. By default None.
27702959
27712960
Returns
27722961
-------
@@ -2786,6 +2975,41 @@ def predict(
27862975
f"Length of is_missing must match length of time. Time is n_samples: {len(time)}"
27872976
)
27882977

2978+
# Handle deprecated boolean flags
2979+
import warnings
2980+
2981+
if (
2982+
save_log_likelihood_to_results is not None
2983+
or save_causal_posterior_to_results is not None
2984+
):
2985+
warnings.warn(
2986+
"save_log_likelihood_to_results and save_causal_posterior_to_results "
2987+
"are deprecated. Use return_outputs parameter instead.",
2988+
DeprecationWarning,
2989+
stacklevel=2,
2990+
)
2991+
2992+
# Convert old flags to new format
2993+
outputs_from_flags = set()
2994+
if save_log_likelihood_to_results:
2995+
outputs_from_flags.add("log_likelihood")
2996+
if save_causal_posterior_to_results:
2997+
outputs_from_flags.add("filter")
2998+
2999+
if return_outputs is not None:
3000+
raise ValueError(
3001+
"Cannot specify both return_outputs and deprecated "
3002+
"save_*_to_results flags. Use return_outputs only."
3003+
)
3004+
return_outputs = outputs_from_flags if outputs_from_flags else None
3005+
3006+
# Normalize return_outputs to canonical set
3007+
requested_outputs = _normalize_return_outputs(return_outputs)
3008+
3009+
# Automatically enable caching if log_likelihood is requested
3010+
if "log_likelihood" in requested_outputs and not cache_likelihood:
3011+
cache_likelihood = True
3012+
27893013
if discrete_transition_covariate_data is not None:
27903014
if self.discrete_transition_coefficients_ is None:
27913015
raise ValueError(
@@ -2801,10 +3025,10 @@ def predict(
28013025
acausal_posterior,
28023026
acausal_state_probabilities,
28033027
marginal_log_likelihood,
2804-
_,
2805-
_,
3028+
causal_state_probabilities,
3029+
predictive_state_probabilities,
28063030
log_likelihood,
2807-
_,
3031+
causal_posterior,
28083032
) = self._predict(
28093033
time=time,
28103034
log_likelihood_args=(
@@ -2822,7 +3046,20 @@ def predict(
28223046
acausal_posterior,
28233047
acausal_state_probabilities,
28243048
marginal_log_likelihood,
2825-
log_likelihood=log_likelihood if save_log_likelihood_to_results else None,
3049+
log_likelihood=(
3050+
log_likelihood if "log_likelihood" in requested_outputs else None
3051+
),
3052+
causal_posterior=(
3053+
causal_posterior if "filter" in requested_outputs else None
3054+
),
3055+
causal_state_probabilities=(
3056+
causal_state_probabilities if "filter" in requested_outputs else None
3057+
),
3058+
predictive_state_probabilities=(
3059+
predictive_state_probabilities
3060+
if "predictive" in requested_outputs
3061+
else None
3062+
),
28263063
)
28273064

28283065
def estimate_parameters(

0 commit comments

Comments
 (0)