1010import matplotlib .pyplot as plt
1111import networkx as nx
1212import numpy as np
13- from numpy .typing import NDArray
1413import pandas as pd
1514from numba import njit
15+ from numpy .typing import NDArray
1616from scipy import ndimage
1717from scipy .interpolate import interp1d
1818from sklearn .neighbors import NearestNeighbors
@@ -85,7 +85,9 @@ def __eq__(self, other: object) -> bool:
8585 return self .environment_name == other .environment_name
8686
8787 def fit_place_grid (
88- self , position : Optional [NDArray [np .float64 ]] = None , infer_track_interior : bool = True
88+ self ,
89+ position : Optional [NDArray [np .float64 ]] = None ,
90+ infer_track_interior : bool = True ,
8991 ):
9092 """Fits a discrete grid of the spatial environment.
9193
@@ -102,6 +104,8 @@ def fit_place_grid(
102104
103105 """
104106 if self .track_graph is None :
107+ if position is None :
108+ raise ValueError ("Must provide position if no track graph given." )
105109 (
106110 self .edges_ ,
107111 self .place_bin_edges_ ,
@@ -217,7 +221,7 @@ def get_n_bins(
217221 position : NDArray [np .float64 ],
218222 bin_size : float = 2.5 ,
219223 position_range : Optional [list [NDArray [np .float64 ]]] = None ,
220- ) -> int :
224+ ) -> NDArray [ np . int32 ] :
221225 """Get number of bins need to span a range given a bin size.
222226
223227 Parameters
@@ -229,7 +233,7 @@ def get_n_bins(
229233
230234 Returns
231235 -------
232- n_bins : int
236+ n_bins : NDArray[np.int32], shape (n_position_dims,)
233237
234238 """
235239 if position_range is not None :
@@ -734,7 +738,9 @@ def get_track_grid(
734738
735739
736740def get_track_boundary (
737- is_track_interior : NDArray [np .bool_ ], n_position_dims : int = 2 , connectivity : int = 1
741+ is_track_interior : NDArray [np .bool_ ],
742+ n_position_dims : int = 2 ,
743+ connectivity : int = 1 ,
738744) -> NDArray [np .bool_ ]:
739745 """Determines the boundary of the valid interior track bins. The boundary
740746 are not bins on the track but surround it.
@@ -798,7 +804,9 @@ def order_boundary(boundary: NDArray[np.float64]) -> NDArray[np.float64]:
798804
799805
800806def get_track_boundary_points (
801- is_track_interior : NDArray [np .bool_ ], edges : list [NDArray [np .float64 ]], connectivity : int = 1
807+ is_track_interior : NDArray [np .bool_ ],
808+ edges : list [NDArray [np .float64 ]],
809+ connectivity : int = 1 ,
802810) -> NDArray [np .float64 ]:
803811 """
804812
@@ -991,7 +999,9 @@ def diffuse_each_bin(
991999 return diffused_grid
9921000
9931001
994- def get_bin_ind (sample : NDArray [np .float64 ], edges : list [NDArray [np .float64 ]]) -> NDArray [np .int64 ]:
1002+ def get_bin_ind (
1003+ sample : NDArray [np .float64 ], edges : list [NDArray [np .float64 ]]
1004+ ) -> NDArray [np .int64 ]:
9951005 """Figure out which bin a given sample falls into.
9961006
9971007 Extracted from np.histogramdd.
0 commit comments