Skip to content

Commit e6c7536

Browse files
authored
Merge branch 'main' into 96-task-write-trails-velocity-feature-functions-calculate_segment_velocity_metrics
2 parents a3a1918 + da2c270 commit e6c7536

8 files changed

Lines changed: 400 additions & 143 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
import datetime
5-
import typing
5+
from typing import Callable, List, Optional, Tuple
66

77
import numpy as np
88
import pandas as pd
@@ -95,7 +95,7 @@ def all(cls) -> set[str]:
9595
@classmethod
9696
def get_extractors(
9797
cls, spiral: Drawing, reference_spiral: np.ndarray
98-
) -> dict[str, typing.Callable[[], dict[str, float]]]:
98+
) -> dict[str, Callable[[], dict[str, float]]]:
9999
"""Get all feature extractors with appropriate inputs.
100100
101101
Args:
@@ -189,6 +189,7 @@ class LineSegment:
189189
points: pd.DataFrame
190190
is_error: bool
191191
line_number: int
192+
ink_points: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
192193

193194
ink_time: float = 0.0
194195
think_time: float = 0.0
@@ -200,8 +201,62 @@ class LineSegment:
200201
smoothness: float = 0.0
201202
hesitation_count: int = 0
202203
hesitation_duration: float = 0.0
203-
velocities: typing.List[float] = dataclasses.field(default_factory=list)
204-
accelerations: typing.List[float] = dataclasses.field(default_factory=list)
204+
velocities: List[float] = dataclasses.field(default_factory=list)
205+
accelerations: List[float] = dataclasses.field(default_factory=list)
206+
207+
def valid_ink_trajectory(
208+
self,
209+
start_circle: CircleTarget,
210+
end_circle: CircleTarget,
211+
) -> Tuple[Optional[int], Optional[int]]:
212+
"""Determine whether an ink trajectory exists from a start to end circle.
213+
214+
An "ink trajectory" is defined as the first contiguous sequence of
215+
points that:
216+
1. Begins **after** the pen leaves the start circle, and
217+
2. Ends when the pen first enters the end circle.
218+
219+
The function scans point-by-point in order. The ink start index is the
220+
first point whose (x, y) location is *outside* the start circle. The
221+
ink end index is the first subsequent point whose (x, y) location falls
222+
*inside* the end circle. If either of these conditions never occurs,
223+
the trajectory is considered invalid. If a valid trajectory is found,
224+
the ink_points attribute is updated to contain only the points within
225+
this trajectory.
226+
227+
Args:
228+
points: DataFrame of points with 'x' and 'y' columns.
229+
start_circle: CircleTarget representing the start circle.
230+
end_circle: CircleTarget representing the end circle.
231+
232+
Returns:
233+
Tuple of (ink_start_idx: int, ink_end_idx: int) if valid
234+
trajectory exists, else (None, None).
235+
"""
236+
ink_start_idx = None
237+
ink_end_idx = None
238+
239+
for idx, row in self.points.iterrows():
240+
if (
241+
not start_circle.contains_point(row["x"], row["y"])
242+
and ink_start_idx is None
243+
):
244+
ink_start_idx = idx
245+
246+
if ink_start_idx is not None and end_circle.contains_point(
247+
row["x"], row["y"]
248+
):
249+
ink_end_idx = idx
250+
break
251+
252+
if (
253+
ink_start_idx is not None
254+
and ink_end_idx is not None
255+
and ink_end_idx > ink_start_idx
256+
):
257+
self.ink_points = self.points.iloc[ink_start_idx : ink_end_idx + 1].copy()
258+
259+
return ink_start_idx, ink_end_idx
205260

206261
def calculate_path_optimality(
207262
self,

src/graphomotor/features/trails/drawing_metrics.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Feature extraction module for drawing error-based metrics in trails drawing data."""
22

3+
import numpy as np
4+
import pandas as pd
5+
36
from graphomotor.core import models
47

58

