Skip to content

Commit c68e01e

Browse files
cgmaioranoAsanto32
andauthored
74 task write linesegment class in trails utils (#75)
* created trail_utils file * wrote LineSegment in config, updated venv with pandas * Update test.yaml Trying to remove the --only-dev that seems to be causing issues in ruff and mypy * Bump shapely version * Trying to limit to 3.11-3.13 * Update pyproject.toml * 77 task write circletarget class in config (#78) * wrote class and method * wrote 4 unit tests * Move LineSegment and CircleTarget from config to models * Move unit tests from test_config to test_models * editing docstrings * parameterized 3 tests into 1 --------- Co-authored-by: Adam Santorelli <148909356+Asanto32@users.noreply.github.com>
1 parent 94675c2 commit c68e01e

6 files changed

Lines changed: 161 additions & 195 deletions

File tree

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
with:
7979
python-version-file: pyproject.toml
8080
- name: Install the project
81-
run: uv sync --only-dev
81+
run: uv sync
8282
- name: Ruff format
8383
run: uv run ruff format --check
8484
- name: Ruff check
@@ -97,6 +97,6 @@ jobs:
9797
with:
9898
python-version-file: pyproject.toml
9999
- name: Install the project
100-
run: uv sync --only-dev
100+
run: uv sync
101101
- run: |
102102
uv run mypy .

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@ authors = [
88
{name = "Iktae Kim", email = "iktae.kim@childmind.org"},
99
{name = "Adam Santorelli", email = "adam.santorelli@childmind.org"}
1010
]
11-
license = "LGPL-2.1"
11+
license = "LGPL-2.1-only"
1212
readme = "README.md"
13-
requires-python = ">=3.11"
13+
requires-python = ">=3.11,<3.14"
14+
1415
dependencies = [
1516
"pandas>=2.2.3",
1617
"pydantic>=2.11.1",
1718
"scipy>=1.15.2",
18-
"shapely>=2.1.0",
19+
"shapely>=2.1.2",
1920
"tqdm>=4.66.0",
2021
"typer>=0.12.0",
2122
"matplotlib>=3.10.0",
22-
"seaborn>=0.13.0"
23+
"seaborn>=0.13.0",
2324
]
2425

2526
[project.scripts]

src/graphomotor/core/models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Internal data class for spiral drawing data."""
22

3+
import dataclasses
34
import datetime
45
import typing
56

@@ -116,3 +117,82 @@ def get_extractors(
116117
spiral, reference_spiral
117118
),
118119
}
120+
121+
122+
@dataclasses.dataclass
123+
class LineSegment:
124+
"""Represents a line drawn between two circles.
125+
126+
Attributes:
127+
start_label: Label of the starting circle.
128+
end_label: Label of the ending circle.
129+
points: DataFrame containing the points in the line segment.
130+
is_error: Whether the line segment is an error (missed target).
131+
line_number: The line number of the segment.
132+
133+
Calculated features:
134+
ink_time: Time spent drawing the line segment.
135+
think_time: Time spent thinking before drawing the line segment.
136+
think_circle_label: Label of the circle associated with think time.
137+
distance: Total distance drawn outside circles.
138+
mean_speed: Mean speed of drawing the line segment.
139+
speed_variance: Variance of speed during the line segment.
140+
path_optimality: Ratio of actual path length to optimal path length.
141+
smoothness: Smoothness of the line segment based on curvature changes.
142+
hesitation_count: Number of hesitations during the line segment.
143+
hesitation_duration: Total duration of hesitations during the line segment.
144+
velocities: List of velocities at each point in the line segment.
145+
accelerations: List of accelerations at each point in the line segment.
146+
"""
147+
148+
start_label: str
149+
end_label: str
150+
points: pd.DataFrame
151+
is_error: bool
152+
line_number: int
153+
154+
ink_time: float = 0.0
155+
think_time: float = 0.0
156+
think_circle_label: str = ""
157+
distance: float = 0.0
158+
mean_speed: float = 0.0
159+
speed_variance: float = 0.0
160+
path_optimality: float = 0.0
161+
smoothness: float = 0.0
162+
hesitation_count: int = 0
163+
hesitation_duration: float = 0.0
164+
velocities: typing.List[float] = dataclasses.field(default_factory=list)
165+
accelerations: typing.List[float] = dataclasses.field(default_factory=list)
166+
167+
168+
@dataclasses.dataclass
169+
class CircleTarget:
170+
"""Represents a target circle in the drawing task.
171+
172+
Attributes:
173+
order: The order of the circle in the sequence.
174+
label: The label of the circle.
175+
center_x: The x-coordinate of the circle's center.
176+
center_y: The y-coordinate of the circle's center.
177+
radius: The radius of the circle.
178+
"""
179+
180+
order: int
181+
label: str
182+
center_x: float
183+
center_y: float
184+
radius: float
185+
186+
def contains_point(self, x: float, y: float, tolerance: float = 1.5) -> bool:
187+
"""Check if a point is within the circle (with tolerance multiplier).
188+
189+
Args:
190+
x: X coordinate of the point.
191+
y: Y coordinate of the point.
192+
tolerance: Multiplier for the radius to define tolerance boundary.
193+
194+
Returns:
195+
True if the point is within the circle (with tolerance), False otherwise.
196+
"""
197+
distance = np.sqrt((x - self.center_x) ** 2 + (y - self.center_y) ** 2)
198+
return distance <= (self.radius * tolerance)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Utility functions for trails management."""

tests/unit/test_models.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,39 @@ def test_invalid_metadata_values(
5959
models.Drawing(
6060
data=valid_spiral_data, task_name="spiral", metadata=invalid_metadata
6161
)
62+
63+
64+
@pytest.fixture
65+
def circle() -> models.CircleTarget:
66+
"""Create a standard circle at origin with radius 10."""
67+
return models.CircleTarget(
68+
order=1, label="test_circle", center_x=0.0, center_y=0.0, radius=10.0
69+
)
70+
71+
72+
@pytest.mark.parametrize(
73+
"x,y,description",
74+
[
75+
(0.0, 0.0, "center"),
76+
(10.0, 0.0, "right edge"),
77+
(0.0, 10.0, "top edge"),
78+
(-10.0, 0.0, "left edge"),
79+
(0.0, -10.0, "bottom edge"),
80+
(5.0, 0.0, "inside horizontally"),
81+
(0.0, 5.0, "inside vertically"),
82+
],
83+
)
84+
def test_point_inside_circle(
85+
circle: models.CircleTarget,
86+
x: float,
87+
y: float,
88+
description: str,
89+
) -> None:
90+
"""Point at center, on edge, or just inside should be contained."""
91+
assert circle.contains_point(x, y)
92+
93+
94+
def test_point_outside_with_default_tolerance(circle: models.CircleTarget) -> None:
95+
"""Point outside default tolerance boundary should not be contained."""
96+
assert not circle.contains_point(16.0, 0.0)
97+
assert not circle.contains_point(0.0, 16.0)

0 commit comments

Comments
 (0)