-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
318 lines (259 loc) · 10.7 KB
/
models.py
File metadata and controls
318 lines (259 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""Internal data classes for drawing data."""
import dataclasses
import datetime
from typing import Callable, List, Optional, Tuple
import numpy as np
import pandas as pd
import pydantic
import scipy.spatial.distance as dist
class Drawing(pydantic.BaseModel):
"""Class representing a drawing task, encapsulating both raw data and metadata.
Attributes:
data: DataFrame containing drawing data with required columns (line_number, x,
y, UTC_Timestamp, seconds).
task_name: Name of the drawing task (e.g., 'spiral', 'trails', etc.).
metadata: Dictionary containing metadata about the drawing:
- id: Unique identifier for the participant,
- hand: Hand used ('Dom' for dominant, 'NonDom' for non-dominant),
- task: Task name,
- start_time: Start time of drawing,
- source_path: Path to the source CSV file.
"""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
data: pd.DataFrame
task_name: str
metadata: dict[str, str | datetime.datetime]
@pydantic.field_validator("data")
@classmethod
def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
"""Validate that DataFrame is not empty.
Args:
cls: The class.
v: The dataframe to validate.
Returns:
The dataframe if it is not empty.
Raises:
ValueError: If the dataframe is empty.
"""
if v.empty:
raise ValueError("DataFrame is empty")
return v
@pydantic.field_validator("metadata")
@classmethod
def validate_metadata(cls, v: dict) -> dict:
"""Validate metadata dictionary for required keys and correct data types.
Args:
cls: The class.
v: The metadata dictionary to validate.
Returns:
The metadata dictionary if it is valid.
Raises:
ValueError: If the metadata dictionary has invalid values.
"""
if not v["id"].startswith("5"):
raise ValueError("'id' must start with digit 5")
if len(v["id"]) != 7:
raise ValueError("'id' must be 7 digits long")
return v
class SpiralFeatureCategories:
"""Class to hold valid feature categories for Graphomotor."""
DURATION = "duration"
VELOCITY = "velocity"
HAUSDORFF = "hausdorff"
AUC = "AUC"
@classmethod
def all(cls) -> set[str]:
"""Return all valid feature categories."""
return {
cls.DURATION,
cls.VELOCITY,
cls.HAUSDORFF,
cls.AUC,
}
@classmethod
def get_extractors(
cls, spiral: Drawing, reference_spiral: np.ndarray
) -> dict[str, Callable[[], dict[str, float]]]:
"""Get all feature extractors with appropriate inputs.
Args:
spiral: The spiral data to extract features from.
reference_spiral: Reference spiral for comparison-based metrics.
Returns:
Dictionary mapping category names to their feature extractor functions.
"""
# Importing feature modules here to avoid circular imports.
from graphomotor.features import shared_features
from graphomotor.features.spiral import (
distance,
drawing_error,
velocity,
)
return {
cls.DURATION: lambda: shared_features.get_task_duration(spiral),
cls.VELOCITY: lambda: velocity.calculate_velocity_metrics(spiral),
cls.HAUSDORFF: lambda: distance.calculate_hausdorff_metrics(
spiral, reference_spiral
),
cls.AUC: lambda: drawing_error.calculate_area_under_curve(
spiral, reference_spiral
),
}
@dataclasses.dataclass
class CircleTarget:
"""Represents a target circle in the drawing task.
Attributes:
order: The order of the circle in the sequence.
label: The label of the circle.
center_x: The x-coordinate of the circle's center.
center_y: The y-coordinate of the circle's center.
radius: The radius of the circle.
"""
order: int
label: str
center_x: float
center_y: float
radius: float
def contains_point(self, x: float, y: float, tolerance: float = 1.5) -> bool:
"""Check if a point is within the circle (with tolerance multiplier).
Args:
x: X coordinate of the point.
y: Y coordinate of the point.
tolerance: Multiplier for the radius to define tolerance boundary.
Returns:
True if the point is within the circle (with tolerance), False otherwise.
"""
distance = np.sqrt((x - self.center_x) ** 2 + (y - self.center_y) ** 2)
return distance <= (self.radius * tolerance)
@dataclasses.dataclass
class LineSegment:
"""Represents a line drawn between two circles.
Attributes:
start_label: Label of the starting circle.
end_label: Label of the ending circle.
points: DataFrame containing the points in the line segment.
is_error: Whether the line segment is an error (missed target).
line_number: The line number of the segment.
Calculated features:
ink_time: Time spent drawing the line segment.
think_time: Time spent thinking before drawing the line segment.
think_circle_label: Label of the circle associated with think time.
distance: Total distance drawn outside circles.
mean_speed: Mean speed of drawing the line segment.
speed_variance: Variance of speed during the line segment.
path_optimality: Ratio of actual path length to optimal path length.
smoothness: Smoothness of the line segment based on curvature changes.
hesitation_count: Number of hesitations during the line segment.
hesitation_duration: Total duration of hesitations during the line segment.
velocities: List of velocities at each point in the line segment.
accelerations: List of accelerations at each point in the line segment.
"""
start_label: str
end_label: str
points: pd.DataFrame
is_error: bool
line_number: int
ink_points: np.ndarray = dataclasses.field(default_factory=lambda: np.array([]))
ink_time: float = 0.0
think_time: float = 0.0
think_circle_label: str = ""
distance: float = 0.0
mean_speed: float = 0.0
speed_variance: float = 0.0
path_optimality: float = 0.0
smoothness: float = 0.0
hesitation_count: int = 0
hesitation_duration: float = 0.0
velocities: List[float] = dataclasses.field(default_factory=list)
accelerations: List[float] = dataclasses.field(default_factory=list)
def valid_ink_trajectory(
self,
start_circle: CircleTarget,
end_circle: CircleTarget,
) -> Tuple[Optional[int], Optional[int]]:
"""Determine whether an ink trajectory exists from a start to end circle.
An "ink trajectory" is defined as the first contiguous sequence of
points that:
1. Begins **after** the pen leaves the start circle, and
2. Ends when the pen first enters the end circle.
The function scans point-by-point in order. The ink start index is the
first point whose (x, y) location is *outside* the start circle. The
ink end index is the first subsequent point whose (x, y) location falls
*inside* the end circle. If either of these conditions never occurs,
the trajectory is considered invalid. If a valid trajectory is found,
the ink_points attribute is updated to contain only the points within
this trajectory.
Args:
points: DataFrame of points with 'x' and 'y' columns.
start_circle: CircleTarget representing the start circle.
end_circle: CircleTarget representing the end circle.
Returns:
Tuple of (ink_start_idx: int, ink_end_idx: int) if valid
trajectory exists, else (None, None).
"""
ink_start_idx = None
ink_end_idx = None
for idx, row in self.points.iterrows():
if (
not start_circle.contains_point(row["x"], row["y"])
and ink_start_idx is None
):
ink_start_idx = idx
if ink_start_idx is not None and end_circle.contains_point(
row["x"], row["y"]
):
ink_end_idx = idx
break
if (
ink_start_idx is not None
and ink_end_idx is not None
and ink_end_idx > ink_start_idx
):
self.ink_points = self.points.iloc[ink_start_idx : ink_end_idx + 1].copy()
return ink_start_idx, ink_end_idx
def calculate_path_optimality(
self,
start_circle: CircleTarget,
end_circle: CircleTarget,
) -> None:
"""Calculate path optimality ratio.
The default value for path optimality in the LineSegment object is 0.0. This
function updates the path_optimality attribute of the LineSegment object based
on the optimal distance between the start and end circles, adjusted for their
radii. If the optimal distance is less than or equal to zero, the path
optimality remains 0.0.
Args:
segment: LineSegment object for which to calculate path optimality.
start_circle: CircleTarget representing the start circle.
end_circle: CircleTarget representing the end circle.
Returns:
Path optimality ratio.
"""
optimal_distance = (
dist.euclidean(
[start_circle.center_x, start_circle.center_y],
[end_circle.center_x, end_circle.center_y],
)
- start_circle.radius
- end_circle.radius
)
if optimal_distance > 0:
self.path_optimality = optimal_distance / self.distance
return
def calculate_velocity_metrics(self, ink_points: pd.DataFrame) -> None:
"""Get distance, velocity, and acceleration metrics of a LineSegment.
Args:
self: LineSegment object to calculate velocities for.
ink_points: DataFrame of ink points with 'x', 'y', and 'seconds' columns.
"""
dx = np.diff(ink_points["x"].values)
dy = np.diff(ink_points["y"].values)
dt = np.diff(ink_points["seconds"].values)
distances = np.sqrt(dx**2 + dy**2)
self.distance = np.sum(distances)
velocities = distances / dt
self.velocities = velocities.tolist()
self.mean_speed = np.mean(velocities)
self.speed_variance = np.var(velocities)
if len(velocities) >= 2:
self.accelerations = np.diff(velocities).tolist()
return