Skip to content

Commit 3a2be39

Browse files
committed
Refactor environment functions and type annotations
Moved NDArray import for consistency, updated type annotations for get_n_bins to return NDArray instead of int, and improved argument formatting for several functions to enhance readability and maintainability. Added error handling in fit_place_grid for missing position when track_graph is None.
1 parent 4d85c34 commit 3a2be39

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

replay_trajectory_classification/environments.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import matplotlib.pyplot as plt
1111
import networkx as nx
1212
import numpy as np
13-
from numpy.typing import NDArray
1413
import pandas as pd
1514
from numba import njit
15+
from numpy.typing import NDArray
1616
from scipy import ndimage
1717
from scipy.interpolate import interp1d
1818
from sklearn.neighbors import NearestNeighbors
@@ -85,7 +85,9 @@ def __eq__(self, other: object) -> bool:
8585
return self.environment_name == other.environment_name
8686

8787
def fit_place_grid(
88-
self, position: Optional[NDArray[np.float64]] = None, infer_track_interior: bool = True
88+
self,
89+
position: Optional[NDArray[np.float64]] = None,
90+
infer_track_interior: bool = True,
8991
):
9092
"""Fits a discrete grid of the spatial environment.
9193
@@ -102,6 +104,8 @@ def fit_place_grid(
102104
103105
"""
104106
if self.track_graph is None:
107+
if position is None:
108+
raise ValueError("Must provide position if no track graph given.")
105109
(
106110
self.edges_,
107111
self.place_bin_edges_,
@@ -217,7 +221,7 @@ def get_n_bins(
217221
position: NDArray[np.float64],
218222
bin_size: float = 2.5,
219223
position_range: Optional[list[NDArray[np.float64]]] = None,
220-
) -> int:
224+
) -> NDArray[np.int32]:
221225
"""Get number of bins need to span a range given a bin size.
222226
223227
Parameters
@@ -229,7 +233,7 @@ def get_n_bins(
229233
230234
Returns
231235
-------
232-
n_bins : int
236+
n_bins : NDArray[np.int32], shape (n_position_dims,)
233237
234238
"""
235239
if position_range is not None:
@@ -734,7 +738,9 @@ def get_track_grid(
734738

735739

736740
def get_track_boundary(
737-
is_track_interior: NDArray[np.bool_], n_position_dims: int = 2, connectivity: int = 1
741+
is_track_interior: NDArray[np.bool_],
742+
n_position_dims: int = 2,
743+
connectivity: int = 1,
738744
) -> NDArray[np.bool_]:
739745
"""Determines the boundary of the valid interior track bins. The boundary
740746
are not bins on the track but surround it.
@@ -798,7 +804,9 @@ def order_boundary(boundary: NDArray[np.float64]) -> NDArray[np.float64]:
798804

799805

800806
def get_track_boundary_points(
801-
is_track_interior: NDArray[np.bool_], edges: list[NDArray[np.float64]], connectivity: int = 1
807+
is_track_interior: NDArray[np.bool_],
808+
edges: list[NDArray[np.float64]],
809+
connectivity: int = 1,
802810
) -> NDArray[np.float64]:
803811
"""
804812
@@ -991,7 +999,9 @@ def diffuse_each_bin(
991999
return diffused_grid
9921000

9931001

994-
def get_bin_ind(sample: NDArray[np.float64], edges: list[NDArray[np.float64]]) -> NDArray[np.int64]:
1002+
def get_bin_ind(
1003+
sample: NDArray[np.float64], edges: list[NDArray[np.float64]]
1004+
) -> NDArray[np.int64]:
9951005
"""Figure out which bin a given sample falls into.
9961006
9971007
Extracted from np.histogramdd.

0 commit comments

Comments
 (0)