Skip to content

Commit 5ce6bab

Browse files
committed
Add and improve docstrings across core modules
Expanded and standardized docstrings for classes and functions in classifier, core, decoder, environments, discrete/continuous state transitions, and likelihood modules. Docstrings now include parameter and return value descriptions, improving code readability and API documentation for trajectory decoding and classification.
1 parent 151318c commit 5ce6bab

File tree

9 files changed

+340
-136
lines changed

9 files changed

+340
-136
lines changed

replay_trajectory_classification/classifier.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
"""State space models that classify trajectories as well as decode the
2-
trajectory from population spiking
1+
"""State space models that classify trajectories and decode from population spiking.
2+
3+
This module contains classes for trajectory classification using neural population
4+
activity from both sorted spikes and clusterless data.
35
"""
46

57
from __future__ import annotations
@@ -77,7 +79,24 @@
7779

7880

7981
class _ClassifierBase(BaseEstimator, ABC):
80-
"""Base class for classifier objects."""
82+
"""Base class for trajectory classifier objects.
83+
84+
Parameters
85+
----------
86+
environments : list[Environment], optional
87+
List of spatial environments to classify trajectories within.
88+
observation_models : ObservationModel, optional
89+
Observation models for neural data, by default None.
90+
continuous_transition_types : list[list[...]], optional
91+
Continuous state transition models for each state.
92+
discrete_transition_type : DiagonalDiscrete | RandomDiscrete | UniformDiscrete | UserDefinedDiscrete, optional
93+
Discrete state transition model, by default DiagonalDiscrete(0.968).
94+
initial_conditions_type : UniformInitialConditions | UniformOneEnvironmentInitialConditions, optional
95+
Initial conditions model, by default UniformInitialConditions().
96+
infer_track_interior : bool, optional
97+
Whether to infer track interior from position data, by default True.
98+
99+
"""
81100

82101
def __init__(
83102
self,
@@ -121,13 +140,15 @@ def __init__(
121140
def fit_environments(
122141
self, position: NDArray[np.float64], environment_labels: Optional[NDArray[np.int64]] = None
123142
) -> None:
124-
"""Fits the Environment class on the position data to get information about the spatial environment.
143+
"""Fit the Environment class on position data to extract spatial information.
125144
126145
Parameters
127146
----------
128147
position : NDArray[np.float64], shape (n_time, n_position_dims)
148+
Position coordinates over time.
129149
environment_labels : NDArray[np.int64], optional, shape (n_time,)
130-
Labels for each time points about which environment it corresponds to, by default None
150+
Labels indicating which environment each time point corresponds to,
151+
by default None.
131152
132153
"""
133154
for environment in self.environments:
@@ -144,7 +165,7 @@ def fit_environments(
144165
)
145166

