Skip to content

Commit 850224c

Browse files
committed
requested changes
1 parent 46c68d1 commit 850224c

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

src/graphomotor/features/trails/drawing_metrics.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def percent_accurate_paths(drawing: models.Drawing) -> dict[str, float]:
4646

4747

4848
def calculate_smoothness(points: pd.DataFrame) -> float:
49-
"""Calculate path smoothness based on RMS curvature.
49+
"""Calculate path smoothness based on Root Mean Square (RMS) curvature.
5050
5151
Represants the curvature per unit arc length.
5252
Lower values indicate smoother drawings. Penalizes sharp corners (e.g., 90° turns)
@@ -63,27 +63,29 @@ def calculate_smoothness(points: pd.DataFrame) -> float:
6363

6464
xy = points[["x", "y"]].to_numpy()
6565

66-
v1 = xy[1:-1] - xy[:-2]
67-
v2 = xy[2:] - xy[1:-1]
66+
forward_vector = xy[1:-1] - xy[:-2]
67+
backward_vector = xy[2:] - xy[1:-1]
6868

69-
l1 = np.linalg.norm(v1, axis=1)
70-
l2 = np.linalg.norm(v2, axis=1)
69+
forward_norm = np.linalg.norm(forward_vector, axis=1)
70+
backward_norm = np.linalg.norm(backward_vector, axis=1)
7171

72-
valid = (l1 > 0) & (l2 > 0)
72+
valid = (forward_norm > 0) & (backward_norm > 0)
7373
if not np.any(valid):
7474
return 0.0
7575

76-
v1 = v1[valid]
77-
v2 = v2[valid]
78-
l1 = l1[valid]
79-
l2 = l2[valid]
76+
valid_forward_vector = forward_vector[valid]
77+
valid_backward_vector = backward_vector[valid]
78+
valid_forward_norm = forward_norm[valid]
79+
valid_backward_norm = backward_norm[valid]
8080

81-
cos_angle = np.einsum("ij,ij->i", v1, v2) / (l1 * l2)
81+
cos_angle = (valid_forward_vector * valid_backward_vector).sum(axis=1) / (
82+
valid_forward_norm * valid_backward_norm
83+
)
8284
cos_angle = np.clip(cos_angle, -1.0, 1.0)
8385

8486
angles = np.arccos(cos_angle)
8587

86-
arc_len = (l1 + l2) / 2.0
88+
arc_len = (valid_forward_norm + valid_backward_norm) / 2.0
8789
curvatures = angles / arc_len
8890

8991
return float(np.sqrt(np.mean(curvatures**2)))

0 commit comments

Comments
 (0)