1- """Core algorithms for decoding."""
1+ """Core algorithms for decoding.
2+
3+ This module contains the fundamental algorithms for Bayesian decoding including
4+ causal and acausal state estimation and classification.
5+ """
26
37from __future__ import annotations
48
@@ -46,18 +50,17 @@ def get_centers(bin_edges: NDArray[np.float64]) -> NDArray[np.float64]:
4650
4751@njit (parallel = True , error_model = "numpy" )
4852def normalize_to_probability (distribution : NDArray [np .float64 ]) -> NDArray [np .float64 ]:
49- """Ensure the distribution integrates to 1 so that it is a probability
50- distribution.
53+ """Ensure the distribution integrates to 1 so that it is a probability distribution.
5154
5255 Parameters
5356 ----------
5457 distribution : NDArray[np.float64]
55- Probability distribution to normalize
58+ Probability distribution values to normalize.
5659
5760 Returns
5861 -------
59- NDArray[np.float64]
60- Normalized probability distribution that sums to 1
62+ normalized_distribution : NDArray[np.float64]
63+ Normalized probability distribution that sums to 1.
6164
6265 """
6366 return distribution / np .nansum (distribution )
@@ -73,13 +76,18 @@ def _causal_decode(
7376 Parameters
7477 ----------
7578 initial_conditions : NDArray[np.float64], shape (n_bins,)
79+ Initial probability distribution over state bins.
7680 state_transition : NDArray[np.float64], shape (n_bins, n_bins)
81+ Transition probability matrix between state bins.
7782 likelihood : NDArray[np.float64], shape (n_time, n_bins)
83+ Likelihood values for each time point and state bin.
7884
7985 Returns
8086 -------
8187 posterior : NDArray[np.float64], shape (n_time, n_bins)
88+ Posterior probability distribution over time and state bins.
8289 log_data_likelihood : float
90+ Log-likelihood of the observed data.
8391
8492 """
8593
@@ -109,11 +117,14 @@ def _acausal_decode(
109117 Parameters
110118 ----------
111119 causal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
120+ Causal (forward-pass) posterior probabilities.
112121 state_transition : NDArray[np.float64], shape (n_bins, n_bins)
122+ Transition probability matrix between state bins.
113123
114- Return
115- ------
124+ Returns
125+ -------
116126 acausal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
127+ Acausal (forward-backward) posterior probabilities.
117128
118129 """
119130 acausal_posterior = np .zeros_like (causal_posterior )
@@ -150,14 +161,20 @@ def _causal_classify(
150161 Parameters
151162 ----------
152163 initial_conditions : NDArray[np.float64], shape (n_states, n_bins, 1)
164+ Initial probability distribution for each state and spatial bin.
153165 continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
166+ Continuous-space transition probabilities between states and bins.
154167 discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
168+ Discrete state transition probabilities.
155169 likelihood : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
170+ Likelihood values for each time point, state, and spatial bin.
156171
157172 Returns
158173 -------
159174 causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
175+ Causal posterior probabilities over time, states, and spatial bins.
160176 log_data_likelihood : float
177+ Log-likelihood of the observed data.
161178
162179 """
163180 n_time , n_states , n_bins , _ = likelihood .shape
@@ -196,12 +213,16 @@ def _acausal_classify(
196213 Parameters
197214 ----------
198215 causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
216+ Causal (forward-pass) posterior probabilities.
199217 continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
218+ Continuous-space transition probabilities between states and bins.
200219 discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
220+ Discrete state transition probabilities.
201221
202- Return
203- ------
222+ Returns
223+ -------
204224 acausal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
225+ Acausal (forward-backward) posterior probabilities.
205226
206227 """
207228 acausal_posterior = np .zeros_like (causal_posterior )
@@ -242,11 +263,14 @@ def scaled_likelihood(log_likelihood: NDArray[np.float64], axis: int = 1) -> NDA
242263 Parameters
243264 ----------
244265 log_likelihood : NDArray[np.float64], shape (n_time, n_bins)
245- axis : int
266+ Log-likelihood values to be scaled.
267+ axis : int, optional
268+ Axis along which to find the maximum, by default 1.
246269
247270 Returns
248271 -------
249272 scaled_log_likelihood : NDArray[np.float64], shape (n_time, n_bins)
273+ Likelihood values scaled so that the maximum is 1.
250274
251275 """
252276 max_log_likelihood = np .nanmax (log_likelihood , axis = axis , keepdims = True )
@@ -269,11 +293,14 @@ def mask(value: NDArray[np.float64], is_track_interior: NDArray[np.bool_]) -> ND
269293 Parameters
270294 ----------
271295 value : NDArray[np.float64], shape (..., n_bins)
296+ Input values to be masked.
272297 is_track_interior : NDArray[np.bool_], shape (n_bins,)
298+ Boolean array indicating which bins are part of the track interior.
273299
274300 Returns
275301 -------
276- masked_value : NDArray[np.float64]
302+ masked_value : NDArray[np.float64], shape (..., n_bins)
303+ Values with non-track bins set to NaN.
277304
278305 """
279306 try :
@@ -288,24 +315,27 @@ def check_converged(
288315 previous_log_likelihood : NDArray [np .float64 ],
289316 tolerance : float = 1e-4 ,
290317) -> Tuple [bool , bool ]:
291- """We have converged if the slope of the log-likelihood function falls below 'tolerance',
318+ """Check if log-likelihood has converged.
292319
293- i.e., |f(t) - f(t-1)| / avg < tolerance,
294- where avg = (|f(t)| + |f(t-1)|)/2 and f(t) is log lik at iteration t.
320+ We have converged if the slope of the log-likelihood function falls below
321+ 'tolerance', i.e., |f(t) - f(t-1)| / avg < tolerance, where
322+ avg = (|f(t)| + |f(t-1)|)/2 and f(t) is log likelihood at iteration t.
295323
296324 Parameters
297325 ----------
298326 log_likelihood : NDArray[np.float64]
299- Current log likelihood
327+ Current log likelihood values.
300328 previous_log_likelihood : NDArray[np.float64]
301- Previous log likelihood
329+ Previous log likelihood values.
302330 tolerance : float, optional
303- threshold for similarity, by default 1e-4
331+ Threshold for similarity, by default 1e-4.
304332
305333 Returns
306334 -------
307335 is_converged : bool
336+ Whether the log-likelihood has converged.
308337 is_increasing : bool
338+ Whether the log-likelihood is increasing.
309339
310340 """
311341 delta_log_likelihood = abs (log_likelihood - previous_log_likelihood )
@@ -334,13 +364,18 @@ def _causal_decode_gpu(
334364 Parameters
335365 ----------
336366 initial_conditions : NDArray[np.float64], shape (n_bins,)
367+ Initial probability distribution over state bins.
337368 state_transition : NDArray[np.float64], shape (n_bins, n_bins)
369+ Transition probability matrix between state bins.
338370 likelihood : NDArray[np.float64], shape (n_time, n_bins)
371+ Likelihood values for each time point and state bin.
339372
340373 Returns
341374 -------
342375 posterior : NDArray[np.float64], shape (n_time, n_bins)
376+ Posterior probability distribution over time and state bins.
343377 log_data_likelihood : float
378+ Log-likelihood of the observed data.
344379
345380 """
346381
@@ -373,11 +408,14 @@ def _acausal_decode_gpu(
373408 Parameters
374409 ----------
375410 causal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
411+ Causal (forward-pass) posterior probabilities.
376412 state_transition : NDArray[np.float64], shape (n_bins, n_bins)
413+ Transition probability matrix between state bins.
377414
378- Return
379- ------
415+ Returns
416+ -------
380417 acausal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
418+ Acausal (forward-backward) posterior probabilities.
381419
382420 """
383421 causal_posterior = cp .asarray (causal_posterior , dtype = cp .float32 )
@@ -414,14 +452,20 @@ def _causal_classify_gpu(
414452 Parameters
415453 ----------
416454 initial_conditions : NDArray[np.float64], shape (n_states, n_bins, 1)
455+ Initial probability distribution for each state and spatial bin.
417456 continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
457+ Continuous-space transition probabilities between states and bins.
418458 discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
459+ Discrete state transition probabilities.
419460 likelihood : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
461+ Likelihood values for each time point, state, and spatial bin.
420462
421463 Returns
422464 -------
423465 causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
466+ Causal posterior probabilities over time, states, and spatial bins.
424467 log_data_likelihood : float
468+ Log-likelihood of the observed data.
425469
426470 """
427471 initial_conditions = cp .asarray (initial_conditions , dtype = cp .float32 )
@@ -467,12 +511,16 @@ def _acausal_classify_gpu(
467511 Parameters
468512 ----------
469513 causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
514+ Causal (forward-pass) posterior probabilities.
470515 continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
516+ Continuous-space transition probabilities between states and bins.
471517 discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
518+ Discrete state transition probabilities.
472519
473- Return
474- ------
520+ Returns
521+ -------
475522 acausal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
523+ Acausal (forward-backward) posterior probabilities.
476524
477525 """
478526 causal_posterior = cp .asarray (causal_posterior , dtype = cp .float32 )
0 commit comments