146167
def fit_initial_conditions(self):
147-
"""Constructs the initial probability for the state and each spatial bin."""
168+
"""Construct the initial probability distribution for states and spatial bins."""
148169
logger.info("Fitting initial conditions...")
149170
environment_names_to_state = [
150171
obs.environment_name for obs in self.observation_models
@@ -178,7 +199,7 @@ def fit_continuous_state_transition(
178199
encoding_group_labels: Optional[NDArray[np.int64]] = None,
179200
environment_labels: Optional[NDArray[np.int64]] = None,
180201
) -> None:
181-
"""Constructs the transition matrices for the continuous states.
202+
"""Construct the transition matrices for the continuous states.
182203
183204
Parameters
184205
----------
@@ -240,7 +261,7 @@ def fit_continuous_state_transition(
240261
] = st
241262

242263
def fit_discrete_state_transition(self):
243-
"""Constructs the transition matrix for the discrete states."""
264+
"""Construct the transition matrix for the discrete states."""
244265
logger.info("Fitting discrete state transition")
245266
n_states = len(self.continuous_transition_types)
246267
self.discrete_state_transition_ = (

replay_trajectory_classification/continuous_state_transitions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ def _normalize_row_probability(x: NDArray[np.float64]) -> NDArray[np.float64]:
1818
Parameters
1919
----------
2020
x : NDArray[np.float64], shape (n_rows, n_cols)
21+
Input matrix to normalize.
2122
2223
Returns
2324
-------
2425
normalized_x : NDArray[np.float64], shape (n_rows, n_cols)
26+
Row-normalized matrix where each row sums to 1.
2527
2628
"""
2729
x /= x.sum(axis=1, keepdims=True)
@@ -32,19 +34,19 @@ def _normalize_row_probability(x: NDArray[np.float64]) -> NDArray[np.float64]:
3234
def estimate_movement_var(
3335
position: NDArray[np.float64], sampling_frequency: int = 1
3436
) -> NDArray[np.float64]:
35-
"""Estimates the movement variance based on position.
37+
"""Estimate the movement variance based on position data.
3638
3739
Parameters
3840
----------
39-
position : NDArray[np.float64], shape (n_time, n_position_dim)
40-
Position of the animal
41+
position : NDArray[np.float64], shape (n_time, n_position_dims)
42+
Position coordinates of the animal over time.
4143
sampling_frequency : int, optional
42-
Number of samples per second.
44+
Number of samples per second, by default 1.
4345
4446
Returns
4547
-------
46-
movement_var : NDArray[np.float64], shape (n_position_dim,)
47-
Variance of the movement.
48+
movement_var : NDArray[np.float64], shape (n_position_dims,)
49+
Covariance matrix of movement scaled by sampling frequency.
4850
4951
"""
5052
position = atleast_2d(position)

replay_trajectory_classification/core.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
"""Core algorithms for decoding."""
1+
"""Core algorithms for decoding.
2+
3+
This module contains the fundamental algorithms for Bayesian decoding including
4+
causal and acausal state estimation and classification.
5+
"""
26

37
from __future__ import annotations
48

@@ -46,18 +50,17 @@ def get_centers(bin_edges: NDArray[np.float64]) -> NDArray[np.float64]:
4650

4751
@njit(parallel=True, error_model="numpy")
4852
def normalize_to_probability(distribution: NDArray[np.float64]) -> NDArray[np.float64]:
49-
"""Ensure the distribution integrates to 1 so that it is a probability
50-
distribution.
53+
"""Ensure the distribution integrates to 1 so that it is a probability distribution.
5154
5255
Parameters
5356
----------
5457
distribution : NDArray[np.float64]
55-
Probability distribution to normalize
58+
Probability distribution values to normalize.
5659
5760
Returns
5861
-------
59-
NDArray[np.float64]
60-
Normalized probability distribution that sums to 1
62+
normalized_distribution : NDArray[np.float64]
63+
Normalized probability distribution that sums to 1.
6164
6265
"""
6366
return distribution / np.nansum(distribution)
@@ -73,13 +76,18 @@ def _causal_decode(
7376
Parameters
7477
----------
7578
initial_conditions : NDArray[np.float64], shape (n_bins,)
79+
Initial probability distribution over state bins.
7680
state_transition : NDArray[np.float64], shape (n_bins, n_bins)
81+
Transition probability matrix between state bins.
7782
likelihood : NDArray[np.float64], shape (n_time, n_bins)
83+
Likelihood values for each time point and state bin.
7884
7985
Returns
8086
-------
8187
posterior : NDArray[np.float64], shape (n_time, n_bins)
88+
Posterior probability distribution over time and state bins.
8289
log_data_likelihood : float
90+
Log-likelihood of the observed data.
8391
8492
"""
8593

@@ -109,11 +117,14 @@ def _acausal_decode(
109117
Parameters
110118
----------
111119
causal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
120+
Causal (forward-pass) posterior probabilities.
112121
state_transition : NDArray[np.float64], shape (n_bins, n_bins)
122+
Transition probability matrix between state bins.
113123
114-
Return
115-
------
124+
Returns
125+
-------
116126
acausal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
127+
Acausal (forward-backward) posterior probabilities.
117128
118129
"""
119130
acausal_posterior = np.zeros_like(causal_posterior)
@@ -150,14 +161,20 @@ def _causal_classify(
150161
Parameters
151162
----------
152163
initial_conditions : NDArray[np.float64], shape (n_states, n_bins, 1)
164+
Initial probability distribution for each state and spatial bin.
153165
continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
166+
Continuous-space transition probabilities between states and bins.
154167
discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
168+
Discrete state transition probabilities.
155169
likelihood : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
170+
Likelihood values for each time point, state, and spatial bin.
156171
157172
Returns
158173
-------
159174
causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
175+
Causal posterior probabilities over time, states, and spatial bins.
160176
log_data_likelihood : float
177+
Log-likelihood of the observed data.
161178
162179
"""
163180
n_time, n_states, n_bins, _ = likelihood.shape
@@ -196,12 +213,16 @@ def _acausal_classify(
196213
Parameters
197214
----------
198215
causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
216+
Causal (forward-pass) posterior probabilities.
199217
continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
218+
Continuous-space transition probabilities between states and bins.
200219
discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
220+
Discrete state transition probabilities.
201221
202-
Return
203-
------
222+
Returns
223+
-------
204224
acausal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
225+
Acausal (forward-backward) posterior probabilities.
205226
206227
"""
207228
acausal_posterior = np.zeros_like(causal_posterior)
@@ -242,11 +263,14 @@ def scaled_likelihood(log_likelihood: NDArray[np.float64], axis: int = 1) -> NDA
242263
Parameters
243264
----------
244265
log_likelihood : NDArray[np.float64], shape (n_time, n_bins)
245-
axis : int
266+
Log-likelihood values to be scaled.
267+
axis : int, optional
268+
Axis along which to find the maximum, by default 1.
246269
247270
Returns
248271
-------
249272
scaled_log_likelihood : NDArray[np.float64], shape (n_time, n_bins)
273+
Likelihood values scaled so that the maximum is 1.
250274
251275
"""
252276
max_log_likelihood = np.nanmax(log_likelihood, axis=axis, keepdims=True)
@@ -269,11 +293,14 @@ def mask(value: NDArray[np.float64], is_track_interior: NDArray[np.bool_]) -> ND
269293
Parameters
270294
----------
271295
value : NDArray[np.float64], shape (..., n_bins)
296+
Input values to be masked.
272297
is_track_interior : NDArray[np.bool_], shape (n_bins,)
298+
Boolean array indicating which bins are part of the track interior.
273299
274300
Returns
275301
-------
276-
masked_value : NDArray[np.float64]
302+
masked_value : NDArray[np.float64], shape (..., n_bins)
303+
Values with non-track bins set to NaN.
277304
278305
"""
279306
try:
@@ -288,24 +315,27 @@ def check_converged(
288315
previous_log_likelihood: NDArray[np.float64],
289316
tolerance: float = 1e-4,
290317
) -> Tuple[bool, bool]:
291-
"""We have converged if the slope of the log-likelihood function falls below 'tolerance',
318+
"""Check if log-likelihood has converged.
292319
293-
i.e., |f(t) - f(t-1)| / avg < tolerance,
294-
where avg = (|f(t)| + |f(t-1)|)/2 and f(t) is log lik at iteration t.
320+
We have converged if the slope of the log-likelihood function falls below
321+
'tolerance', i.e., |f(t) - f(t-1)| / avg < tolerance, where
322+
avg = (|f(t)| + |f(t-1)|)/2 and f(t) is log likelihood at iteration t.
295323
296324
Parameters
297325
----------
298326
log_likelihood : NDArray[np.float64]
299-
Current log likelihood
327+
Current log likelihood values.
300328
previous_log_likelihood : NDArray[np.float64]
301-
Previous log likelihood
329+
Previous log likelihood values.
302330
tolerance : float, optional
303-
threshold for similarity, by default 1e-4
331+
Threshold for similarity, by default 1e-4.
304332
305333
Returns
306334
-------
307335
is_converged : bool
336+
Whether the log-likelihood has converged.
308337
is_increasing : bool
338+
Whether the log-likelihood is increasing.
309339
310340
"""
311341
delta_log_likelihood = abs(log_likelihood - previous_log_likelihood)
@@ -334,13 +364,18 @@ def _causal_decode_gpu(
334364
Parameters
335365
----------
336366
initial_conditions : NDArray[np.float64], shape (n_bins,)
367+
Initial probability distribution over state bins.
337368
state_transition : NDArray[np.float64], shape (n_bins, n_bins)
369+
Transition probability matrix between state bins.
338370
likelihood : NDArray[np.float64], shape (n_time, n_bins)
371+
Likelihood values for each time point and state bin.
339372
340373
Returns
341374
-------
342375
posterior : NDArray[np.float64], shape (n_time, n_bins)
376+
Posterior probability distribution over time and state bins.
343377
log_data_likelihood : float
378+
Log-likelihood of the observed data.
344379
345380
"""
346381

@@ -373,11 +408,14 @@ def _acausal_decode_gpu(
373408
Parameters
374409
----------
375410
causal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
411+
Causal (forward-pass) posterior probabilities.
376412
state_transition : NDArray[np.float64], shape (n_bins, n_bins)
413+
Transition probability matrix between state bins.
377414
378-
Return
379-
------
415+
Returns
416+
-------
380417
acausal_posterior : NDArray[np.float64], shape (n_time, n_bins, 1)
418+
Acausal (forward-backward) posterior probabilities.
381419
382420
"""
383421
causal_posterior = cp.asarray(causal_posterior, dtype=cp.float32)
@@ -414,14 +452,20 @@ def _causal_classify_gpu(
414452
Parameters
415453
----------
416454
initial_conditions : NDArray[np.float64], shape (n_states, n_bins, 1)
455+
Initial probability distribution for each state and spatial bin.
417456
continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
457+
Continuous-space transition probabilities between states and bins.
418458
discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
459+
Discrete state transition probabilities.
419460
likelihood : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
461+
Likelihood values for each time point, state, and spatial bin.
420462
421463
Returns
422464
-------
423465
causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
466+
Causal posterior probabilities over time, states, and spatial bins.
424467
log_data_likelihood : float
468+
Log-likelihood of the observed data.
425469
426470
"""
427471
initial_conditions = cp.asarray(initial_conditions, dtype=cp.float32)
@@ -467,12 +511,16 @@ def _acausal_classify_gpu(
467511
Parameters
468512
----------
469513
causal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
514+
Causal (forward-pass) posterior probabilities.
470515
continuous_state_transition : NDArray[np.float64], shape (n_states, n_states, n_bins, n_bins)
516+
Continuous-space transition probabilities between states and bins.
471517
discrete_state_transition : NDArray[np.float64], shape (n_states, n_states)
518+
Discrete state transition probabilities.
472519
473-
Return
474-
------
520+
Returns
521+
-------
475522
acausal_posterior : NDArray[np.float64], shape (n_time, n_states, n_bins, 1)
523+
Acausal (forward-backward) posterior probabilities.
476524
477525
"""
478526
causal_posterior = cp.asarray(causal_posterior, dtype=cp.float32)

0 commit comments

Comments
 (0)