Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install the project
run: uv sync --only-dev --resolution=${{ matrix.dependency-mode }}
run: uv sync --resolution=${{ matrix.dependency-mode }}
- name: Run tests
id: run-tests
shell: bash
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# copilot
.github/**/*.md

# data
data/

Expand Down
56 changes: 30 additions & 26 deletions src/graphomotor/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,31 @@

import dataclasses
import logging
import warnings
from typing import Any

import numpy as np


@dataclasses.dataclass
def get_logger() -> logging.Logger:
"""Get the Graphomotor logger."""
logger = logging.getLogger("graphomotor")
if logger.handlers:
return logger
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - "
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger


logger = get_logger()


@dataclasses.dataclass(frozen=True)
class SpiralConfig:
"""Class for the parameters of anticipated spiral drawing."""

Expand All @@ -29,32 +48,17 @@ def add_custom_params(cls, config_dict: dict[str, float | int]) -> "SpiralConfig
Returns:
SpiralConfig instance with updated parameters.
"""
config = cls()
valid_params = {f.name for f in cls.__dataclass_fields__.values()}
filtered_params: dict[str, Any] = {}

for key, value in config_dict.items():
if hasattr(config, key):
setattr(config, key, value)
if key in valid_params:
filtered_params[key] = value
else:
valid_params = ", ".join(
f.name for f in cls.__dataclass_fields__.values()
)
warnings.warn(
valid_param_names = ", ".join(valid_params)
logger.warning(
f"Unknown configuration parameters will be ignored: {key}. "
f"Valid parameters are: {valid_params}"
f"Valid parameters are: {valid_param_names}"
)
return config


def get_logger() -> logging.Logger:
"""Get the Graphomotor logger."""
logger = logging.getLogger("graphomotor")
if logger.handlers:
return logger
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - "
"%(filename)s:%(lineno)s - %(funcName)s - %(message)s",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
return cls(**filtered_params)
2 changes: 1 addition & 1 deletion src/graphomotor/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def extract_features(
logger.debug("Generating reference spiral to calculate features")
config_to_use = spiral_config or config.SpiralConfig()
reference_spiral = generate_reference_spiral.generate_reference_spiral(
config=config_to_use
spiral_config=config_to_use
)
centered_reference_spiral = center_spiral.center_spiral(reference_spiral)

Expand Down
16 changes: 10 additions & 6 deletions src/graphomotor/utils/center_spiral.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def center_spiral(spiral):
spiral_config = config.SpiralConfig()

if isinstance(spiral, models.Spiral):
spiral.data["x"] -= spiral_config.center_x
spiral.data["y"] -= spiral_config.center_y
return spiral
centered_spiral = models.Spiral(
data=spiral.data.copy(), metadata=spiral.metadata.copy()
)
centered_spiral.data["x"] -= spiral_config.center_x
centered_spiral.data["y"] -= spiral_config.center_y
return centered_spiral
elif isinstance(spiral, np.ndarray):
spiral[:, 0] -= spiral_config.center_x
spiral[:, 1] -= spiral_config.center_y
return spiral
centered_spiral = spiral.copy()
centered_spiral[:, 0] -= spiral_config.center_x
centered_spiral[:, 1] -= spiral_config.center_y
return centered_spiral
else:
raise TypeError(
f"Expected models.Spiral or np.ndarray, got {type(spiral).__name__}"
Expand Down
91 changes: 58 additions & 33 deletions src/graphomotor/utils/generate_reference_spiral.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,82 @@
"""Utility functions for generating an equidistant reference spiral."""

import functools

import numpy as np
from scipy import integrate, optimize

from graphomotor.core.config import SpiralConfig
from graphomotor.core import config


def _arc_length_integrand(t: float, config: SpiralConfig) -> float:
def _arc_length_integrand(t: float, spiral_config: config.SpiralConfig) -> float:
"""Calculate the differential arc length at angle t for an Archimedean spiral.

Args:
t: Angle parameter.
config: Spiral configuration.
spiral_config: Spiral configuration.

Returns:
Differential arc length value.
"""
r_t = config.start_radius + config.growth_rate * t
return np.sqrt(r_t**2 + config.growth_rate**2)
r_t = spiral_config.start_radius + spiral_config.growth_rate * t
return np.sqrt(r_t**2 + spiral_config.growth_rate**2)


def _calculate_arc_length(theta: float, config: SpiralConfig) -> float:
"""Calculate the arc length of the spiral from start_angle to theta.
def _calculate_arc_length_between(
theta_start: float, theta_end: float, spiral_config: config.SpiralConfig
) -> float:
"""Calculate the arc length of the spiral between two theta values.

Args:
theta: The angle in radians.
config: Spiral configuration.
theta_start: Starting angle in radians.
theta_end: Ending angle in radians.
spiral_config: Spiral configuration.

Returns:
The arc length of the spiral from start_angle to theta.
The arc length of the spiral from theta_start to theta_end.
"""
return integrate.quad(
lambda t: _arc_length_integrand(t, config), config.start_angle, theta
lambda t: _arc_length_integrand(t, spiral_config),
theta_start,
theta_end,
)[0]


def _find_theta_for_arc_length(target_arc_length: float, config: SpiralConfig) -> float:
"""Find the theta value for a given arc length using numerical root finding.
def _find_theta_for_incremental_arc_length(
target_increment: float,
current_theta: float,
spiral_config: config.SpiralConfig,
) -> float:
"""Find the theta value for a given incremental arc length from current position.