@@ -35,9 +38,54 @@ def percent_accurate_paths(drawing: models.Drawing) -> dict[str, float]:
3538
raise ValueError(
3639
"DataFrame must contain 'correct_path' and 'actual_path' columns."
3740
)
38-
3941
return {
4042
"percent_accurate_paths": (
4143
(drawing.data["correct_path"] == drawing.data["actual_path"]).mean() * 100
4244
)
4345
}
46+
47+
48+
def calculate_smoothness(points: pd.DataFrame) -> float:
49+
"""Calculate path smoothness based on Root Mean Square (RMS) curvature.
50+
51+
Represents the curvature per unit arc length.
52+
Lower values indicate smoother drawings. Penalizes sharp corners (e.g., 90° turns)
53+
and noisy corrections. Normalized by arc length to reduce sampling-rate dependence.
54+
55+
Args:
56+
points: DataFrame representing drawing points.
57+
58+
Returns:
59+
Smoothness metric as a float.
60+
"""
61+
if len(points) < 3:
62+
return 0.0
63+
64+
xy = points[["x", "y"]].to_numpy()
65+
66+
forward_vector = xy[1:-1] - xy[:-2]
67+
backward_vector = xy[2:] - xy[1:-1]
68+
69+
forward_norm = np.linalg.norm(forward_vector, axis=1)
70+
backward_norm = np.linalg.norm(backward_vector, axis=1)
71+
72+
valid = (forward_norm > 0) & (backward_norm > 0)
73+
if not np.any(valid):
74+
return 0.0
75+
76+
valid_forward_vector = forward_vector[valid]
77+
valid_backward_vector = backward_vector[valid]
78+
valid_forward_norm = forward_norm[valid]
79+
valid_backward_norm = backward_norm[valid]
80+
81+
cos_angle = (valid_forward_vector * valid_backward_vector).sum(axis=1) / (
82+
valid_forward_norm * valid_backward_norm
83+
)
84+
cos_angle = np.clip(cos_angle, -1.0, 1.0)
85+
86+
angles = np.arccos(cos_angle)
87+
88+
avg_segment_length = (valid_forward_norm + valid_backward_norm) / 2.0
89+
curvatures = angles / avg_segment_length
90+
91+
return float(np.sqrt(np.mean(curvatures**2)))
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Feature extraction module for time-based metrics in trails drawing data."""
2+
3+
from graphomotor.core import models
4+
5+
6+
def calculate_total_error_time(drawing: models.Drawing) -> dict[str, float]:
7+
"""Calculate the total time spent making errors.
8+
9+
A contiguous "error chunk" is any sequence of rows where df["error"] != "E0".
10+
The start and end of each chunk is defined as the midpoint between the last
11+
timestamp with a "correct" entry and the first timestamp of an "error". The total
12+
error time is the sum of the durations of all error chunks.
13+
14+
Args:
15+
drawing: Drawing object containing drawing data.
16+
17+
Returns:
18+
Dictionary containing the total time (s) spent in error states.
19+
"""
20+
mask = drawing.data["error"] != "E0"
21+
if not mask.any():
22+
return {"total_error_time": 0.0}
23+
24+
error_change = mask.astype(int).diff()
25+
chunk_starts = error_change[error_change == 1].index.tolist()
26+
chunk_ends = error_change[error_change == -1].index.tolist()
27+
28+
if mask.iloc[0]:
29+
chunk_starts = [0] + chunk_starts
30+
31+
if mask.iloc[-1]:
32+
chunk_ends = chunk_ends + [len(drawing.data)]
33+
34+
seconds = drawing.data["seconds"].to_numpy()
35+
total_error_time = 0.0
36+
37+
for start_idx, end_idx in zip(chunk_starts, chunk_ends):
38+
start_time = (
39+
(seconds[start_idx - 1] + seconds[start_idx]) / 2
40+
if start_idx > 0
41+
else seconds[0]
42+
)
43+
44+
end_time = (
45+
(seconds[end_idx - 1] + seconds[end_idx]) / 2
46+
if end_idx < len(seconds)
47+
else seconds[-1]
48+
)
49+
50+
total_error_time += end_time - start_time
51+
52+
return {"total_error_time": float(total_error_time)}

src/graphomotor/utils/trails_utils.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,12 @@
11
"""Utility functions for trails management."""
22

3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Dict, List
44

55
import pandas as pd
66

77
from graphomotor.core import models
88

99

10-
def valid_ink_trajectory(
11-
points: pd.DataFrame,
12-
start_circle: models.CircleTarget,
13-
end_circle: models.CircleTarget,
14-
) -> Tuple[Optional[int], Optional[int]]:
15-
"""Determine whether an ink trajectory exists from a start circle to an end circle.
16-
17-
An "ink trajectory" is defined as the first contiguous sequence of points that:
18-
1. Begins **after** the pen leaves the start circle, and
19-
2. Ends when the pen first enters the end circle.
20-
21-
The function scans point-by-point in order. The ink start index is the first point
22-
whose (x, y) location is *outside* the start circle. The ink end index is the first
23-
subsequent point whose (x, y) location falls *inside* the end circle. If either of
24-
these conditions never occurs, the trajectory is considered invalid.
25-
26-
Args:
27-
points: DataFrame of points with 'x' and 'y' columns.
28-
start_circle: CircleTarget representing the start circle.
29-
end_circle: CircleTarget representing the end circle.
30-
31-
Returns:
32-
Tuple of (ink_start_idx: int, ink_end_idx: int) if valid trajectory exists,
33-
else (None, None).
34-
"""
35-
ink_start_idx = None
36-
ink_end_idx = None
37-
38-
for idx, row in points.iterrows():
39-
if (
40-
not start_circle.contains_point(row["x"], row["y"])
41-
and ink_start_idx is None
42-
):
43-
ink_start_idx = idx
44-
45-
if ink_start_idx is not None and end_circle.contains_point(row["x"], row["y"]):
46-
ink_end_idx = idx
47-
break
48-
49-
return ink_start_idx, ink_end_idx
50-
51-
5210
def segment_lines(
5311
trail_data: pd.DataFrame,
5412
trail_id: str,

tests/unit/test_models.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test cases for the Spiral model."""
22

