Skip to content

Commit 7703e37

Browse files
authored
71 task write trails utility functions valid ink trajectory updates (#101)
* added seciton of code to calculate ink_points * move valid_ink_trajectory into lineSegment and update impors for models.py and trails_utils.py * moved unit test and reformatted * mypy error * ruff reformat
1 parent 704fce8 commit 7703e37

4 files changed

Lines changed: 164 additions & 138 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
import datetime
5-
import typing
5+
from typing import Callable, List, Optional, Tuple
66

77
import numpy as np
88
import pandas as pd
@@ -95,7 +95,7 @@ def all(cls) -> set[str]:
9595
@classmethod
9696
def get_extractors(
9797
cls, spiral: Drawing, reference_spiral: np.ndarray
98-
) -> dict[str, typing.Callable[[], dict[str, float]]]:
98+
) -> dict[str, Callable[[], dict[str, float]]]:
9999
"""Get all feature extractors with appropriate inputs.
100100
101101
Args:
@@ -189,6 +189,7 @@ class LineSegment:
189189
points: pd.DataFrame
190190
is_error: bool
191191
line_number: int
192+
ink_points: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
192193

193194
ink_time: float = 0.0
194195
think_time: float = 0.0
@@ -200,8 +201,62 @@ class LineSegment:
200201
smoothness: float = 0.0
201202
hesitation_count: int = 0
202203
hesitation_duration: float = 0.0
203-
velocities: typing.List[float] = dataclasses.field(default_factory=list)
204-
accelerations: typing.List[float] = dataclasses.field(default_factory=list)
204+
velocities: List[float] = dataclasses.field(default_factory=list)
205+
accelerations: List[float] = dataclasses.field(default_factory=list)
206+
207+
def valid_ink_trajectory(
208+
self,
209+
start_circle: CircleTarget,
210+
end_circle: CircleTarget,
211+
) -> Tuple[Optional[int], Optional[int]]:
212+
"""Determine whether an ink trajectory exists from a start to end circle.
213+
214+
An "ink trajectory" is defined as the first contiguous sequence of
215+
points that:
216+
1. Begins **after** the pen leaves the start circle, and
217+
2. Ends when the pen first enters the end circle.
218+
219+
The function scans point-by-point in order. The ink start index is the
220+
first point whose (x, y) location is *outside* the start circle. The
221+
ink end index is the first subsequent point whose (x, y) location falls
222+
*inside* the end circle. If either of these conditions never occurs,
223+
the trajectory is considered invalid. If a valid trajectory is found,
224+
the ink_points attribute is updated to contain only the points within
225+
this trajectory.
226+
227+
Args:
228+
points: DataFrame of points with 'x' and 'y' columns.
229+
start_circle: CircleTarget representing the start circle.
230+
end_circle: CircleTarget representing the end circle.
231+
232+
Returns:
233+
Tuple of (ink_start_idx: int, ink_end_idx: int) if valid
234+
trajectory exists, else (None, None).
235+
"""
236+
ink_start_idx = None
237+
ink_end_idx = None
238+
239+
for idx, row in self.points.iterrows():
240+
if (
241+
not start_circle.contains_point(row["x"], row["y"])
242+
and ink_start_idx is None
243+
):
244+
ink_start_idx = idx
245+
246+
if ink_start_idx is not None and end_circle.contains_point(
247+
row["x"], row["y"]
248+
):
249+
ink_end_idx = idx
250+
break
251+
252+
if (
253+
ink_start_idx is not None
254+
and ink_end_idx is not None
255+
and ink_end_idx > ink_start_idx
256+
):
257+
self.ink_points = self.points.iloc[ink_start_idx : ink_end_idx + 1].copy()
258+
259+
return ink_start_idx, ink_end_idx
205260

206261
def calculate_path_optimality(
207262
self,

src/graphomotor/utils/trails_utils.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,12 @@
11
"""Utility functions for trails management."""
22

3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, List
44

55
import pandas as pd
66

77
from graphomotor.core import models
88

99

10-
def valid_ink_trajectory(
11-
points: pd.DataFrame,
12-
start_circle: models.CircleTarget,
13-
end_circle: models.CircleTarget,
14-
) -> Tuple[Optional[int], Optional[int]]:
15-
"""Determine whether an ink trajectory exists from a start circle to an end circle.
16-
17-
An "ink trajectory" is defined as the first contiguous sequence of points that:
18-
1. Begins **after** the pen leaves the start circle, and
19-
2. Ends when the pen first enters the end circle.
20-
21-
The function scans point-by-point in order. The ink start index is the first point
22-
whose (x, y) location is *outside* the start circle. The ink end index is the first
23-
subsequent point whose (x, y) location falls *inside* the end circle. If either of
24-
these conditions never occurs, the trajectory is considered invalid.
25-
26-
Args:
27-
points: DataFrame of points with 'x' and 'y' columns.
28-
start_circle: CircleTarget representing the start circle.
29-
end_circle: CircleTarget representing the end circle.
30-
31-
Returns:
32-
Tuple of (ink_start_idx: int, ink_end_idx: int) if valid trajectory exists,
33-
else (None, None).
34-
"""
35-
ink_start_idx = None
36-
ink_end_idx = None
37-
38-
for idx, row in points.iterrows():
39-
if (
40-
not start_circle.contains_point(row["x"], row["y"])
41-
and ink_start_idx is None
42-
):
43-
ink_start_idx = idx
44-
45-
if ink_start_idx is not None and end_circle.contains_point(row["x"], row["y"]):
46-
ink_end_idx = idx
47-
break
48-
49-
return ink_start_idx, ink_end_idx
50-
51-
5210
def segment_lines(
5311
trail_data: pd.DataFrame,
5412
trail_id: str,

tests/unit/test_models.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test cases for the Spiral model."""
22

