Skip to content

Commit 14b932e

Browse files
authored
Merge pull request #22 from childmindresearch/alperkent/issue21
feature/issue-21/implement-velocity-features
2 parents 1a95356 + 2d24ef8 commit 14b932e

11 files changed

Lines changed: 254 additions & 35 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
- --autofix
2121
- --indent=2
2222
- id: pretty-format-toml
23-
exclude: ^poetry.lock$
23+
exclude: ^uv.lock$
2424
args:
2525
- --autofix
2626
- --indent=2

src/graphomotor/core/config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Configuration module for Graphomotor."""
2+
3+
import logging
4+
5+
import numpy as np
6+
7+
8+
class _SpiralConfig:
9+
"""Configuration for the reference spiral generation."""
10+
11+
SPIRAL_CENTER_X = 50
12+
SPIRAL_CENTER_Y = 50
13+
SPIRAL_START_RADIUS = 0
14+
SPIRAL_GROWTH_RATE = 1.075
15+
SPIRAL_START_ANGLE = 0
16+
SPIRAL_END_ANGLE = 8 * np.pi
17+
SPIRAL_NUM_POINTS = 10000
18+
19+
20+
def get_logger() -> logging.Logger:
21+
"""Get the Graphomotor logger."""
22+
logger = logging.getLogger("graphomotor")
23+
if logger.handlers:
24+
return logger
25+
logger.setLevel(logging.INFO)
26+
formatter = logging.Formatter(
27+
"%(asctime)s - %(name)s - %(levelname)s - "
28+
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
29+
)
30+
handler = logging.StreamHandler()
31+
handler.setFormatter(formatter)
32+
logger.addHandler(handler)
33+
return logger

src/graphomotor/features/time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from graphomotor.core import models
44

55

6-
def get_task_duration(spiral: models.Spiral) -> dict:
6+
def get_task_duration(spiral: models.Spiral) -> dict[str, float]:
77
"""Calculate the total duration of a spiral drawing task.
88
99
Args:
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Feature extraction module for velocity-based metrics in spiral drawing data."""
2+
3+
import numpy as np
4+
from scipy import stats
5+
6+
from graphomotor.core import models
7+
from graphomotor.utils import center_spiral
8+
9+
10+
def _calculate_statistics(values: np.ndarray, name: str) -> dict[str, float]:
11+
"""Helper function to calculate statistics for a given array.
12+
13+
Args:
14+
values: 1-D Numpy array of numerical values.
15+
name: Name prefix for the statistics (e.g., "linear_velocity").
16+
17+
Returns:
18+
Dictionary containing calculated metrics (sum, median, variation, skewness,
19+
kurtosis) with keys prefixed by the provided name.
20+
"""
21+
return {
22+
f"{name}_sum": np.sum(np.abs(values)),
23+
f"{name}_median": np.median(np.abs(values)),
24+
f"{name}_variation": stats.variation(values),
25+
f"{name}_skewness": stats.skew(values),
26+
f"{name}_kurtosis": stats.kurtosis(values),
27+
}
28+
29+
30+
def calculate_velocity_metrics(spiral: models.Spiral) -> dict[str, float]:
31+
"""Calculate velocity-based metrics from spiral drawing data.
32+
33+
This function computes three types of velocity metrics by calculating the difference
34+
between consecutive points in the spiral drawing data. The three types of velocity
35+
are:
36+
1. Linear velocity: The magnitude of change of Euclidean distance in pixels
37+
per second. This is calculated as the square root of the sum of squares of
38+
the differences in x and y coordinates divided by the difference in time.
39+
2. Radial velocity: The magnitude of change of distance from center (radius) in
40+
pixels per second. Radius is calculated as the square root of the sum of
41+
squares of x and y coordinates.
42+
3. Angular velocity: The magnitude of change of angle in radians per second.
43+
Angle is calculated using the arctangent of y coordinates divided by x
44+
coordinates, and then unwrapped to maintain continuity across the -π to π
45+
boundary.
46+
47+
For each velocity type, the following metrics are calculated:
48+
- Sum: Sum of absolute velocity values
49+
- Median: Median of absolute velocity values
50+
- Variation: Coefficient of variation
51+
- Skewness: Asymmetry of the velocity distribution
52+
- Kurtosis: Tailedness of the velocity distribution
53+
54+
Args:
55+
spiral: Spiral object containing drawing data.
56+
57+
Returns:
58+
Dictionary containing calculated velocity metrics.
59+
"""
60+
spiral = center_spiral.center_spiral(spiral)
61+
x = spiral.data["x"].values
62+
y = spiral.data["y"].values
63+
time = spiral.data["seconds"].values
64+
radius = np.sqrt(x**2 + y**2)
65+
theta = np.unwrap(np.arctan2(y, x))
66+
67+
dx = np.diff(x)
68+
dy = np.diff(y)
69+
dt = np.diff(time)
70+
dr = np.diff(radius)
71+
dtheta = np.diff(theta)
72+
73+
linear_velocity = np.sqrt(dx**2 + dy**2) / dt
74+
radial_velocity = dr / dt
75+
angular_velocity = dtheta / dt
76+
77+
return {
78+
**_calculate_statistics(linear_velocity, "linear_velocity"),
79+
**_calculate_statistics(radial_velocity, "radial_velocity"),
80+
**_calculate_statistics(angular_velocity, "angular_velocity"),
81+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Utility functions for centering a spiral."""
2+
3+
from graphomotor.core import config, models
4+
5+
6+
def center_spiral(spiral: models.Spiral) -> models.Spiral:
7+
"""Center a spiral by translating it to the origin.
8+
9+
Args:
10+
spiral: Spiral object containing spiral data.
11+
12+
Returns:
13+
Spiral object with centered spiral data.
14+
"""
15+
spiral.data["x"] -= config._SpiralConfig.SPIRAL_CENTER_X
16+
spiral.data["y"] -= config._SpiralConfig.SPIRAL_CENTER_Y
17+
18+
return spiral

