Skip to content

Commit 2d24ef8

Browse files
authored
Refactor logger setup in config module to prevent upward propagation of handler search, add unit tests for logger functionality, refactor velocity module to use center spiral module and change the private function to be more generic.
1 parent 3309d85 commit 2d24ef8

3 files changed

Lines changed: 58 additions & 23 deletions

File tree

src/graphomotor/core/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ class _SpiralConfig:
2020
def get_logger() -> logging.Logger:
2121
"""Get the Graphomotor logger."""
2222
logger = logging.getLogger("graphomotor")
23-
if logger.hasHandlers():
23+
if logger.handlers:
2424
return logger
2525
logger.setLevel(logging.INFO)
26-
handler = logging.StreamHandler()
2726
formatter = logging.Formatter(
2827
"%(asctime)s - %(name)s - %(levelname)s - "
2928
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
3029
)
30+
handler = logging.StreamHandler()
3131
handler.setFormatter(formatter)
3232
logger.addHandler(handler)
3333
return logger

src/graphomotor/features/velocity.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,26 @@
44
from scipy import stats
55

66
from graphomotor.core import models
7+
from graphomotor.utils import center_spiral
78

89

9-
def _get_velocity_statistics(velocity: np.ndarray, type_: str) -> dict[str, float]:
10-
"""Calculate velocity metrics for a given type of velocity.
10+
def _calculate_statistics(values: np.ndarray, name: str) -> dict[str, float]:
11+
"""Helper function to calculate statistics for a given array.
1112
1213
Args:
13-
velocity: Numpy array of velocity values.
14-
type_: Type of velocity (e.g., "linear_velocity", "radial_velocity",
15-
"angular_velocity").
14+
values: 1-D Numpy array of numerical values.
15+
name: Name prefix for the statistics (e.g., "linear_velocity").
1616
1717
Returns:
18-
Dictionary containing calculated metrics for the specified type of velocity.
18+
Dictionary containing calculated metrics (sum, median, variation, skewness,
19+
kurtosis) with keys prefixed by the provided name.
1920
"""
2021
return {
21-
f"{type_}_sum": np.sum(np.abs(velocity)),
22-
f"{type_}_median": np.median(np.abs(velocity)),
23-
f"{type_}_variation": stats.variation(velocity),
24-
f"{type_}_skewness": stats.skew(velocity),
25-
f"{type_}_kurtosis": stats.kurtosis(velocity),
22+
f"{name}_sum": np.sum(np.abs(values)),
23+
f"{name}_median": np.median(np.abs(values)),
24+
f"{name}_variation": stats.variation(values),
25+
f"{name}_skewness": stats.skew(values),
26+
f"{name}_kurtosis": stats.kurtosis(values),
2627
}
2728

2829

@@ -44,7 +45,7 @@ def calculate_velocity_metrics(spiral: models.Spiral) -> dict[str, float]:
4445
boundary.
4546
4647
For each velocity type, the following metrics are calculated:
47-
- Sum: Total absolute velocity over the entire drawing
48+
- Sum: Sum of absolute velocity values
4849
- Median: Median of absolute velocity values
4950
- Variation: Coefficient of variation
5051
- Skewness: Asymmetry of the velocity distribution
@@ -56,14 +57,15 @@ def calculate_velocity_metrics(spiral: models.Spiral) -> dict[str, float]:
5657
Returns:
5758
Dictionary containing calculated velocity metrics.
5859
"""
59-
x_coord = spiral.data["x"].values - 50
60-
y_coord = spiral.data["y"].values - 50
60+
spiral = center_spiral.center_spiral(spiral)
61+
x = spiral.data["x"].values
62+
y = spiral.data["y"].values
6163
time = spiral.data["seconds"].values
62-
radius = np.sqrt(x_coord**2 + y_coord**2)
63-
theta = np.unwrap(np.arctan2(y_coord, x_coord))
64+
radius = np.sqrt(x**2 + y**2)
65+
theta = np.unwrap(np.arctan2(y, x))
6466

65-
dx = np.diff(x_coord)
66-
dy = np.diff(y_coord)
67+
dx = np.diff(x)
68+
dy = np.diff(y)
6769
dt = np.diff(time)
6870
dr = np.diff(radius)
6971
dtheta = np.diff(theta)
@@ -73,7 +75,7 @@ def calculate_velocity_metrics(spiral: models.Spiral) -> dict[str, float]:
7375
angular_velocity = dtheta / dt
7476

7577
return {
76-
**_get_velocity_statistics(linear_velocity, "linear_velocity"),
77-
**_get_velocity_statistics(radial_velocity, "radial_velocity"),
78-
**_get_velocity_statistics(angular_velocity, "angular_velocity"),
78+
**_calculate_statistics(linear_velocity, "linear_velocity"),
79+
**_calculate_statistics(radial_velocity, "radial_velocity"),
80+
**_calculate_statistics(angular_velocity, "angular_velocity"),
7981
}

tests/unit/test_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Test cases for config.py functions."""
2+
3+
import logging
4+
5+
import pytest
6+
7+
from graphomotor.core import config
8+
9+
10+
def test_get_logger(caplog: pytest.LogCaptureFixture) -> None:
11+
"""Test the graphomotor logger with level set to INFO (20)."""
12+
if logging.getLogger("graphomotor").handlers:
13+
logging.getLogger("graphomotor").handlers.clear()
14+
logger = config.get_logger()
15+
16+
logger.debug("Debug message here.")
17+
logger.info("Info message here.")
18+
logger.warning("Warning message here.")
19+
20+
assert logger.getEffectiveLevel() == logging.INFO
21+
assert "Debug message here" not in caplog.text
22+
assert "Info message here." in caplog.text
23+
assert "Warning message here." in caplog.text
24+
25+
26+
def test_get_logger_second_call() -> None:
27+
"""Test get logger when a handler already exists."""
28+
logger = config.get_logger()
29+
second_logger = config.get_logger()
30+
31+
assert len(logger.handlers) == len(second_logger.handlers) == 1
32+
assert logger.handlers[0] is second_logger.handlers[0]
33+
assert logger is second_logger

0 commit comments

Comments
 (0)