Skip to content

Commit ecd4ea1

Browse files
committed
all my functions disappeared during merge from main so added back in functions and fixed tests
1 parent 3b5907a commit ecd4ea1

2 files changed

Lines changed: 255 additions & 162 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,16 @@ def calculate_path_optimality(
285285
self.path_optimality = optimal_distance / self.distance
286286
return
287287

288-
def calculate_velocity_metrics(self, ink_points: pd.DataFrame) -> None:
289-
"""Get distance, velocity, and acceleration metrics of a LineSegment.
288+
def calculate_velocity_metrics(self) -> None:
289+
"""Get velocity metrics of a LineSegment.
290290
291291
Args:
292292
self: LineSegment object to calculate velocities for.
293-
ink_points: DataFrame of ink points with 'x', 'y', and 'seconds' columns.
293+
ink_points: DataFrame containing the ink points for the line segment.
294294
"""
295-
dx = np.diff(ink_points["x"].values)
296-
dy = np.diff(ink_points["y"].values)
297-
dt = np.diff(ink_points["seconds"].values)
295+
dx = np.diff(self.ink_points["x"].values)
296+
dy = np.diff(self.ink_points["y"].values)
297+
dt = np.diff(self.ink_points["seconds"].values)
298298

299299
distances = np.sqrt(dx**2 + dy**2)
300300
self.distance = np.sum(distances)
@@ -309,3 +309,143 @@ def calculate_velocity_metrics(self, ink_points: pd.DataFrame) -> None:
309309
self.accelerations = np.diff(velocities).tolist()
310310

311311
return
312+
313+
def detect_hesitations(self, threshold_percentile: int = 20) -> None:
314+
"""Detect hesitations as periods of significantly reduced velocity.
315+
316+
This function defines a hesitation as any period where the velocity falls below
317+
a certain threshold, which is determined by the specified percentile of the
318+
velocity distribution. It counts the number of distinct hesitation periods and
319+
adds 1 if the line starts with a hesitation. It also calculates the total
320+
duration of hesitations based on the number of points that fall below the
321+
threshold and the time interval between points.
322+
323+
Args:
324+
ink_points: DataFrame containing the ink points for the line segment.
325+
threshold_percentile: Percentile to determine the velocity threshold for
326+
hesitations (default is 20, meaning the bottom 20% of velocities are
327+
considered hesitations).
328+
"""
329+
if len(self.velocities) < 3:
330+
return
331+
332+
dt = np.diff(self.ink_points["seconds"].values)
333+
334+
threshold = np.percentile(self.velocities, threshold_percentile)
335+
hesitations = self.velocities < threshold
336+
337+
hesitation_changes = np.diff(hesitations.astype(int))
338+
hesitation_starts = np.where(hesitation_changes == 1)[0] + 1
339+
hesitation_count = len(hesitation_starts)
340+
341+
if hesitations[0]:
342+
hesitation_count += 1
343+
344+
hesitation_duration = np.sum(hesitations) * dt[0]
345+
346+
self.hesitation_count = hesitation_count
347+
self.hesitation_duration = hesitation_duration
348+
349+
return
350+
351+
def calculate_smoothness(self) -> None:
352+
"""Calculate path smoothness based on Root Mean Square (RMS) curvature.
353+
354+
Represents the curvature per unit arc length.
355+
Lower values indicate smoother drawings. Penalizes sharp corners (e.g.,
356+
90° turns) and noisy corrections. Normalized by arc length to reduce
357+
sampling-rate dependence.
358+
"""
359+
if len(self.ink_points) < 3:
360+
return 0.0
361+
return
362+
363+
xy = self.ink_points[["x", "y"]].to_numpy()
364+
365+
forward_vector = xy[1:-1] - xy[:-2]
366+
backward_vector = xy[2:] - xy[1:-1]
367+
368+
forward_norm = np.linalg.norm(forward_vector, axis=1)
369+
backward_norm = np.linalg.norm(backward_vector, axis=1)
370+
371+
valid = (forward_norm > 0) & (backward_norm > 0)
372+
if not np.any(valid):
373+
return
374+
375+
valid_forward_vector = forward_vector[valid]
376+
valid_backward_vector = backward_vector[valid]
377+
valid_forward_norm = forward_norm[valid]
378+
valid_backward_norm = backward_norm[valid]
379+
380+
cos_angle = (valid_forward_vector * valid_backward_vector).sum(axis=1) / (
381+
valid_forward_norm * valid_backward_norm
382+
)
383+
cos_angle = np.clip(cos_angle, -1.0, 1.0)
384+
385+
angles = np.arccos(cos_angle)
386+
387+
avg_segment_length = (valid_forward_norm + valid_backward_norm) / 2.0
388+
curvatures = angles / avg_segment_length
389+
390+
self.smoothness = float(np.sqrt(np.mean(curvatures**2)))
391+
392+
return
393+
394+
def compute_segment_metrics(
395+
self, circles: dict[str, dict[str, CircleTarget]], trail_id: str
396+
) -> None:
397+
"""Compute all metrics for a line segment.
398+
399+
This function computes various metrics for the line segment, including ink time,
400+
velocity metrics, path optimality, smoothness, and hesitation detection. It
401+
first determines the valid ink trajectory between the start and end circles. If
402+
a valid trajectory is found, it updates the ink_points attribute and calculates
403+
the metrics.
404+
405+
Args:
406+
circles: A dictionary mapping each trail type to dictionaries of
407+
CircleTarget instances (output of load_scaled_circles in config).
408+
trail_id: Trail identifier for circle lookup.
409+
"""
410+
trail_circles = circles[trail_id]
411+
points = self.points.copy()
412+
413+
if len(points) < 2:
414+
return
415+
416+
if self.start_label not in trail_circles or self.end_label not in trail_circles:
417+
return
418+
419+
start_circle = trail_circles[self.start_label]
420+
end_circle = trail_circles[self.end_label]
421+
422+
ink_start_idx, ink_end_idx = self.valid_ink_trajectory(start_circle, end_circle)
423+
424+
if (
425+
ink_start_idx is not None
426+
and ink_end_idx is not None
427+
and ink_end_idx > ink_start_idx
428+
):
429+
self.ink_points = self.points.iloc[ink_start_idx : ink_end_idx + 1].copy()
430+
431+
if len(self.ink_points) >= 2:
432+
ink_start = self.ink_points.iloc[0]["seconds"]
433+
ink_end = self.ink_points.iloc[-1]["seconds"]
434+
self.ink_time = ink_end - ink_start
435+
436+
self.calculate_velocity_metrics()
437+
438+
self.calculate_path_optimality(start_circle, end_circle)
439+
440+
self.calculate_smoothness()
441+
442+
self.detect_hesitations()
443+
444+
elif ink_start_idx is not None:
445+
self.ink_points = points.iloc[ink_start_idx:].copy()
446+
if len(self.ink_points) >= 2:
447+
self.ink_time = (
448+
self.ink_points.iloc[-1]["seconds"]
449+
- self.ink_points.iloc[0]["seconds"]
450+
)
451+
return

0 commit comments

Comments
 (0)