33
import datetime
4+
from typing import Dict, cast
45

56
import pandas as pd
67
import pytest
@@ -135,3 +136,106 @@ def test_path_optimality_non_positive_distance() -> None:
135136
segment.calculate_path_optimality(start, end)
136137

137138
assert segment.path_optimality == 0.0
139+
140+
141+
@pytest.mark.parametrize(
142+
"points_data,start_params,end_params,expected_start,expected_end,test_id",
143+
[
144+
(
145+
{"x": [0, 1, 2, 3, 4, 5], "y": [0, 1, 2, 3, 4, 5]},
146+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
147+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
148+
1,
149+
5,
150+
"valid_trajectory",
151+
),
152+
(
153+
{"x": [0.1, 0.2, 0.3], "y": [0.1, 0.2, 0.3]},
154+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
155+
{"order": 2, "label": "B", "center_x": 10, "center_y": 10, "radius": 1.0},
156+
None,
157+
None,
158+
"no_exit_from_start",
159+
),
160+
(
161+
{"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
162+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
163+
{"order": 2, "label": "B", "center_x": 10, "center_y": 10, "radius": 0.5},
164+
1,
165+
None,
166+
"never_reaches_end",
167+
),
168+
(
169+
{"x": [], "y": []},
170+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
171+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 1.0},
172+
None,
173+
None,
174+
"empty_dataframe",
175+
),
176+
(
177+
{"x": [0.1], "y": [0.1]},
178+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
179+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 1.0},
180+
None,
181+
None,
182+
"single_point_in_start",
183+
),
184+
(
185+
{"x": [2], "y": [2]},
186+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
187+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
188+
0,
189+
None,
190+
"single_point_outside_start",
191+
),
192+
(
193+
{"x": [0, 2.5, 5], "y": [0, 2.5, 5]},
194+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
195+
{"order": 2, "label": "B", "center_x": 2.5, "center_y": 2.5, "radius": 1.0},
196+
1,
197+
1,
198+
"immediate_transition",
199+
),
200+
(
201+
{"x": [3, 4, 5], "y": [3, 4, 5]},
202+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
203+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
204+
0,
205+
2,
206+
"first_point_outside_start",
207+
),
208+
],
209+
ids=lambda x: x if isinstance(x, str) else "",
210+
)
211+
def test_valid_ink_trajectory(
212+
points_data: Dict[str, list[float]],
213+
start_params: dict[str, int | str | float],
214+
end_params: dict[str, int | str | float],
215+
expected_start: int | None,
216+
expected_end: int | None,
217+
test_id: str,
218+
) -> None:
219+
"""Test valid_ink_trajectory method with various point configurations.
220+
221+
Tests behavior with different circle boundaries.
222+
"""
223+
points_df = pd.DataFrame(points_data)
224+
225+
start_circle = models.CircleTarget(**start_params) # type: ignore[arg-type]
226+
end_circle = models.CircleTarget(**end_params) # type: ignore[arg-type]
227+
228+
line_segment = models.LineSegment(
229+
start_label=cast(str, start_params["label"]),
230+
end_label=cast(str, end_params["label"]),
231+
points=points_df,
232+
is_error=False,
233+
line_number=1,
234+
)
235+
236+
result_start, result_end = line_segment.valid_ink_trajectory(
237+
start_circle, end_circle
238+
)
239+
240+
assert result_start == expected_start, f"Start index mismatch for {test_id}"
241+
assert result_end == expected_end, f"End index mismatch for {test_id}"

