11"""Internal data class for spiral drawing data."""
22
3- from datetime import datetime
3+ import datetime
4+ import typing
45
6+ import numpy as np
57import pandas as pd
6- from pydantic import BaseModel , ConfigDict , field_validator
8+ import pydantic
79
810
9- class Spiral (BaseModel ):
10- """A class representing a spiral drawing, encapsulating both raw data and metadata.
11+ class Spiral (pydantic . BaseModel ):
12+ """Class representing a spiral drawing, encapsulating both raw data and metadata.
1113
1214 Attributes:
1315 data: DataFrame containing drawing data with required columns (line_number, x,
@@ -19,12 +21,12 @@ class Spiral(BaseModel):
1921 - start_time: Start time of drawing.
2022 """
2123
22- model_config = ConfigDict (arbitrary_types_allowed = True )
24+ model_config = pydantic . ConfigDict (arbitrary_types_allowed = True )
2325
2426 data : pd .DataFrame
25- metadata : dict [str , str | datetime ]
27+ metadata : dict [str , str | datetime . datetime ]
2628
27- @field_validator ("data" )
29+ @pydantic . field_validator ("data" )
2830 @classmethod
2931 def validate_dataframe (cls , v : pd .DataFrame ) -> pd .DataFrame :
3032 """Validate that DataFrame contains required columns and correct data types.
@@ -44,7 +46,7 @@ def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
4446
4547 return v
4648
47- @field_validator ("metadata" )
49+ @pydantic . field_validator ("metadata" )
4850 @classmethod
4951 def validate_metadata (cls , v : dict ) -> dict :
5052 """Validate metadata dictionary for required keys and correct data types.
@@ -77,3 +79,49 @@ def validate_metadata(cls, v: dict) -> dict:
7779 )
7880
7981 return v
82+
83+
84+ class FeatureCategories :
85+ """Class to hold valid feature categories for Graphomotor."""
86+
87+ DURATION = "duration"
88+ VELOCITY = "velocity"
89+ HAUSDORFF = "hausdorff"
90+ AUC = "AUC"
91+
92+ @classmethod
93+ def all (cls ) -> set [str ]:
94+ """Return all valid feature categories."""
95+ return {
96+ cls .DURATION ,
97+ cls .VELOCITY ,
98+ cls .HAUSDORFF ,
99+ cls .AUC ,
100+ }
101+
102+ @classmethod
103+ def get_extractors (
104+ cls , spiral : Spiral , reference_spiral : np .ndarray
105+ ) -> dict [str , typing .Callable [[], dict [str , float ]]]:
106+ """Get all feature extractors with appropriate inputs.
107+
108+ Args:
109+ spiral: The spiral data to extract features from.
110+ reference_spiral: Reference spiral for comparison-based metrics.
111+
112+ Returns:
113+ Dictionary mapping category names to their feature extractor functions.
114+ """
115+ # Importing feature modules here to avoid circular imports.
116+ from graphomotor .features import distance , drawing_error , time , velocity
117+
118+ return {
119+ cls .DURATION : lambda : time .get_task_duration (spiral ),
120+ cls .VELOCITY : lambda : velocity .calculate_velocity_metrics (spiral ),
121+ cls .HAUSDORFF : lambda : distance .calculate_hausdorff_metrics (
122+ spiral , reference_spiral
123+ ),
124+ cls .AUC : lambda : drawing_error .calculate_area_under_curve (
125+ spiral , reference_spiral
126+ ),
127+ }
0 commit comments