Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install the project
run: uv sync --only-dev --resolution=${{ matrix.dependency-mode }}
run: uv sync --resolution=${{ matrix.dependency-mode }}
- name: Run tests
id: run-tests
shell: bash
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# copilot
.github/**/*.md

# data
data/

Expand Down
56 changes: 30 additions & 26 deletions src/graphomotor/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@

import dataclasses
import logging
import warnings
from typing import Any

import numpy as np


@dataclasses.dataclass
def get_logger() -> logging.Logger:
"""Get the Graphomotor logger."""
logger = logging.getLogger("graphomotor")
if logger.handlers:
return logger
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - "
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger


logger = get_logger()


@dataclasses.dataclass(frozen=True)
class SpiralConfig:
"""Class for the parameters of anticipated spiral drawing."""

Expand All @@ -29,32 +48,17 @@ def add_custom_params(cls, config_dict: dict[str, float | int]) -> "SpiralConfig
Returns:
SpiralConfig instance with updated parameters.
"""
config = cls()
valid_params = {f.name for f in cls.__dataclass_fields__.values()}
filtered_params: dict[str, Any] = {}

for key, value in config_dict.items():
if hasattr(config, key):
setattr(config, key, value)
if key in valid_params:
filtered_params[key] = value
else:
valid_params = ", ".join(
f.name for f in cls.__dataclass_fields__.values()
)
warnings.warn(
valid_param_names = ", ".join(valid_params)
logger.warning(
f"Unknown configuration parameters will be ignored: {key}. "
f"Valid parameters are: {valid_params}"
f"Valid parameters are: {valid_param_names}"
)
return config


def get_logger() -> logging.Logger:
"""Get the Graphomotor logger."""
logger = logging.getLogger("graphomotor")
if logger.handlers:
return logger
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - "
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
return cls(**filtered_params)
2 changes: 1 addition & 1 deletion src/graphomotor/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def extract_features(
logger.debug("Generating reference spiral to calculate features")
config_to_use = spiral_config or config.SpiralConfig()
reference_spiral = generate_reference_spiral.generate_reference_spiral(
config=config_to_use
spiral_config=config_to_use
)
centered_reference_spiral = center_spiral.center_spiral(reference_spiral)

Expand Down
16 changes: 10 additions & 6 deletions src/graphomotor/utils/center_spiral.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def center_spiral(spiral):
spiral_config = config.SpiralConfig()

if isinstance(spiral, models.Spiral):
spiral.data["x"] -= spiral_config.center_x
spiral.data["y"] -= spiral_config.center_y
return spiral
centered_spiral = models.Spiral(
data=spiral.data.copy(), metadata=spiral.metadata.copy()
)
centered_spiral.data["x"] -= spiral_config.center_x
centered_spiral.data["y"] -= spiral_config.center_y
return centered_spiral
elif isinstance(spiral, np.ndarray):
spiral[:, 0] -= spiral_config.center_x
spiral[:, 1] -= spiral_config.center_y
return spiral
centered_spiral = spiral.copy()
centered_spiral[:, 0] -= spiral_config.center_x
centered_spiral[:, 1] -= spiral_config.center_y
return centered_spiral
else:
raise TypeError(
f"Expected models.Spiral or np.ndarray, got {type(spiral).__name__}"
Expand Down
91 changes: 58 additions & 33 deletions src/graphomotor/utils/generate_reference_spiral.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,82 @@
"""Utility functions for generating an equidistant reference spiral."""

import functools

import numpy as np
from scipy import integrate, optimize

from graphomotor.core.config import SpiralConfig
from graphomotor.core import config


def _arc_length_integrand(t: float, config: SpiralConfig) -> float:
def _arc_length_integrand(t: float, spiral_config: config.SpiralConfig) -> float:
"""Calculate the differential arc length at angle t for an Archimedean spiral.

Args:
t: Angle parameter.
config: Spiral configuration.
spiral_config: Spiral configuration.

Returns:
Differential arc length value.
"""
r_t = config.start_radius + config.growth_rate * t
return np.sqrt(r_t**2 + config.growth_rate**2)
r_t = spiral_config.start_radius + spiral_config.growth_rate * t
return np.sqrt(r_t**2 + spiral_config.growth_rate**2)


def _calculate_arc_length(theta: float, config: SpiralConfig) -> float:
"""Calculate the arc length of the spiral from start_angle to theta.
def _calculate_arc_length_between(
theta_start: float, theta_end: float, spiral_config: config.SpiralConfig
) -> float:
"""Calculate the arc length of the spiral between two theta values.

Args:
theta: The angle in radians.
config: Spiral configuration.
theta_start: Starting angle in radians.
theta_end: Ending angle in radians.
spiral_config: Spiral configuration.

Returns:
The arc length of the spiral from start_angle to theta.
The arc length of the spiral from theta_start to theta_end.
"""
return integrate.quad(
lambda t: _arc_length_integrand(t, config), config.start_angle, theta
lambda t: _arc_length_integrand(t, spiral_config),
theta_start,
theta_end,
)[0]


def _find_theta_for_arc_length(target_arc_length: float, config: SpiralConfig) -> float:
"""Find the theta value for a given arc length using numerical root finding.
def _find_theta_for_incremental_arc_length(
target_increment: float,
current_theta: float,
spiral_config: config.SpiralConfig,
) -> float:
"""Find the theta value for a given incremental arc length from current position.

Args:
target_arc_length: Target arc length.
config: Spiral configuration.
target_increment: Target arc length increment from current position.
current_theta: Current theta position.
spiral_config: Spiral configuration.