tests/unit/test_trails_utils.py

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -9,97 +9,6 @@
99
from graphomotor.utils import trails_utils
1010

1111

12-
@pytest.mark.parametrize(
13-
"points_data,start_params,end_params,expected_start,expected_end,test_id",
14-
[
15-
(
16-
{"x": [0, 1, 2, 3, 4, 5], "y": [0, 1, 2, 3, 4, 5]},
17-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
18-
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
19-
1,
20-
5,
21-
"valid_trajectory",
22-
),
23-
(
24-
{"x": [0.1, 0.2, 0.3], "y": [0.1, 0.2, 0.3]},
25-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
26-
{"order": 2, "label": "B", "center_x": 10, "center_y": 10, "radius": 1.0},
27-
None,
28-
None,
29-
"no_exit_from_start",
30-
),
31-
(
32-
{"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
33-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
34-
{"order": 2, "label": "B", "center_x": 10, "center_y": 10, "radius": 0.5},
35-
1,
36-
None,
37-
"never_reaches_end",
38-
),
39-
(
40-
{"x": [], "y": []},
41-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
42-
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 1.0},
43-
None,
44-
None,
45-
"empty_dataframe",
46-
),
47-
(
48-
{"x": [0.1], "y": [0.1]},
49-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
50-
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 1.0},
51-
None,
52-
None,
53-
"single_point_in_start",
54-
),
55-
(
56-
{"x": [2], "y": [2]},
57-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
58-
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
59-
0,
60-
None,
61-
"single_point_outside_start",
62-
),
63-
(
64-
{"x": [0, 2.5, 5], "y": [0, 2.5, 5]},
65-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
66-
{"order": 2, "label": "B", "center_x": 2.5, "center_y": 2.5, "radius": 1.0},
67-
1,
68-
1,
69-
"immediate_transition",
70-
),
71-
(
72-
{"x": [3, 4, 5], "y": [3, 4, 5]},
73-
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
74-
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
75-
0,
76-
2,
77-
"first_point_outside_start",
78-
),
79-
],
80-
ids=lambda x: x if isinstance(x, str) else "",
81-
)
82-
def test_valid_ink_trajectory_scenarios(
83-
points_data: Dict[str, list],
84-
start_params: Dict,
85-
end_params: Dict,
86-
expected_start: int,
87-
expected_end: int,
88-
test_id: str,
89-
) -> None:
90-
"""Test various trajectory scenarios between start and end circles."""
91-
points = pd.DataFrame(points_data)
92-
start_circle = models.CircleTarget(**start_params)
93-
end_circle = models.CircleTarget(**end_params)
94-
95-
ink_start, ink_end = trails_utils.valid_ink_trajectory(
96-
points, start_circle, end_circle
97-
)
98-
99-
assert ink_start == expected_start, f"Failed on {test_id}: ink_start"
100-
assert ink_end == expected_end, f"Failed on {test_id}: ink_end"
101-
102-
10312
@pytest.fixture
10413
def circles() -> Dict[str, Dict[str, models.CircleTarget]]:
10514
"""Return a minimal set of CircleTarget dictionaries for each trail."""

0 commit comments

Comments
 (0)