Skip to content

Commit ca8b78a

Browse files
committed
Merge branch 'main' into 73-task-write-trails-time-feature-functions-think_time_
2 parents c7cf159 + d4a7571 commit ca8b78a

2 files changed

Lines changed: 113 additions & 0 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class GridCell:
144144
y_max: Top boundary of the cell.
145145
index: Position of the cell in the grid (0-based).
146146
label: Display label for the cell (e.g., 'A', 'B', '1').
147+
strokes: List of Stroke objects assigned to this cell.
147148
"""
148149

149150
x_min: float
@@ -152,6 +153,7 @@ class GridCell:
152153
y_max: float
153154
index: int = 0
154155
label: str = ""
156+
strokes: List["Stroke"] = dataclasses.field(default_factory=list)
155157

156158
def __post_init__(self) -> None:
157159
"""Validate that min bounds are strictly less than max bounds.
@@ -189,6 +191,41 @@ def contains_points(self, points: pd.DataFrame) -> bool:
189191
)
190192

191193

194+
@dataclasses.dataclass
195+
class Stroke:
196+
"""Represents a single stroke in an Alphabet or DSYM task.
197+
198+
This class holds stroke data and computed features. Features are populated by
199+
utility functions after initialization.
200+
201+
Attributes:
202+
points: DataFrame with columns including 'x', 'y', and 'seconds'.
203+
line_number: The line number identifying this stroke in the raw data.
204+
duration: Total time (s) spent drawing the stroke.
205+
distance: Total distance (px) of the stroke path.
206+
mean_speed: Mean drawing speed (px/s).
207+
speed_variance: Variance of drawing speed.
208+
smoothness: Smoothness of the stroke based on curvature changes.
209+
hesitation_count: Number of hesitations during the stroke.
210+
hesitation_duration: Total duration of hesitations (s).
211+
velocities: List of velocities at each point in the stroke (px/s).
212+
accelerations: List of accelerations at each point in the stroke (px/s²).
213+
"""
214+
215+
points: pd.DataFrame
216+
line_number: int
217+
218+
duration: float = 0.0
219+
distance: float = 0.0
220+
mean_speed: float = 0.0
221+
speed_variance: float = 0.0
222+
smoothness: float = 0.0
223+
hesitation_count: int = 0
224+
hesitation_duration: float = 0.0
225+
velocities: List[float] = dataclasses.field(default_factory=list)
226+
accelerations: List[float] = dataclasses.field(default_factory=list)
227+
228+
192229
@dataclasses.dataclass
193230
class CircleTarget:
194231
"""Represents a target circle in the drawing task.

tests/unit/test_stroke.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Test cases for the Stroke model."""
2+
3+
import pandas as pd
4+
5+
from graphomotor.core import models
6+
7+
8+
def _make_points(line_number: int = 0) -> pd.DataFrame:
9+
"""Create a stroke DataFrame matching real Alphabet CSV structure.
10+
11+
In the real data each stroke (line_number) contains many rows with the same
12+
line_number value, plus x, y, seconds, and timestamp columns.
13+
"""
14+
return pd.DataFrame(
15+
{
16+
"line_number": [line_number] * 5,
17+
"x": [18.35, 18.24, 18.15, 18.12, 17.99],
18+
"y": [92.84, 92.85, 92.88, 92.92, 93.01],
19+
"seconds": [0.0, 0.02, 0.037, 0.046, 0.062],
20+
}
21+
)
22+
23+
24+
def test_stroke_initialization() -> None:
25+
"""Stroke should store points and line_number with default feature values."""
26+
points = _make_points(line_number=0)
27+
stroke = models.Stroke(points=points, line_number=0)
28+
29+
assert stroke.line_number == 0
30+
assert stroke.points.equals(points)
31+
assert list(stroke.points.columns) == ["line_number", "x", "y", "seconds"]
32+
assert len(stroke.points) == 5
33+
assert stroke.duration == 0.0
34+
assert stroke.distance == 0.0
35+
assert stroke.mean_speed == 0.0
36+
assert stroke.speed_variance == 0.0
37+
assert stroke.smoothness == 0.0
38+
assert stroke.hesitation_count == 0
39+
assert stroke.hesitation_duration == 0.0
40+
assert stroke.velocities == []
41+
assert stroke.accelerations == []
42+
43+
44+
def test_stroke_features_are_mutable() -> None:
45+
"""Computed features should be writable after initialization."""
46+
stroke = models.Stroke(points=_make_points(line_number=0), line_number=0)
47+
48+
stroke.duration = 0.5
49+
stroke.distance = 25.0
50+
stroke.mean_speed = 50.0
51+
stroke.velocities = [40.0, 50.0, 60.0]
52+
53+
assert stroke.duration == 0.5
54+
assert stroke.distance == 25.0
55+
assert stroke.mean_speed == 50.0
56+
assert stroke.velocities == [40.0, 50.0, 60.0]
57+
58+
59+
def test_grid_cell_stores_strokes() -> None:
60+
"""GridCell should hold Stroke objects in its strokes list."""
61+
cell = models.GridCell(x_min=0.0, x_max=30.0, y_min=70.0, y_max=100.0)
62+
stroke_a = models.Stroke(points=_make_points(line_number=0), line_number=0)
63+
stroke_b = models.Stroke(points=_make_points(line_number=1), line_number=1)
64+
65+
cell.strokes.append(stroke_a)
66+
cell.strokes.append(stroke_b)
67+
68+
assert len(cell.strokes) == 2
69+
assert cell.strokes[0].line_number == 0
70+
assert cell.strokes[1].line_number == 1
71+
72+
73+
def test_grid_cell_strokes_default_empty() -> None:
74+
"""GridCell should default to an empty strokes list."""
75+
cell = models.GridCell(x_min=0.0, x_max=10.0, y_min=0.0, y_max=10.0)
76+
assert cell.strokes == []

0 commit comments

Comments
 (0)