Returns:
Angle theta corresponding to the arc length.
Angle theta that results in the target arc length increment from current_theta.
"""
solution = optimize.root_scalar(
lambda theta: _calculate_arc_length(theta, config) - target_arc_length,
bracket=[config.start_angle, config.end_angle],
lambda theta: _calculate_arc_length_between(current_theta, theta, spiral_config)
- target_increment,
bracket=(current_theta, spiral_config.end_angle),
)
return solution.root


def generate_reference_spiral(config: SpiralConfig) -> np.ndarray:
@functools.lru_cache(maxsize=48)
def generate_reference_spiral(spiral_config: config.SpiralConfig) -> np.ndarray:
"""Generate a reference spiral with equidistant points along its arc length.

This function creates an Archimedean spiral with points distributed at equal arc
length intervals. The generated spiral serves as a standardized reference template
for feature extraction algorithms that compare user-drawn spirals with an ideal
form.

This function is decorated with an LRU cache to store pre-computed spirals for
faster retrieval on subsequent calls with the same configuration.

The algorithm works by:
1. Computing the total arc length for the entire spiral,
2. Creating equidistant target arc length values,
Expand All @@ -74,28 +91,36 @@ def generate_reference_spiral(config: SpiralConfig) -> np.ndarray:
- Cartesian coordinates: x = cx + r·cos(θ), y = cy + r·sin(θ)

Parameters are defined in the SpiralConfig class:
- Center coordinates: (cx, cy) = (config.center_x, config.center_y)
- Start radius: a = config.start_radius
- Growth rate: b = config.growth_rate
- Total rotation: θ = config.end_angle - config.start_angle
- Number of points: N = config.num_points
- Center coordinates: cx, cy = spiral_config.center_x, spiral_config.center_y
Comment thread
alperkent marked this conversation as resolved.
- Start radius: a = spiral_config.start_radius
- Growth rate: b = spiral_config.growth_rate
- Total rotation: θ = spiral_config.end_angle - spiral_config.start_angle
- Number of points: N = spiral_config.num_points

Args:
config: Configuration parameters for the spiral.
spiral_config: Configuration parameters for the spiral.

Returns:
Array with shape (N, 2) containing Cartesian coordinates of the spiral points.
"""
total_arc_length = _calculate_arc_length(config.end_angle, config)
total_arc_length = _calculate_arc_length_between(
spiral_config.start_angle, spiral_config.end_angle, spiral_config
)

arc_length_values = np.linspace(0, total_arc_length, config.num_points)
arc_length_increment = total_arc_length / (spiral_config.num_points - 1)

theta_values = np.array(
[_find_theta_for_arc_length(s, config) for s in arc_length_values]
)
theta_values = np.zeros(spiral_config.num_points)
Comment thread
alperkent marked this conversation as resolved.
theta_values[0] = spiral_config.start_angle

for i in range(1, spiral_config.num_points):
theta_values[i] = _find_theta_for_incremental_arc_length(
arc_length_increment,
theta_values[i - 1],
spiral_config,
)

r_values = config.start_radius + config.growth_rate * theta_values
x_values = config.center_x + r_values * np.cos(theta_values)
y_values = config.center_y + r_values * np.sin(theta_values)
r_values = spiral_config.start_radius + spiral_config.growth_rate * theta_values
x_values = spiral_config.center_x + r_values * np.cos(theta_values)
y_values = spiral_config.center_y + r_values * np.sin(theta_values)

return np.column_stack((x_values, y_values))
10 changes: 5 additions & 5 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
def test_spiral_config_add_custom_params_valid(
custom_params: dict[str, int | float],
expected_params: dict[str, int | float],
recwarn: pytest.WarningsRecorder,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that SpiralConfig.add_custom_params correctly sets parameter values."""
spiral_config = config.SpiralConfig.add_custom_params(custom_params)

for key, value in expected_params.items():
assert getattr(spiral_config, key) == value
assert len(recwarn) == 0
assert len(caplog.records) == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -71,17 +71,17 @@ def test_spiral_config_add_custom_params_warnings(
custom_params: dict[str, int | float],
expected_params: dict[str, int | float],
expected_warnings: list[str],
recwarn: pytest.WarningsRecorder,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that SpiralConfig.add_custom_params issues warnings appropriately."""
spiral_config = config.SpiralConfig.add_custom_params(custom_params)

assert len(recwarn) == len(expected_warnings)
assert len(caplog.records) == len(expected_warnings)
for key, value in expected_params.items():
assert getattr(spiral_config, key) == value
for i, param in enumerate(expected_warnings):
assert f"Unknown configuration parameters will be ignored: {param}" in str(
recwarn[i].message
caplog.records[i].message
)


Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_generate_reference_spiral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from graphomotor.utils import generate_reference_spiral


def test_generate_reference_spiral() -> None:
def test_compute_reference_spiral() -> None:
"""Test the generation of a reference spiral."""
Comment thread
alperkent marked this conversation as resolved.
spiral_config = config.SpiralConfig()
expected_mean_arc_length = generate_reference_spiral._calculate_arc_length(
spiral_config.end_angle, spiral_config
expected_mean_arc_length = generate_reference_spiral._calculate_arc_length_between(
spiral_config.start_angle, spiral_config.end_angle, spiral_config
) / (spiral_config.num_points - 1)

spiral = generate_reference_spiral.generate_reference_spiral(spiral_config)
Expand Down
Loading