src/graphomotor/utils/reference_spiral.py renamed to src/graphomotor/utils/generate_reference_spiral.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
1-
"""Generate a reference spiral with equidistant points along its arc length."""
1+
"""Utility functions for generating an equidistant reference spiral."""
22

33
import numpy as np
44
from scipy import integrate, optimize
55

6-
_SPIRAL_CENTER_X = 50
7-
_SPIRAL_CENTER_Y = 50
8-
_SPIRAL_START_RADIUS = 0
9-
_SPIRAL_GROWTH_RATE = 1.075
10-
_SPIRAL_START_ANGLE = 0
11-
_SPIRAL_END_ANGLE = 8 * np.pi
12-
_SPIRAL_NUM_POINTS = 10000
6+
from graphomotor.core.config import _SpiralConfig
137

148

159
def _arc_length_integrand(t: float) -> float:
@@ -21,8 +15,8 @@ def _arc_length_integrand(t: float) -> float:
2115
Returns:
2216
Differential arc length value.
2317
"""
24-
r_t = _SPIRAL_START_RADIUS + _SPIRAL_GROWTH_RATE * t
25-
return np.sqrt(r_t**2 + _SPIRAL_GROWTH_RATE**2)
18+
r_t = _SpiralConfig.SPIRAL_START_RADIUS + _SpiralConfig.SPIRAL_GROWTH_RATE * t
19+
return np.sqrt(r_t**2 + _SpiralConfig.SPIRAL_GROWTH_RATE**2)
2620

2721

2822
def _calculate_arc_length(theta: float) -> float:
@@ -35,7 +29,7 @@ def _calculate_arc_length(theta: float) -> float:
3529
The arc length of the spiral from _SPIRAL_START_ANGLE to theta.
3630
"""
3731
return integrate.quad(
38-
lambda t: _arc_length_integrand(t), _SPIRAL_START_ANGLE, theta
32+
lambda t: _arc_length_integrand(t), _SpiralConfig.SPIRAL_START_ANGLE, theta
3933
)[0]
4034

4135

@@ -50,7 +44,7 @@ def _find_theta_for_arc_length(target_arc_length: float) -> float:
5044
"""
5145
solution = optimize.root_scalar(
5246
lambda theta: _calculate_arc_length(theta) - target_arc_length,
53-
bracket=[_SPIRAL_START_ANGLE, _SPIRAL_END_ANGLE],
47+
bracket=[_SpiralConfig.SPIRAL_START_ANGLE, _SpiralConfig.SPIRAL_END_ANGLE],
5448
)
5549
return solution.root
5650

@@ -86,14 +80,19 @@ def generate_reference_spiral() -> np.ndarray:
8680
Returns:
8781
Array with shape (N, 2) containing Cartesian coordinates of the spiral points.
8882
"""
89-
total_arc_length = _calculate_arc_length(_SPIRAL_END_ANGLE)
83+
total_arc_length = _calculate_arc_length(_SpiralConfig.SPIRAL_END_ANGLE)
9084

91-
arc_length_values = np.linspace(0, total_arc_length, _SPIRAL_NUM_POINTS)
85+
arc_length_values = np.linspace(
86+
0, total_arc_length, _SpiralConfig.SPIRAL_NUM_POINTS
87+
)
9288

9389
theta_values = np.array([_find_theta_for_arc_length(s) for s in arc_length_values])
9490

95-
r_values = _SPIRAL_START_RADIUS + _SPIRAL_GROWTH_RATE * theta_values
96-
x_values = _SPIRAL_CENTER_X + r_values * np.cos(theta_values)
97-
y_values = _SPIRAL_CENTER_Y + r_values * np.sin(theta_values)
91+
r_values = (
92+
_SpiralConfig.SPIRAL_START_RADIUS
93+
+ _SpiralConfig.SPIRAL_GROWTH_RATE * theta_values
94+
)
95+
x_values = _SpiralConfig.SPIRAL_CENTER_X + r_values * np.cos(theta_values)
96+
y_values = _SpiralConfig.SPIRAL_CENTER_Y + r_values * np.sin(theta_values)
9897

