Skip to content

Commit fe45c00

Browse files
committed
Refactor type annotations and return values for clarity
Updated type annotations to use Sequence and Mapping where appropriate, replacing list and dict. Refactored several functions to assign return values to variables before returning, improving code readability and type checking. No functional changes were made.
1 parent 393ca0f commit fe45c00

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

src/track_linearization/core.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterator
1+
from collections.abc import Iterator, Mapping, Sequence
22
from math import sqrt
33
from typing import Any
44

@@ -187,9 +187,10 @@ def project_points_to_segment(
187187

188188
np.clip(nx_param, 0.0, 1.0, out=nx_param)
189189

190-
return node1[np.newaxis, ...] + ( # type: ignore[no-any-return]
190+
result: np.ndarray = node1[np.newaxis, ...] + (
191191
nx_param[:, :, np.newaxis] * segment_diff[np.newaxis, ...]
192192
)
193+
return result
193194

194195

195196
def find_projected_point_distance(
@@ -209,11 +210,12 @@ def find_projected_point_distance(
209210
distances : np.ndarray, shape (n_time, n_segments)
210211
Euclidean distance from each point to its projection on each segment.
211212
"""
212-
return np.linalg.norm( # type: ignore[no-any-return]
213+
result: np.ndarray = np.linalg.norm(
213214
position[:, np.newaxis, :]
214215
- project_points_to_segment(track_segments, position),
215216
axis=2,
216217
)
218+
return result
217219

218220

219221
def find_nearest_segment(
@@ -235,7 +237,8 @@ def find_nearest_segment(
235237
Index of the nearest track segment for each time point.
236238
"""
237239
distance = find_projected_point_distance(track_segments, position)
238-
return np.argmin(distance, axis=1) # type: ignore[no-any-return]
240+
result: np.ndarray = np.argmin(distance, axis=1)
241+
return result
239242

240243

241244
def euclidean_distance_change(position: np.ndarray) -> np.ndarray:
@@ -329,11 +332,11 @@ def route_distance(
329332
start_node_ind = node_ind[n_original_nodes::2] # Corresponds to t_0_ nodes
330333
end_node_ind = node_ind[n_original_nodes + 1 :: 2] # Corresponds to t_1_ nodes
331334

332-
dist_matrix_slice = path_distance[start_node_ind][:, end_node_ind]
335+
dist_matrix_slice: np.ndarray = path_distance[start_node_ind][:, end_node_ind]
333336

334337
track_graph.remove_nodes_from(node_names) # Clean up graph
335338

336-
return dist_matrix_slice # type: ignore[no-any-return]
339+
return dist_matrix_slice
337340

338341

339342
def batch(n_samples: int, batch_size: int = 1) -> Iterator[range]:
@@ -462,9 +465,10 @@ def calculate_position_likelihood(
462465
projected_position_distance = find_projected_point_distance(
463466
track_segments, position
464467
)
465-
return np.exp(-0.5 * (projected_position_distance / sigma) ** 2) / ( # type: ignore[no-any-return]
468+
result: np.ndarray = np.exp(-0.5 * (projected_position_distance / sigma) ** 2) / (
466469
np.sqrt(2 * np.pi) * sigma
467470
)
471+
return result
468472

469473

470474
def normalize_to_probability(x: np.ndarray, axis: int = -1) -> np.ndarray:
@@ -484,7 +488,8 @@ def normalize_to_probability(x: np.ndarray, axis: int = -1) -> np.ndarray:
484488
If a sum along the axis is 0, the original values in that slice will
485489
result in NaNs or Infs after division.
486490
"""
487-
return x / x.sum(axis=axis, keepdims=True) # type: ignore[no-any-return]
491+
result: np.ndarray = x / x.sum(axis=axis, keepdims=True)
492+
return result
488493

489494

490495
def calculate_empirical_state_transition(
@@ -785,7 +790,7 @@ def batch_linear_distance(
785790

786791

787792
def _normalize_edge_spacing(
788-
edge_spacing: float | list[float], n_edges: int
793+
edge_spacing: float | Sequence[float], n_edges: int
789794
) -> list[float]:
790795
"""Convert edge_spacing to list format with validation.
791796
@@ -839,10 +844,10 @@ def _normalize_edge_spacing(
839844
def _apply_edge_map_to_positions(
840845
linear_position: np.ndarray,
841846
track_segment_id: np.ndarray,
842-
edge_map: dict[int, int],
847+
edge_map: Mapping[int, int | str],
843848
track_graph: Graph,
844849
edge_order: list[Edge],
845-
edge_spacing: float | list[float],
850+
edge_spacing: float | Sequence[float],
846851
) -> tuple[np.ndarray, np.ndarray]:
847852
"""Apply edge_map to adjust linear positions and segment IDs.
848853
@@ -939,7 +944,7 @@ def _calculate_linear_position(
939944
position: np.ndarray,
940945
track_segment_id: np.ndarray,
941946
edge_order: list[Edge],
942-
edge_spacing: float | list[float],
947+
edge_spacing: float | Sequence[float],
943948
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
944949
"""Determines the linear position given a 2D position and a track graph.
945950
@@ -1028,12 +1033,12 @@ def get_linearized_position(
10281033
position: np.ndarray,
10291034
track_graph: Graph,
10301035
edge_order: list[Edge] | None = None,
1031-
edge_spacing: float | list[float] = 0.0,
1036+
edge_spacing: float | Sequence[float] = 0.0,
10321037
use_HMM: bool = False,
10331038
route_euclidean_distance_scaling: float = 1.0,
10341039
sensor_std_dev: float = 5.0,
10351040
diagonal_bias: float = 0.1,
1036-
edge_map: dict[int, int] | None = None,
1041+
edge_map: Mapping[int, int | str] | None = None,
10371042
) -> pd.DataFrame:
10381043
"""Linearize 2D position based on graph representation of track.
10391044
@@ -1238,7 +1243,7 @@ def project_1d_to_2d(
12381243
linear_position: np.ndarray,
12391244
track_graph: nx.Graph,
12401245
edge_order: list[Edge],
1241-
edge_spacing: float | list[float] = 0.0,
1246+
edge_spacing: float | Sequence[float] = 0.0,
12421247
) -> np.ndarray:
12431248
"""
12441249
Map 1-D linear positions back to 2-D coordinates on the track graph.
@@ -1297,8 +1302,8 @@ def project_1d_to_2d(
12971302
u = np.array([node_pos[edge_order[i][0]] for i in idx])
12981303
v = np.array([node_pos[edge_order[i][1]] for i in idx])
12991304

1300-
coords = (1.0 - t[:, None]) * u + t[:, None] * v
1305+
coords: np.ndarray = (1.0 - t[:, None]) * u + t[:, None] * v
13011306

13021307
# propagate NaNs from the input
13031308
coords[nan_mask] = np.nan
1304-
return coords # type: ignore[no-any-return]
1309+
return coords

src/track_linearization/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def _plot_linear_segment(
267267
def plot_graph_as_1D(
268268
track_graph: nx.Graph,
269269
edge_order: list[Edge] | None = None,
270-
edge_spacing: float | list[float] = 0,
270+
edge_spacing: float | Sequence[float] = 0,
271271
ax: plt.Axes | None = None,
272272
axis: str = "x",
273273
other_axis_start: float = 0.0,
@@ -430,7 +430,8 @@ def _get_projected_track_position(
430430
track_segments = get_track_segments_from_graph(track_graph)
431431
projected_track_position = project_points_to_segment(track_segments, position)
432432
n_time = projected_track_position.shape[0]
433-
return projected_track_position[(np.arange(n_time), track_segment_id)] # type: ignore[no-any-return]
433+
result: np.ndarray = projected_track_position[(np.arange(n_time), track_segment_id)]
434+
return result
434435

435436

436437
def make_actual_vs_linearized_position_movie(

src/track_linearization/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def get_projection_confidence(
163163
scale = 1.0 # All points on track
164164

165165
# Calculate confidence using Gaussian-like decay
166-
confidence = np.exp(-(distances**2) / (2 * scale**2))
166+
confidence: np.ndarray = np.exp(-(distances**2) / (2 * scale**2))
167167

168-
return confidence # type: ignore[no-any-return]
168+
return confidence
169169

170170

171171
def detect_linearization_outliers(

0 commit comments

Comments
 (0)