Skip to content

Commit 4240878

Browse files
committed
Handle empty position arrays and add edge case tests
Fixed _calculate_linear_position to handle empty position arrays gracefully. Updated HMM tests to check for empty input handling. Added comprehensive edge case tests for batch utilities and track graph validation, including missing attributes, invalid types, and non-finite values.
1 parent 5683fe6 commit 4240878

File tree

4 files changed

+397
-7
lines changed

4 files changed

+397
-7
lines changed

src/track_linearization/core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,10 +1017,14 @@ def _calculate_linear_position(
10171017
[track_graph.nodes[node]["pos"] for node in start_node_id]
10181018
)
10191019

1020-
linear_position = start_node_linear_position + (
1021-
np.linalg.norm(start_node_2D_position - projected_track_positions, axis=1)
1022-
)
1023-
linear_position[is_nan] = np.nan
1020+
# Handle empty position array edge case
1021+
if len(position) == 0:
1022+
linear_position = np.array([])
1023+
else:
1024+
linear_position = start_node_linear_position + (
1025+
np.linalg.norm(start_node_2D_position - projected_track_positions, axis=1)
1026+
)
1027+
linear_position[is_nan] = np.nan
10241028

10251029
return (
10261030
linear_position,
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""Tests for batch processing and utility functions."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from track_linearization import get_linearized_position, make_track_graph, project_1d_to_2d
7+
from track_linearization.core import (
8+
batch_linear_distance,
9+
route_distance_change,
10+
)
11+
12+
13+
class TestBatchLinearDistance:
14+
"""Test batch_linear_distance function."""
15+
16+
def test_batch_linear_distance_basic(self):
17+
"""Test basic batch linear distance calculation."""
18+
node_positions = [(0, 0), (10, 0), (20, 0), (30, 0)]
19+
edges = [(0, 1), (1, 2), (2, 3)]
20+
track_graph = make_track_graph(node_positions, edges)
21+
22+
# Project positions onto the track
23+
projected_positions = np.array([
24+
[5, 0], # Middle of edge 0
25+
[15, 0], # Middle of edge 1
26+
[25, 0], # Middle of edge 2
27+
])
28+
29+
edge_ids = [(0, 1), (1, 2), (2, 3)]
30+
linear_zero_node_id = 0
31+
32+
distances = batch_linear_distance(
33+
projected_track_positions=projected_positions,
34+
edge_ids=edge_ids,
35+
track_graph=track_graph,
36+
linear_zero_node_id=linear_zero_node_id,
37+
)
38+
39+
# Check distances are reasonable
40+
assert len(distances) == 3
41+
assert all(isinstance(d, (int, float)) for d in distances)
42+
assert distances[0] < distances[1] < distances[2] # Monotonically increasing
43+
44+
def test_batch_linear_distance_single_position(self):
45+
"""Test batch linear distance with single position."""
46+
node_positions = [(0, 0), (10, 0)]
47+
edges = [(0, 1)]
48+
track_graph = make_track_graph(node_positions, edges)
49+
50+
projected_positions = np.array([[5, 0]])
51+
edge_ids = [(0, 1)]
52+
53+
distances = batch_linear_distance(
54+
projected_track_positions=projected_positions,
55+
edge_ids=edge_ids,
56+
track_graph=track_graph,
57+
linear_zero_node_id=0,
58+
)
59+
60+
assert len(distances) == 1
61+
assert 0 < distances[0] < 10 # Should be between start and end
62+
63+
def test_batch_linear_distance_complex_track(self):
64+
"""Test batch linear distance on more complex track."""
65+
node_positions = [(0, 0), (10, 0), (10, 10), (0, 10)]
66+
edges = [(0, 1), (1, 2), (2, 3)]
67+
track_graph = make_track_graph(node_positions, edges)
68+
69+
projected_positions = np.array([[5, 0], [10, 5], [5, 10]])
70+
edge_ids = [(0, 1), (1, 2), (2, 3)]
71+
72+
distances = batch_linear_distance(
73+
projected_track_positions=projected_positions,
74+
edge_ids=edge_ids,
75+
track_graph=track_graph,
76+
linear_zero_node_id=0,
77+
)
78+
79+
# Distances should increase as we move along track
80+
assert len(distances) == 3
81+
assert distances[0] < distances[1] < distances[2]
82+
83+
84+
class TestRouteDistanceChange:
85+
"""Test route_distance_change function."""
86+
87+
def test_route_distance_change_basic(self):
88+
"""Test basic route distance change calculation."""
89+
node_positions = [(0, 0), (10, 0), (10, 10), (0, 10)]
90+
edges = [(0, 1), (1, 2), (2, 3), (3, 0)]
91+
track_graph = make_track_graph(node_positions, edges)
92+
93+
# Sequence of positions moving along the track
94+
position = np.array([
95+
[5, 0], # Edge 0
96+
[10, 5], # Edge 1
97+
[5, 10], # Edge 2
98+
])
99+
100+
distances = route_distance_change(position, track_graph)
101+
102+
# Check structure - returns (n_time, n_segments, n_segments)
103+
assert distances.shape == (3, 4, 4) # 3 time points, 4 segments
104+
# First time point should have all NaNs (no previous position)
105+
assert np.all(np.isnan(distances[0]))
106+
# Subsequent rows should have finite values
107+
assert np.all(np.isfinite(distances[1:]))
108+
109+
def test_route_distance_change_simple_track(self):
110+
"""Test route distance on simple two-point track."""
111+
node_positions = [(0, 0), (10, 0)]
112+
edges = [(0, 1)]
113+
track_graph = make_track_graph(node_positions, edges)
114+
115+
# Two positions on same segment
116+
position = np.array([[3, 0], [7, 0]])
117+
118+
distances = route_distance_change(position, track_graph)
119+
120+
# Should return (2, 1, 1) for 2 time points and 1 segment
121+
assert distances.shape == (2, 1, 1)
122+
# First row is NaN
123+
assert np.isnan(distances[0, 0, 0])
124+
# Second row should be finite
125+
assert np.isfinite(distances[1, 0, 0])
126+
127+
128+
class TestProject1dTo2d:
129+
"""Test project_1d_to_2d function and edge cases."""
130+
131+
def test_project_1d_to_2d_basic(self):
132+
"""Test basic 1D to 2D projection."""
133+
node_positions = [(0, 0), (10, 0), (20, 0)]
134+
edges = [(0, 1), (1, 2)]
135+
track_graph = make_track_graph(node_positions, edges)
136+
137+
# Linear positions along the track
138+
linear_positions = np.array([5.0, 15.0])
139+
140+
projected = project_1d_to_2d(
141+
linear_positions, track_graph, edge_order=edges, edge_spacing=0.0
142+
)
143+
144+
# Check shape
145+
assert projected.shape == (2, 2)
146+
# Check positions are on track
147+
assert np.allclose(projected[0], [5, 0])
148+
assert np.allclose(projected[1], [15, 0])
149+
150+
def test_project_1d_to_2d_with_spacing(self):
151+
"""Test 1D to 2D projection with edge spacing."""
152+
node_positions = [(0, 0), (10, 0), (20, 0)]
153+
edges = [(0, 1), (1, 2)]
154+
track_graph = make_track_graph(node_positions, edges)
155+
156+
# With 5 unit spacing between edges
157+
linear_positions = np.array([5.0, 17.0]) # Accounting for spacing
158+
159+
projected = project_1d_to_2d(
160+
linear_positions, track_graph, edge_order=edges, edge_spacing=5.0
161+
)
162+
163+
assert projected.shape == (2, 2)
164+
# First position on first segment
165+
assert np.allclose(projected[0], [5, 0])
166+
# Second position on second segment (accounting for spacing)
167+
assert np.allclose(projected[1], [12, 0])
168+
169+
def test_project_1d_to_2d_nan_handling(self):
170+
"""Test 1D to 2D projection with NaN values."""
171+
node_positions = [(0, 0), (10, 0)]
172+
edges = [(0, 1)]
173+
track_graph = make_track_graph(node_positions, edges)
174+
175+
# Include NaN in linear positions
176+
linear_positions = np.array([5.0, np.nan, 7.0])
177+
178+
projected = project_1d_to_2d(
179+
linear_positions, track_graph, edge_order=edges, edge_spacing=0.0
180+
)
181+
182+
# Check shape
183+
assert projected.shape == (3, 2)
184+
# First and third should be valid
185+
assert np.all(np.isfinite(projected[0]))
186+
assert np.all(np.isfinite(projected[2]))
187+
# Second should be NaN
188+
assert np.all(np.isnan(projected[1]))
189+
190+
def test_project_1d_to_2d_out_of_bounds(self):
191+
"""Test 1D to 2D projection with out-of-bounds positions."""
192+
node_positions = [(0, 0), (10, 0)]
193+
edges = [(0, 1)]
194+
track_graph = make_track_graph(node_positions, edges)
195+
196+
# Position beyond end of track
197+
linear_positions = np.array([15.0])
198+
199+
projected = project_1d_to_2d(
200+
linear_positions, track_graph, edge_order=edges, edge_spacing=0.0
201+
)
202+
203+
# Should still return something (clamped to end or NaN)
204+
assert projected.shape == (1, 2)
205+
206+
def test_project_1d_to_2d_roundtrip(self):
207+
"""Test that 2D -> 1D -> 2D roundtrip preserves positions."""
208+
node_positions = [(0, 0), (10, 0), (20, 0)]
209+
edges = [(0, 1), (1, 2)]
210+
track_graph = make_track_graph(node_positions, edges)
211+
212+
# Original 2D positions on the track
213+
position_2d = np.array([[5, 0], [15, 0]])
214+
215+
# Convert to 1D
216+
result = get_linearized_position(position_2d, track_graph, edge_order=edges)
217+
linear_pos = result["linear_position"].to_numpy()
218+
219+
# Convert back to 2D
220+
position_2d_reconstructed = project_1d_to_2d(
221+
linear_pos, track_graph, edge_order=edges, edge_spacing=0.0
222+
)
223+
224+
# Should approximately recover original positions
225+
assert np.allclose(position_2d, position_2d_reconstructed, atol=0.01)
226+
227+
def test_project_1d_to_2d_empty_array(self):
228+
"""Test 1D to 2D projection with empty array."""
229+
node_positions = [(0, 0), (10, 0)]
230+
edges = [(0, 1)]
231+
track_graph = make_track_graph(node_positions, edges)
232+
233+
linear_positions = np.array([])
234+
235+
projected = project_1d_to_2d(
236+
linear_positions, track_graph, edge_order=edges, edge_spacing=0.0
237+
)
238+
239+
# Should return empty array
240+
assert projected.shape[0] == 0
241+
# Empty array may not preserve 2D shape, which is acceptable
242+
assert len(projected.shape) <= 2
243+
244+
245+
if __name__ == "__main__":
246+
pytest.main([__file__, "-v"])

src/track_linearization/tests/test_hmm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,10 @@ def test_hmm_empty_positions(self):
217217

218218
position = np.empty((0, 2))
219219

220-
# Empty positions should be handled gracefully
221-
# Skip this test as it reveals an edge case bug in the implementation
222-
pytest.skip("Empty position arrays cause broadcast error - known edge case")
220+
result = get_linearized_position(position, track_graph, use_HMM=True)
221+
222+
# Should return empty dataframe
223+
assert len(result) == 0
223224

224225
def test_hmm_nan_positions(self):
225226
"""Test HMM handles NaN in position data."""

0 commit comments

Comments
 (0)