9998
return np.column_stack((x_values, y_values))

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
from graphomotor.core import models
11-
from graphomotor.utils import reference_spiral
11+
from graphomotor.utils import generate_reference_spiral
1212

1313

1414
@pytest.fixture
@@ -56,4 +56,4 @@ def valid_spiral(
5656
@pytest.fixture
5757
def ref_spiral() -> np.ndarray:
5858
"""Create a reference spiral for testing."""
59-
return reference_spiral.generate_reference_spiral()
59+
return generate_reference_spiral.generate_reference_spiral()

tests/unit/test_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Test cases for config.py functions."""
2+
3+
import logging
4+
5+
import pytest
6+
7+
from graphomotor.core import config
8+
9+
10+
def test_get_logger(caplog: pytest.LogCaptureFixture) -> None:
11+
"""Test the graphomotor logger with level set to INFO (20)."""
12+
if logging.getLogger("graphomotor").handlers:
13+
logging.getLogger("graphomotor").handlers.clear()
14+
logger = config.get_logger()
15+
16+
logger.debug("Debug message here.")
17+
logger.info("Info message here.")
18+
logger.warning("Warning message here.")
19+
20+
assert logger.getEffectiveLevel() == logging.INFO
21+
assert "Debug message here" not in caplog.text
22+
assert "Info message here." in caplog.text
23+
assert "Warning message here." in caplog.text
24+
25+
26+
def test_get_logger_second_call() -> None:
27+
"""Test get logger when a handler already exists."""
28+
logger = config.get_logger()
29+
second_logger = config.get_logger()
30+
31+
assert len(logger.handlers) == len(second_logger.handlers) == 1
32+
assert logger.handlers[0] is second_logger.handlers[0]
33+
assert logger is second_logger

tests/unit/test_drawing_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def test_calculate_area_under_curve(valid_spiral: models.Spiral) -> None:
2020
valid_spiral, np.column_stack((x, y2))
2121
)["area_under_curve"]
2222

23-
assert np.isclose(calculated_area, expected_area, rtol=1e-3)
23+
assert np.isclose(calculated_area, expected_area, atol=0, rtol=1e-3)

tests/unit/test_reference_spiral.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,36 @@
22

33
import numpy as np
44

5-
from graphomotor.utils import reference_spiral
5+
from graphomotor.core import config
6+
from graphomotor.utils import generate_reference_spiral
67

78

89
def test_generate_reference_spiral() -> None:
910
"""Test the generation of a reference spiral."""
10-
expected_mean_arc_length = reference_spiral._calculate_arc_length(
11-
reference_spiral._SPIRAL_END_ANGLE
12-
) / (reference_spiral._SPIRAL_NUM_POINTS - 1)
11+
expected_mean_arc_length = generate_reference_spiral._calculate_arc_length(
12+
config._SpiralConfig.SPIRAL_END_ANGLE
13+
) / (config._SpiralConfig.SPIRAL_NUM_POINTS - 1)
1314

14-
spiral = reference_spiral.generate_reference_spiral()
15+
spiral = generate_reference_spiral.generate_reference_spiral()
1516
arc_lengths = np.linalg.norm(spiral[1:] - spiral[:-1], axis=1)
1617
mean_arc_length = np.mean(arc_lengths)
1718

1819
assert isinstance(spiral, np.ndarray)
19-
assert spiral.shape == (reference_spiral._SPIRAL_NUM_POINTS, 2)
20+
assert spiral.shape == (config._SpiralConfig.SPIRAL_NUM_POINTS, 2)
2021
assert np.array_equal(
2122
spiral[0],
22-
[reference_spiral._SPIRAL_CENTER_X, reference_spiral._SPIRAL_CENTER_Y],
23+
[config._SpiralConfig.SPIRAL_CENTER_X, config._SpiralConfig.SPIRAL_CENTER_Y],
2324
)
2425
assert np.allclose(
2526
spiral[-1],
2627
[
27-
reference_spiral._SPIRAL_CENTER_X
28-
+ reference_spiral._SPIRAL_GROWTH_RATE * reference_spiral._SPIRAL_END_ANGLE,
29-
reference_spiral._SPIRAL_CENTER_Y,
28+
config._SpiralConfig.SPIRAL_CENTER_X
29+
+ config._SpiralConfig.SPIRAL_GROWTH_RATE
30+
* config._SpiralConfig.SPIRAL_END_ANGLE,
31+
config._SpiralConfig.SPIRAL_CENTER_Y,
3032
],
31-
atol=1e-8,
33+
atol=0,
34+
rtol=1e-8,
3235
)
33-
assert np.allclose(arc_lengths, mean_arc_length, rtol=1e-3)
34-
assert np.isclose(mean_arc_length, expected_mean_arc_length, rtol=1e-6)
36+
assert np.allclose(arc_lengths, mean_arc_length, atol=0, rtol=1e-3)
37+
assert np.isclose(mean_arc_length, expected_mean_arc_length, atol=0, rtol=1e-6)

0 commit comments

Comments
 (0)