6161 "block_size" : 10_000 ,
6262}
6363
64+ # Valid options for return_outputs parameter
65+ VALID_OUTPUTS : set [str ] = {"filter" , "predictive" , "log_likelihood" , "all" }
66+
67+ # Mapping of single string options to sets of outputs
68+ OUTPUT_INCLUDES : dict [str , set [str ]] = {
69+ "filter" : {"filter" },
70+ "predictive" : {"predictive" },
71+ "log_likelihood" : {"log_likelihood" },
72+ "all" : {"filter" , "predictive" , "log_likelihood" },
73+ }
74+
75+
76+ def _normalize_return_outputs (
77+ return_outputs : str | list [str ] | set [str ] | None ,
78+ ) -> set [str ]:
79+ """Convert return_outputs to canonical set of output names.
80+
81+ Parameters
82+ ----------
83+ return_outputs : str, list of str, set of str, or None
84+ Controls which optional outputs are included.
85+
86+ Returns
87+ -------
88+ set of str
89+ Normalized set containing any of: 'filter', 'predictive', 'log_likelihood'
90+
91+ Raises
92+ ------
93+ ValueError
94+ If return_outputs contains invalid option names.
95+ TypeError
96+ If return_outputs is not str, list, set, or None.
97+ """
98+ if return_outputs is None :
99+ return set ()
100+
101+ if isinstance (return_outputs , str ):
102+ if return_outputs not in VALID_OUTPUTS :
103+ raise ValueError (
104+ f"Invalid return_outputs='{ return_outputs } '. "
105+ f"Must be one of: { sorted (VALID_OUTPUTS )} "
106+ )
107+ return OUTPUT_INCLUDES .get (return_outputs , {return_outputs })
108+
109+ if isinstance (return_outputs , (list , set )):
110+ outputs_set = set (return_outputs )
111+ invalid = outputs_set - VALID_OUTPUTS
112+ if invalid :
113+ raise ValueError (
114+ f"Invalid outputs: { sorted (invalid )} . "
115+ f"Valid options are: { sorted (VALID_OUTPUTS )} "
116+ )
117+ # Expand 'all' if present
118+ if "all" in outputs_set :
119+ return OUTPUT_INCLUDES ["all" ]
120+ return outputs_set
121+
122+ raise TypeError (
123+ f"return_outputs must be str, list of str, set of str, or None. "
124+ f"Got { type (return_outputs ).__name__ } "
125+ )
126+
64127
65128class _DetectorBase (BaseEstimator ):
66129 """Base class for detector objects."""
@@ -1485,6 +1548,7 @@ def _convert_results_to_xarray(
14851548 log_likelihood : np .ndarray | None = None ,
14861549 causal_posterior : np .ndarray | None = None ,
14871550 causal_state_probabilities : np .ndarray | None = None ,
1551+ predictive_state_probabilities : np .ndarray | None = None ,
14881552 ) -> xr .Dataset :
14891553 """
14901554 Convert the results to an xarray Dataset.
@@ -1505,6 +1569,8 @@ def _convert_results_to_xarray(
15051569 Causal (filtered) posterior probabilities, by default None.
15061570 causal_state_probabilities : np.ndarray, optional, shape (n_time, n_states)
15071571 Causal state probabilities, by default None.
1572+ predictive_state_probabilities : np.ndarray, optional, shape (n_time, n_states)
1573+ One-step-ahead predicted state probabilities, by default None.
15081574
15091575 Returns
15101576 -------
@@ -1626,6 +1692,12 @@ def _convert_results_to_xarray(
16261692 causal_state_probabilities ,
16271693 )
16281694
1695+ if predictive_state_probabilities is not None :
1696+ data_vars ["predictive_state_probabilities" ] = (
1697+ ("time" , "states" ),
1698+ predictive_state_probabilities ,
1699+ )
1700+
16291701 # Create Dataset with MultiIndex coordinates
16301702 results = xr .Dataset (data_vars = data_vars , coords = coords , attrs = attrs )
16311703
@@ -2109,8 +2181,9 @@ def predict(
21092181 discrete_transition_covariate_data : pd .DataFrame | dict | None = None ,
21102182 cache_likelihood : bool = False ,
21112183 n_chunks : int = 1 ,
2112- save_log_likelihood_to_results : bool = False ,
2113- save_causal_posterior_to_results : bool = False ,
2184+ return_outputs : str | list [str ] | set [str ] | None = None ,
2185+ save_log_likelihood_to_results : bool | None = None ,
2186+ save_causal_posterior_to_results : bool | None = None ,
21142187 ) -> xr .Dataset :
21152188 """
21162189 Predict the posterior probabilities for the given data.
@@ -2135,15 +2208,82 @@ def predict(
21352208 If True, log likelihoods are cached instead of recomputed for each chunk, by default True
21362209 n_chunks : int, optional
21372210 Splits data into chunks for processing, by default 1
2211+ return_outputs : str, list of str, set of str, or None, optional
2212+ Controls which optional outputs are returned.
2213+
2214+ Options:
2215+ - None: smoother only (default, minimal memory)
2216+ - 'filter': filtered (causal) posterior and state probabilities
2217+ - 'predictive': one-step-ahead predictive state distributions
2218+ - 'log_likelihood': per-timepoint log likelihoods
2219+ - 'all': all outputs above
2220+ - List/set: e.g., ['filter', 'predictive'] for multiple outputs
2221+
2222+ The smoother (acausal_posterior, acausal_state_probabilities) and
2223+ marginal_log_likelihood are ALWAYS included.
2224+
2225+ When to use each output:
2226+ - 'filter': Online/causal decoding, debugging forward pass
2227+ - 'predictive': Model evaluation, predictive checks
2228+ - 'log_likelihood': Diagnostics, per-timepoint metrics, model comparison
2229+
2230+ Memory warning: 'log_likelihood' and 'filter' can be very large
2231+ (~400 GB for 1M timepoints × 100k spatial bins). Only request
2232+ what you need for your analysis.
21382233 save_log_likelihood_to_results : bool, optional
2139- Whether to save the log likelihood to the results, by default False.
2234+ DEPRECATED. Use return_outputs='log_likelihood' instead.
2235+ Whether to save the log likelihood to the results, by default None.
21402236 save_causal_posterior_to_results : bool, optional
2141- Whether to save the causal (filtered) posterior to the results, by default False.
2237+ DEPRECATED. Use return_outputs='filter' instead.
2238+ Whether to save the causal (filtered) posterior to the results, by default None.
21422239
21432240 Returns
21442241 -------
21452242 xr.Dataset
2146- Predicted posterior probabilities.
2243+ Dataset containing decoded posterior distributions.
2244+
2245+ Always included:
2246+ - acausal_posterior : (n_time, n_state_bins)
2247+ Smoothed posterior over state bins
2248+ - acausal_state_probabilities : (n_time, n_states)
2249+ Smoothed discrete state probabilities
2250+ - marginal_log_likelihoods : float (in attrs)
2251+ Total log evidence for the model
2252+
2253+ Conditionally included based on return_outputs:
2254+ - causal_posterior : (n_time, n_state_bins) - if 'filter'
2255+ Filtered (forward-only) posterior over state bins
2256+ - causal_state_probabilities : (n_time, n_states) - if 'filter'
2257+ Filtered discrete state probabilities
2258+ - predictive_state_probabilities : (n_time, n_states) - if 'predictive'
2259+ One-step-ahead predictive distributions
2260+ - log_likelihood : (n_time, n_state_bins) - if 'log_likelihood'
2261+ Per-timepoint observation log likelihoods
2262+
2263+ Examples
2264+ --------
2265+ Get only smoother (default, minimal memory):
2266+
2267+ >>> results = model.predict(spike_times, time)
2268+ >>> results.acausal_posterior.shape
2269+ (10000, 50000)
2270+
2271+ Include filtered posterior for online decoding:
2272+
2273+ >>> results = model.predict(spike_times, time, return_outputs='filter')
2274+ >>> results.causal_posterior.shape
2275+ (10000, 50000)
2276+
2277+ Get multiple outputs for analysis:
2278+
2279+ >>> results = model.predict(
2280+ ... spike_times, time,
2281+ ... return_outputs=['filter', 'predictive']
2282+ ... )
2283+
2284+ Get everything for debugging:
2285+
2286+ >>> results = model.predict(spike_times, time, return_outputs='all')
21472287 """
21482288 if position is not None :
21492289 position = position [:, np .newaxis ] if position .ndim == 1 else position
@@ -2158,6 +2298,41 @@ def predict(
21582298 f"Length of is_missing must match length of time. Time is n_samples: { len (time )} "
21592299 )
21602300
2301+ # Handle deprecated boolean flags
2302+ import warnings
2303+
2304+ if (
2305+ save_log_likelihood_to_results is not None
2306+ or save_causal_posterior_to_results is not None
2307+ ):
2308+ warnings .warn (
2309+ "save_log_likelihood_to_results and save_causal_posterior_to_results "
2310+ "are deprecated. Use return_outputs parameter instead." ,
2311+ DeprecationWarning ,
2312+ stacklevel = 2 ,
2313+ )
2314+
2315+ # Convert old flags to new format
2316+ outputs_from_flags = set ()
2317+ if save_log_likelihood_to_results :
2318+ outputs_from_flags .add ("log_likelihood" )
2319+ if save_causal_posterior_to_results :
2320+ outputs_from_flags .add ("filter" )
2321+
2322+ if return_outputs is not None :
2323+ raise ValueError (
2324+ "Cannot specify both return_outputs and deprecated "
2325+ "save_*_to_results flags. Use return_outputs only."
2326+ )
2327+ return_outputs = outputs_from_flags if outputs_from_flags else None
2328+
2329+ # Normalize return_outputs to canonical set
2330+ requested_outputs = _normalize_return_outputs (return_outputs )
2331+
2332+ # Automatically enable caching if log_likelihood is requested
2333+ if "log_likelihood" in requested_outputs and not cache_likelihood :
2334+ cache_likelihood = True
2335+
21612336 if discrete_transition_covariate_data is not None :
21622337 if self .discrete_transition_coefficients_ is None :
21632338 raise ValueError (
@@ -2173,7 +2348,7 @@ def predict(
21732348 acausal_state_probabilities ,
21742349 marginal_log_likelihood ,
21752350 causal_state_probabilities ,
2176- _ ,
2351+ predictive_state_probabilities ,
21772352 log_likelihood ,
21782353 causal_posterior ,
21792354 ) = self ._predict (
@@ -2194,12 +2369,19 @@ def predict(
21942369 acausal_posterior ,
21952370 acausal_state_probabilities ,
21962371 marginal_log_likelihood ,
2197- log_likelihood if save_log_likelihood_to_results else None ,
2372+ log_likelihood = (
2373+ log_likelihood if "log_likelihood" in requested_outputs else None
2374+ ),
21982375 causal_posterior = (
2199- causal_posterior if save_causal_posterior_to_results else None
2376+ causal_posterior if "filter" in requested_outputs else None
22002377 ),
22012378 causal_state_probabilities = (
2202- causal_state_probabilities if save_causal_posterior_to_results else None
2379+ causal_state_probabilities if "filter" in requested_outputs else None
2380+ ),
2381+ predictive_state_probabilities = (
2382+ predictive_state_probabilities
2383+ if "predictive" in requested_outputs
2384+ else None
22032385 ),
22042386 )
22052387
@@ -2742,7 +2924,9 @@ def predict(
27422924 discrete_transition_covariate_data : pd .DataFrame | dict | None = None ,
27432925 cache_likelihood : bool = False ,
27442926 n_chunks : int = 1 ,
2745- save_log_likelihood_to_results : bool = False ,
2927+ return_outputs : str | list [str ] | set [str ] | None = None ,
2928+ save_log_likelihood_to_results : bool | None = None ,
2929+ save_causal_posterior_to_results : bool | None = None ,
27462930 ) -> xr .Dataset :
27472931 """
27482932 Predict the posterior probabilities for the given data.
@@ -2765,8 +2949,13 @@ def predict(
27652949 Whether to cache the log likelihoods, by default False.
27662950 n_chunks : int, optional
27672951 Splits data into chunks for processing, by default 1
2952+ return_outputs : str, list of str, set of str, or None, optional
2953+ Controls which optional outputs are returned. See ClusterlessDetector.predict
2954+ for full documentation. By default None.
27682955 save_log_likelihood_to_results : bool, optional
2769- Whether to save the log likelihood to the results, by default False.
2956+ DEPRECATED. Use return_outputs='log_likelihood' instead. By default None.
2957+ save_causal_posterior_to_results : bool, optional
2958+ DEPRECATED. Use return_outputs='filter' instead. By default None.
27702959
27712960 Returns
27722961 -------
@@ -2786,6 +2975,41 @@ def predict(
27862975 f"Length of is_missing must match length of time. Time is n_samples: { len (time )} "
27872976 )
27882977
2978+ # Handle deprecated boolean flags
2979+ import warnings
2980+
2981+ if (
2982+ save_log_likelihood_to_results is not None
2983+ or save_causal_posterior_to_results is not None
2984+ ):
2985+ warnings .warn (
2986+ "save_log_likelihood_to_results and save_causal_posterior_to_results "
2987+ "are deprecated. Use return_outputs parameter instead." ,
2988+ DeprecationWarning ,
2989+ stacklevel = 2 ,
2990+ )
2991+
2992+ # Convert old flags to new format
2993+ outputs_from_flags = set ()
2994+ if save_log_likelihood_to_results :
2995+ outputs_from_flags .add ("log_likelihood" )
2996+ if save_causal_posterior_to_results :
2997+ outputs_from_flags .add ("filter" )
2998+
2999+ if return_outputs is not None :
3000+ raise ValueError (
3001+ "Cannot specify both return_outputs and deprecated "
3002+ "save_*_to_results flags. Use return_outputs only."
3003+ )
3004+ return_outputs = outputs_from_flags if outputs_from_flags else None
3005+
3006+ # Normalize return_outputs to canonical set
3007+ requested_outputs = _normalize_return_outputs (return_outputs )
3008+
3009+ # Automatically enable caching if log_likelihood is requested
3010+ if "log_likelihood" in requested_outputs and not cache_likelihood :
3011+ cache_likelihood = True
3012+
27893013 if discrete_transition_covariate_data is not None :
27903014 if self .discrete_transition_coefficients_ is None :
27913015 raise ValueError (
@@ -2801,10 +3025,10 @@ def predict(
28013025 acausal_posterior ,
28023026 acausal_state_probabilities ,
28033027 marginal_log_likelihood ,
2804- _ ,
2805- _ ,
3028+ causal_state_probabilities ,
3029+ predictive_state_probabilities ,
28063030 log_likelihood ,
2807- _ ,
3031+ causal_posterior ,
28083032 ) = self ._predict (
28093033 time = time ,
28103034 log_likelihood_args = (
@@ -2822,7 +3046,20 @@ def predict(
28223046 acausal_posterior ,
28233047 acausal_state_probabilities ,
28243048 marginal_log_likelihood ,
2825- log_likelihood = log_likelihood if save_log_likelihood_to_results else None ,
3049+ log_likelihood = (
3050+ log_likelihood if "log_likelihood" in requested_outputs else None
3051+ ),
3052+ causal_posterior = (
3053+ causal_posterior if "filter" in requested_outputs else None
3054+ ),
3055+ causal_state_probabilities = (
3056+ causal_state_probabilities if "filter" in requested_outputs else None
3057+ ),
3058+ predictive_state_probabilities = (
3059+ predictive_state_probabilities
3060+ if "predictive" in requested_outputs
3061+ else None
3062+ ),
28263063 )
28273064
28283065 def estimate_parameters (
0 commit comments