Args:
target_arc_length: Target arc length.
config: Spiral configuration.
target_increment: Target arc length increment from current position.
current_theta: Current theta position.
spiral_config: Spiral configuration.

Returns:
Angle theta corresponding to the arc length.
Angle theta that results in the target arc length increment from current_theta.
"""
solution = optimize.root_scalar(
lambda theta: _calculate_arc_length(theta, config) - target_arc_length,
bracket=[config.start_angle, config.end_angle],
lambda theta: _calculate_arc_length_between(current_theta, theta, spiral_config)
- target_increment,
bracket=(current_theta, spiral_config.end_angle),
)
return solution.root


def generate_reference_spiral(config: SpiralConfig) -> np.ndarray:
@functools.lru_cache(maxsize=48)
def generate_reference_spiral(spiral_config: config.SpiralConfig) -> np.ndarray:
"""Generate a reference spiral with equidistant points along its arc length.

This function creates an Archimedean spiral with points distributed at equal arc
length intervals. The generated spiral serves as a standardized reference template
for feature extraction algorithms that compare user-drawn spirals with an ideal
form.

This function is decorated with an LRU cache to store pre-computed spirals for
faster retrieval on subsequent calls with the same configuration.

The algorithm works by:
1. Computing the total arc length for the entire spiral,
2. Creating equidistant target arc length values,
Expand All @@ -74,28 +91,36 @@ def generate_reference_spiral(config: SpiralConfig) -> np.ndarray:
- Cartesian coordinates: x = cx + r·cos(θ), y = cy + r·sin(θ)

Parameters are defined in the SpiralConfig class:
- Center coordinates: (cx, cy) = (config.center_x, config.center_y)
- Start radius: a = config.start_radius
- Growth rate: b = config.growth_rate
- Total rotation: θ = config.end_angle - config.start_angle
- Number of points: N = config.num_points
- Center coordinates: cx, cy = spiral_config.center_x, spiral_config.center_y
Comment thread
alperkent marked this conversation as resolved.
- Start radius: a = spiral_config.start_radius
- Growth rate: b = spiral_config.growth_rate
- Total rotation: θ = spiral_config.end_angle - spiral_config.start_angle
- Number of points: N = spiral_config.num_points

Args:
config: Configuration parameters for the spiral.
spiral_config: Configuration parameters for the spiral.

Returns:
Array with shape (N, 2) containing Cartesian coordinates of the spiral points.
"""
total_arc_length = _calculate_arc_length(config.end_angle, config)
total_arc_length = _calculate_arc_length_between(
spiral_config.start_angle, spiral_config.end_angle, spiral_config
)

arc_length_values = np.linspace(0, total_arc_length, config.num_points)
arc_length_increment = total_arc_length / (spiral_config.num_points - 1)

theta_values = np.array(
[_find_theta_for_arc_length(s, config) for s in arc_length_values]
)
theta_values = np.zeros(spiral_config.num_points)
Comment thread
alperkent marked this conversation as resolved.
theta_values[0] = spiral_config.start_angle

for i in range(1, spiral_config.num_points):
theta_values[i] = _find_theta_for_incremental_arc_length(
arc_length_increment,
theta_values[i - 1],
spiral_config,
)

r_values = config.start_radius + config.growth_rate * theta_values
x_values = config.center_x + r_values * np.cos(theta_values)
y_values = config.center_y + r_values * np.sin(theta_values)
r_values = spiral_config.start_radius + spiral_config.growth_rate * theta_values
x_values = spiral_config.center_x + r_values * np.cos(theta_values)
y_values = spiral_config.center_y + r_values * np.sin(theta_values)

return np.column_stack((x_values, y_values))
10 changes: 5 additions & 5 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
def test_spiral_config_add_custom_params_valid(
custom_params: dict[str, int | float],
expected_params: dict[str, int | float],
recwarn: pytest.WarningsRecorder,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that SpiralConfig.add_custom_params correctly sets parameter values."""
spiral_config = config.SpiralConfig.add_custom_params(custom_params)

for key, value in expected_params.items():
assert getattr(spiral_config, key) == value
assert len(recwarn) == 0
assert len(caplog.records) == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -71,17 +71,17 @@ def test_spiral_config_add_custom_params_warnings(
custom_params: dict[str, int | float],
expected_params: dict[str, int | float],
expected_warnings: list[str],
recwarn: pytest.WarningsRecorder,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that SpiralConfig.add_custom_params issues warnings appropriately."""
spiral_config = config.SpiralConfig.add_custom_params(custom_params)

assert len(recwarn) == len(expected_warnings)
assert len(caplog.records) == len(expected_warnings)
for key, value in expected_params.items():
assert getattr(spiral_config, key) == value
for i, param in enumerate(expected_warnings):
assert f"Unknown configuration parameters will be ignored: {param}" in str(
recwarn[i].message
caplog.records[i].message
)


Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_generate_reference_spiral.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
def test_generate_reference_spiral() -> None:
"""Test the generation of a reference spiral."""
Comment thread
alperkent marked this conversation as resolved.
spiral_config = config.SpiralConfig()
expected_mean_arc_length = generate_reference_spiral._calculate_arc_length(
spiral_config.end_angle, spiral_config
expected_mean_arc_length = generate_reference_spiral._calculate_arc_length_between(
spiral_config.start_angle, spiral_config.end_angle, spiral_config
) / (spiral_config.num_points - 1)

spiral = generate_reference_spiral.generate_reference_spiral(spiral_config)
Expand Down
Loading