Skip to content
Open
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
26e4a8f
Update module docstring
edeno May 2, 2025
7132d50
Add some helper functions
edeno May 5, 2025
d4f00e8
Update docstring
edeno May 5, 2025
dbc3eac
Add typing, update docstring for get_n_bins
edeno May 5, 2025
c666560
Rename `get_grid` to `make_grid`
edeno May 5, 2025
e73734b
Rename `get_track_interior` to `infer_track_interior`
edeno May 5, 2025
ce23e6c
boundary bins False by default
edeno May 5, 2025
91fe471
Rename function to private to avoid conflict
edeno May 5, 2025
7fc2cee
Handle bin size sequence
edeno May 5, 2025
05381ec
Move get_track_boundary
edeno May 5, 2025
4f51089
Remove unneeded parameter
edeno May 5, 2025
4967979
Alias fit_place_grid to fit
edeno May 5, 2025
d61f966
Add gap closing as a parameter
edeno May 5, 2025
748260b
Turn on boundary for now
edeno May 5, 2025
a8c1603
Refactor track_graphDD to track_graph_nd
edeno May 5, 2025
9635575
Indicate functions private
edeno May 5, 2025
d893050
Refactor into more private functions. Return place_bin_centers graph
edeno May 5, 2025
f39c29d
Update plotting
edeno May 5, 2025
5252262
Rearange helper function location
edeno May 5, 2025
6ac617f
Add ability to calculate linear position on environment
edeno May 5, 2025
03e2476
Add method to get fitted track graph
edeno May 5, 2025
1e2838e
Fix call
edeno May 5, 2025
bf6d6af
More private functions and docstring cleanup
edeno May 5, 2025
44db5e7
Add bin information to nd_graph
edeno May 7, 2025
aec9d52
Use function from track_linearization
edeno May 7, 2025
4d6484d
Shorten name
edeno May 7, 2025
b7f9b43
Add bin ind to 1D graph
edeno May 7, 2025
173af25
Add exterior bins to graph
edeno May 7, 2025
e1f38db
Use bin ind attribute for distance
edeno May 7, 2025
38df09b
Use helper function for distance
edeno May 7, 2025
d9714ad
Fix leftover track_graphDD references
edeno May 7, 2025
68c8e85
Make distance_between_bins a cached property
edeno May 7, 2025
098c43b
Fix syntax
edeno May 8, 2025
a809b21
More efficient and handle 1D/nD case `get_bin_ind`
edeno May 8, 2025
d734685
Fix arguments
edeno May 8, 2025
51e6c9a
Update environment.py
edeno May 8, 2025
8fe2072
Fix names
edeno May 8, 2025
1773f1e
Ignore division by zero
edeno May 8, 2025
a81d0a0
Calculate rw movement variance function
edeno May 8, 2025
96c01df
Fix name
edeno May 8, 2025
992c176
Add functions for bin coordinates and serialization
edeno May 8, 2025
165de52
Remove unused attributes
edeno May 8, 2025
9da2da8
Set index name
edeno May 8, 2025
9543853
Simpler names for `load_env`, `save_env`, `fit_place_grid`
edeno May 8, 2025
1dc3a08
Remove 1D dataframes and track graphs
edeno May 8, 2025
d9cc95d
Use indexing
edeno May 8, 2025
f18ca13
Fix names and minor errors
edeno May 8, 2025
0d2a0ad
Fix typing and docstring
edeno May 8, 2025
9e2cd36
Handle n-D
edeno May 8, 2025
172b56b
Fix docstring
edeno May 8, 2025
5afe05b
Fix criterion
edeno May 8, 2025
952b840
Simpler nd graph construction
edeno May 8, 2025
b410f0e
Make sure functions work with n-D
edeno May 8, 2025
eec7bbd
Only return the diagonal
edeno May 8, 2025
b45ec91
Don't add boundary bins
edeno May 8, 2025
b8024d8
Update so distance between each bin is added
edeno May 8, 2025
777ef61
Make functions private, remove unused functions
edeno May 8, 2025
7ba9f3f
Handle lists
edeno May 9, 2025
6223f3a
Add tests
edeno May 9, 2025
7e82fb1
Add more tests
edeno May 9, 2025
7b5f5ae
Remove unintended line
edeno May 9, 2025
a94df3f
Add more helper functions
edeno May 9, 2025
f7a2911
Cleanup tests
edeno May 9, 2025
d9a471d
Add diffusion code
edeno May 9, 2025
cc87a3b
Add notebook
edeno May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/non_local_detector/analysis/distance1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def _get_MAP_estimate_2d_position_edges(
try:
place_bin_center_2D_position = env.place_bin_center_2D_position_
except AttributeError:
place_bin_center_2D_position = np.asarray(
env.place_bin_centers_nodes_df_.loc[:, ["x_position", "y_position"]]
place_bin_center_2D_position = (
env.get_bin_center_dataframe().loc[:, ["pos_x", "pos_y"]].to_numpy()
)

mental_position_2d = place_bin_center_2D_position[map_position_ind]
Expand All @@ -94,7 +94,7 @@ def _get_MAP_estimate_2d_position_edges(
try:
edge_id = env.place_bin_center_ind_to_edge_id_
except AttributeError:
edge_id = np.asarray(env.place_bin_centers_nodes_df_.edge_id)
edge_id = env.get_bin_center_dataframe().edge_id.to_numpy()

track_segment_id = edge_id[map_position_ind]
mental_position_edges = np.asarray(list(track_graph.edges))[track_segment_id]
Expand Down Expand Up @@ -422,7 +422,7 @@ def get_ahead_behind_distance(

def get_map_speed(
posterior: xr.DataArray,
track_graph_with_bin_centers_edges: nx.Graph,
track_graph_bin_centers_edges: nx.Graph,
place_bin_center_ind_to_node: np.ndarray,
sampling_frequency: float = 500.0,
smooth_sigma: float = 0.0025,
Expand All @@ -432,7 +432,7 @@ def get_map_speed(
Parameters
----------
posterior : xr.DataArray
track_graph_with_bin_centers_edges : nx.Graph
track_graph_bin_centers_edges : nx.Graph
Track graph with bin centers as nodes and edges
place_bin_center_ind_to_node : np.ndarray
Mapping of place bin center index to node ID
Expand Down Expand Up @@ -460,7 +460,7 @@ def get_map_speed(
speed,
0,
nx.shortest_path_length(
track_graph_with_bin_centers_edges,
track_graph_bin_centers_edges,
source=node_ids[0],
target=node_ids[1],
weight="distance",
Expand All @@ -471,7 +471,7 @@ def get_map_speed(
speed,
-1,
nx.shortest_path_length(
track_graph_with_bin_centers_edges,
track_graph_bin_centers_edges,
source=node_ids[-2],
target=node_ids[-1],
weight="distance",
Expand All @@ -483,7 +483,7 @@ def get_map_speed(
for node1, node2 in zip(node_ids[:-2], node_ids[2:]):
speed.append(
nx.shortest_path_length(
track_graph_with_bin_centers_edges,
track_graph_bin_centers_edges,
source=node1,
target=node2,
weight="distance",
Expand All @@ -495,7 +495,7 @@ def get_map_speed(
speed,
0,
nx.shortest_path_length(
track_graph_with_bin_centers_edges,
track_graph_bin_centers_edges,
source=node_ids[0],
target=node_ids[1],
weight="distance",
Expand All @@ -506,7 +506,7 @@ def get_map_speed(
speed,
-1,
nx.shortest_path_length(
track_graph_with_bin_centers_edges,
track_graph_bin_centers_edges,
source=node_ids[-2],
target=node_ids[-1],
weight="distance",
Expand Down
52 changes: 42 additions & 10 deletions src/non_local_detector/continuous_state_transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def _normalize_row_probability(x: np.ndarray) -> np.ndarray:
"""
# Handle cases where the sum is zero to avoid division by zero -> NaN
row_sums = x.sum(axis=1, keepdims=True)
# Use np.errstate to temporarily ignore invalid division warnings
with np.errstate(invalid="ignore"):
# Use np.errstate to temporarily ignore invalid division/zero warnings
with np.errstate(invalid="ignore", divide="ignore"):
normalized_x = np.where(row_sums > 0, x / row_sums, 0.0)
# Ensure any remaining NaNs (though unlikely with the above) are zero
normalized_x[np.isnan(normalized_x)] = 0.0
Expand Down Expand Up @@ -187,8 +187,8 @@ def _handle_no_track_graph(self) -> np.ndarray:
else:
transition_matrix = (
multivariate_normal(mean=self.movement_mean, cov=self.movement_var)
.pdf(self.environment.distance_between_nodes_.flat)
.reshape(self.environment.distance_between_nodes_.shape)
.pdf(self.environment.distance_between_bins.flat)
.reshape(self.environment.distance_between_bins.shape)
)

if self.direction is not None:
Expand All @@ -198,14 +198,14 @@ def _handle_no_track_graph(self) -> np.ndarray:
}.get(self.direction.lower(), None)

centrality = nx.closeness_centrality(
self.environment.track_graphDD, distance="distance"
self.environment.track_graph_nd_, distance="distance"
)
center_node_id = list(centrality.keys())[
np.argmax(list(centrality.values()))
]
transition_matrix *= direction_func(
self.environment.distance_between_nodes_[:, [center_node_id]],
self.environment.distance_between_nodes_[[center_node_id]],
self.environment.distance_between_bins[:, [center_node_id]],
self.environment.distance_between_bins[[center_node_id]],
)

return transition_matrix
Expand All @@ -218,15 +218,15 @@ def _handle_with_track_graph(self) -> np.ndarray:
"Random walk with track graph is only implemented for 1D environments"
)

place_bin_center_ind_to_node = np.asarray(
self.environment.place_bin_centers_nodes_df_.node_id
place_bin_center_ind_to_node = (
self.environment.get_bin_center_dataframe().reset_index().node_id.to_numpy()
)
return _random_walk_on_track_graph(
self.environment.place_bin_centers_,
self.movement_mean,
self.movement_var,
place_bin_center_ind_to_node,
self.environment.distance_between_nodes_,
self.environment.distance_between_bins,
)


Expand Down Expand Up @@ -514,3 +514,35 @@ def make_state_transition(self, *args, **kwargs) -> np.ndarray:
state_transition_matrix : np.ndarray, shape (1, 1)
"""
return np.ones((1, 1))


def calculate_rw_movement_variance(speed: float, time_step: float) -> float:
"""Calculates the variance for a RandomWalk model based on speed.

Assumes the characteristic distance traveled in one time step
(speed * time_step) corresponds to the standard deviation of
displacement per dimension.

Parameters
----------
speed : float
Characteristic speed (e.g., in cm/s).
time_step : float
Duration of one time step. Units must be consistent
with speed's time unit (e.g., in 0.002 s).

Returns
-------
movement_variance : float
The calculated variance (movement_var) for the RandomWalk model
(e.g., in cm^2).
"""
if time_step <= 0:
raise ValueError("time_step must be positive.")

# Calculate characteristic distance
# (interpreted as standard deviation sigma)
sigma = speed * time_step

# Variance is sigma squared
return sigma**2
Loading