Skip to content

Commit 2cee2aa

Browse files
authored
Merge pull request #27 from childmindresearch/alperkent/issue23
feature/issue-23/implement-orchestrator
2 parents 4af292f + a84ccf7 commit 2cee2aa

22 files changed

Lines changed: 6899 additions & 413 deletions

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
fail-fast: false
1313
matrix:
1414
os: [ubuntu-latest, windows-latest, macos-latest]
15-
python-version: [3.12, 3.13]
15+
python-version: [3.11, 3.12, 3.13]
1616
dependency-mode: [lowest-direct, highest]
1717
runs-on: ${{ matrix.os }}
1818
steps:

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# experiments
2+
experiments/
3+
14
# vscode
25
.vscode/
36

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = [
77
]
88
license = "LGPL-2.1"
99
readme = "README.md"
10-
requires-python = ">=3.12"
10+
requires-python = ">=3.11"
1111
dependencies = [
1212
"pandas>=2.2.3",
1313
"pydantic>=2.11.1",

src/graphomotor/core/config.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,47 @@
11
"""Configuration module for Graphomotor."""
22

3+
import dataclasses
34
import logging
5+
import warnings
46

57
import numpy as np
68

79

8-
class _SpiralConfig:
9-
"""Configuration for the reference spiral generation."""
10+
@dataclasses.dataclass
11+
class SpiralConfig:
12+
"""Class for the parameters of anticipated spiral drawing."""
1013

11-
SPIRAL_CENTER_X = 50
12-
SPIRAL_CENTER_Y = 50
13-
SPIRAL_START_RADIUS = 0
14-
SPIRAL_GROWTH_RATE = 1.075
15-
SPIRAL_START_ANGLE = 0
16-
SPIRAL_END_ANGLE = 8 * np.pi
17-
SPIRAL_NUM_POINTS = 10000
14+
center_x: float = 50
15+
center_y: float = 50
16+
start_radius: float = 0
17+
growth_rate: float = 1.075
18+
start_angle: float = 0
19+
end_angle: float = 8 * np.pi
20+
num_points: int = 10000
21+
22+
@classmethod
23+
def add_custom_params(cls, config_dict: dict[str, float | int]) -> "SpiralConfig":
24+
"""Update the SpiralConfig instance with custom parameters.
25+
26+
Args:
27+
config_dict: Dictionary with configuration parameters.
28+
29+
Returns:
30+
SpiralConfig instance with updated parameters.
31+
"""
32+
config = cls()
33+
for key, value in config_dict.items():
34+
if hasattr(config, key):
35+
setattr(config, key, value)
36+
else:
37+
valid_params = ", ".join(
38+
f.name for f in cls.__dataclass_fields__.values()
39+
)
40+
warnings.warn(
41+
f"Unknown configuration parameters will be ignored: {key}. "
42+
f"Valid parameters are: {valid_params}"
43+
)
44+
return config
1845

1946

2047
def get_logger() -> logging.Logger:

src/graphomotor/core/models.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Internal data class for spiral drawing data."""
22

3-
from datetime import datetime
3+
import datetime
4+
import typing
45

6+
import numpy as np
57
import pandas as pd
6-
from pydantic import BaseModel, ConfigDict, field_validator
8+
import pydantic
79

810

9-
class Spiral(BaseModel):
10-
"""A class representing a spiral drawing, encapsulating both raw data and metadata.
11+
class Spiral(pydantic.BaseModel):
12+
"""Class representing a spiral drawing, encapsulating both raw data and metadata.
1113
1214
Attributes:
1315
data: DataFrame containing drawing data with required columns (line_number, x,
@@ -19,12 +21,12 @@ class Spiral(BaseModel):
1921
- start_time: Start time of drawing.
2022
"""
2123

22-
model_config = ConfigDict(arbitrary_types_allowed=True)
24+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
2325

2426
data: pd.DataFrame
25-
metadata: dict[str, str | datetime]
27+
metadata: dict[str, str | datetime.datetime]
2628

27-
@field_validator("data")
29+
@pydantic.field_validator("data")
2830
@classmethod
2931
def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
3032
"""Validate that DataFrame contains required columns and correct data types.
@@ -44,7 +46,7 @@ def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
4446

4547
return v
4648

47-
@field_validator("metadata")
49+
@pydantic.field_validator("metadata")
4850
@classmethod
4951
def validate_metadata(cls, v: dict) -> dict:
5052
"""Validate metadata dictionary for required keys and correct data types.
@@ -77,3 +79,49 @@ def validate_metadata(cls, v: dict) -> dict:
7779
)
7880

7981
return v
82+
83+
84+
class FeatureCategories:
85+
"""Class to hold valid feature categories for Graphomotor."""
86+
87+
DURATION = "duration"
88+
VELOCITY = "velocity"
89+
HAUSDORFF = "hausdorff"
90+
AUC = "AUC"
91+
92+
@classmethod
93+
def all(cls) -> set[str]:
94+
"""Return all valid feature categories."""
95+
return {
96+
cls.DURATION,
97+
cls.VELOCITY,
98+
cls.HAUSDORFF,
99+
cls.AUC,
100+
}
101+
102+
@classmethod
103+
def get_extractors(
104+
cls, spiral: Spiral, reference_spiral: np.ndarray
105+
) -> dict[str, typing.Callable[[], dict[str, float]]]:
106+
"""Get all feature extractors with appropriate inputs.
107+
108+
Args:
109+
spiral: The spiral data to extract features from.
110+
reference_spiral: Reference spiral for comparison-based metrics.
111+
112+
Returns:
113+
Dictionary mapping category names to their feature extractor functions.
114+
"""
115+
# Importing feature modules here to avoid circular imports.
116+
from graphomotor.features import distance, drawing_error, time, velocity
117+
118+
return {
119+
cls.DURATION: lambda: time.get_task_duration(spiral),
120+
cls.VELOCITY: lambda: velocity.calculate_velocity_metrics(spiral),
121+
cls.HAUSDORFF: lambda: distance.calculate_hausdorff_metrics(
122+
spiral, reference_spiral
123+
),
124+
cls.AUC: lambda: drawing_error.calculate_area_under_curve(
125+
spiral, reference_spiral
126+
),
127+
}

0 commit comments

Comments
 (0)