33
import datetime
4+
from typing import Dict, cast
45

56
import pandas as pd
67
import pytest
@@ -137,6 +138,108 @@ def test_path_optimality_non_positive_distance() -> None:
137138
assert segment.path_optimality == 0.0
138139

139140

141+
@pytest.mark.parametrize(
142+
"points_data,start_params,end_params,expected_start,expected_end,test_id",
143+
[
144+
(
145+
{"x": [0, 1, 2, 3, 4, 5], "y": [0, 1, 2, 3, 4, 5]},
146+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
147+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
148+
1,
149+
5,
150+
"valid_trajectory",
151+
),
152+
(
153+
{"x": [0.1, 0.2, 0.3], "y": [0.1, 0.2, 0.3]},
154+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
155+
{"order": 2, "label": "B", "center_x": 10, "center_y": 10, "radius": 1.0},
156+
None,
157+
None,
158+
"no_exit_from_start",
159+
),
160+
(
161+
{"x": [0, 1, 2, 3], "y": [0, 1, 2, 3]},
162+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
163+
{"order": 2, "label": "B", "center_x": 10, "center_y": 10, "radius": 0.5},
164+
1,
165+
None,
166+
"never_reaches_end",
167+
),
168+
(
169+
{"x": [], "y": []},
170+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
171+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 1.0},
172+
None,
173+
None,
174+
"empty_dataframe",
175+
),
176+
(
177+
{"x": [0.1], "y": [0.1]},
178+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 1.0},
179+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 1.0},
180+
None,
181+
None,
182+
"single_point_in_start",
183+
),
184+
(
185+
{"x": [2], "y": [2]},
186+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
187+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
188+
0,
189+
None,
190+
"single_point_outside_start",
191+
),
192+
(
193+
{"x": [0, 2.5, 5], "y": [0, 2.5, 5]},
194+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
195+
{"order": 2, "label": "B", "center_x": 2.5, "center_y": 2.5, "radius": 1.0},
196+
1,
197+
1,
198+
"immediate_transition",
199+
),
200+
(
201+
{"x": [3, 4, 5], "y": [3, 4, 5]},
202+
{"order": 1, "label": "A", "center_x": 0, "center_y": 0, "radius": 0.5},
203+
{"order": 2, "label": "B", "center_x": 5, "center_y": 5, "radius": 0.5},
204+
0,
205+
2,
206+
"first_point_outside_start",
207+
),
208+
],
209+
ids=lambda x: x if isinstance(x, str) else "",
210+
)
211+
def test_valid_ink_trajectory(
212+
points_data: Dict[str, list[float]],
213+
start_params: dict[str, int | str | float],
214+
end_params: dict[str, int | str | float],
215+
expected_start: int | None,
216+
expected_end: int | None,
217+
test_id: str,
218+
) -> None:
219+
"""Test valid_ink_trajectory method with various point configurations.
220+
221+
Tests behavior with different circle boundaries.
222+
"""
223+
points_df = pd.DataFrame(points_data)
224+
225+
start_circle = models.CircleTarget(**start_params) # type: ignore[arg-type]
226+
end_circle = models.CircleTarget(**end_params) # type: ignore[arg-type]
227+
228+
line_segment = models.LineSegment(
229+
start_label=cast(str, start_params["label"]),
230+
end_label=cast(str, end_params["label"]),
231+
points=points_df,
232+
is_error=False,
233+
line_number=1,
234+
)
235+
result_start, result_end = line_segment.valid_ink_trajectory(
236+
start_circle, end_circle
237+
)
238+
239+
assert result_start == expected_start, f"Start index mismatch for {test_id}"
240+
assert result_end == expected_end, f"End index mismatch for {test_id}"
241+
242+
140243
def test_uniform_motion() -> None:
141244
"""Test with points moving at constant velocity."""
142245
points = pd.DataFrame({
@@ -148,10 +251,6 @@ def test_uniform_motion() -> None:
148251
start_label="1",
149252
end_label="2",
150253
points=points,
151-
is_error=False,
152-
line_number=1,
153-
)
154-
155254
segment.calculate_velocity_metrics(points)
156255

157256
assert segment.distance == pytest.approx(3.0)
@@ -270,3 +369,4 @@ def test_stationary_motion() -> None:
270369
assert all(v == pytest.approx(0.0) for v in segment.velocities)
271370
assert len(segment.accelerations) == 1
272371
assert segment.accelerations[0] == pytest.approx(0.0)
372+

0 commit comments

Comments
 (0)