Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.12, 3.13]
python-version: [3.11, 3.12, 3.13]
dependency-mode: [lowest-direct, highest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# experiments
experiments/

# vscode
.vscode/

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
license = "LGPL-2.1"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
dependencies = [
"pandas>=2.2.3",
"pydantic>=2.11.1",
Expand Down
45 changes: 36 additions & 9 deletions src/graphomotor/core/config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,47 @@
"""Configuration module for Graphomotor."""

import dataclasses
import logging
import warnings

import numpy as np


class _SpiralConfig:
"""Configuration for the reference spiral generation."""
@dataclasses.dataclass
class SpiralConfig:
"""Class for the parameters of anticipated spiral drawing."""

SPIRAL_CENTER_X = 50
SPIRAL_CENTER_Y = 50
SPIRAL_START_RADIUS = 0
SPIRAL_GROWTH_RATE = 1.075
SPIRAL_START_ANGLE = 0
SPIRAL_END_ANGLE = 8 * np.pi
SPIRAL_NUM_POINTS = 10000
center_x: float = 50
center_y: float = 50
start_radius: float = 0
growth_rate: float = 1.075
start_angle: float = 0
end_angle: float = 8 * np.pi
num_points: int = 10000

@classmethod
def add_custom_params(cls, config_dict: dict[str, float | int]) -> "SpiralConfig":
"""Update the SpiralConfig instance with custom parameters.

Args:
config_dict: Dictionary with configuration parameters.

Returns:
SpiralConfig instance with updated parameters.
"""
config = cls()
for key, value in config_dict.items():
if hasattr(config, key):
setattr(config, key, value)
else:
valid_params = ", ".join(
f.name for f in cls.__dataclass_fields__.values()
)
warnings.warn(
f"Unknown configuration parameters will be ignored: {key}. "
f"Valid parameters are: {valid_params}"
)
return config


def get_logger() -> logging.Logger:
Expand Down
64 changes: 56 additions & 8 deletions src/graphomotor/core/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Internal data class for spiral drawing data."""

from datetime import datetime
import datetime
import typing

import numpy as np
import pandas as pd
from pydantic import BaseModel, ConfigDict, field_validator
import pydantic


class Spiral(BaseModel):
"""A class representing a spiral drawing, encapsulating both raw data and metadata.
class Spiral(pydantic.BaseModel):
"""Class representing a spiral drawing, encapsulating both raw data and metadata.

Attributes:
data: DataFrame containing drawing data with required columns (line_number, x,
Expand All @@ -19,12 +21,12 @@ class Spiral(BaseModel):
- start_time: Start time of drawing.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)

data: pd.DataFrame
metadata: dict[str, str | datetime]
metadata: dict[str, str | datetime.datetime]

@field_validator("data")
@pydantic.field_validator("data")
@classmethod
def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
"""Validate that DataFrame contains required columns and correct data types.
Expand All @@ -44,7 +46,7 @@ def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:

return v

@field_validator("metadata")
@pydantic.field_validator("metadata")
@classmethod
def validate_metadata(cls, v: dict) -> dict:
"""Validate metadata dictionary for required keys and correct data types.
Expand Down Expand Up @@ -77,3 +79,49 @@ def validate_metadata(cls, v: dict) -> dict:
)

return v


class FeatureCategories:
"""Class to hold valid feature categories for Graphomotor."""

DURATION = "duration"
VELOCITY = "velocity"
HAUSDORFF = "hausdorff"
AUC = "AUC"

@classmethod
def all(cls) -> set[str]:
"""Return all valid feature categories."""
return {
cls.DURATION,
cls.VELOCITY,
cls.HAUSDORFF,
cls.AUC,
}

@classmethod
def get_extractors(
cls, spiral: Spiral, reference_spiral: np.ndarray
) -> dict[str, typing.Callable[[], dict[str, float]]]:
"""Get all feature extractors with appropriate inputs.

Args:
spiral: The spiral data to extract features from.
reference_spiral: Reference spiral for comparison-based metrics.

Returns:
Dictionary mapping category names to their feature extractor functions.
"""
# Importing feature modules here to avoid circular imports.
from graphomotor.features import distance, drawing_error, time, velocity

return {
cls.DURATION: lambda: time.get_task_duration(spiral),
cls.VELOCITY: lambda: velocity.calculate_velocity_metrics(spiral),
cls.HAUSDORFF: lambda: distance.calculate_hausdorff_metrics(
spiral, reference_spiral
),
cls.AUC: lambda: drawing_error.calculate_area_under_curve(
spiral, reference_spiral
),
}
Loading