1- """Runner for graphomotor ."""
1+ """Orchestrator for the Graphomotor spiral feature extraction pipeline ."""
22
33import dataclasses
44import datetime
55import pathlib
6+ import re
67import time
78import typing
89
2829]
2930
3031
32+ def _parse_spiral_metadata (filename : str ) -> dict [str , str ]:
33+ """Extract metadata from spiral drawing filename.
34+
35+ The function parses filenames of Curious exports of drawing data that are
36+ typically formatted as '[5123456]curious-ID-spiral_trace2_NonDom'. It extracts
37+ the participant ID (the value within the brackets), task name ('spiral_trace' or
38+ 'spiral_recall', followed by the trial number from 1 to 5), and hand used (dominant
39+ or non-dominant). Regular expressions are used to match the expected pattern
40+ and extract the relevant components.
41+
42+
43+ Args:
44+ filename: Filename of the spiral drawing CSV file from Curious export.
45+
46+ Returns:
47+ Dictionary containing extracted metadata:
48+ - id: Participant ID (e.g., '5123456')
49+ - hand: Hand used for drawing ('Dom' or 'NonDom')
50+ - task: Task name and trial number (e.g., 'spiral_trace2')
51+
52+ Raises:
53+ ValueError: If filename does not match expected pattern.
54+ """
55+ pattern = r"\[(\d+)\].*-([^_]+)_([^_]+)_(\w+)$"
56+ match = re .match (pattern , filename )
57+
58+ if match :
59+ id , task_name , task_detail , hand = match .groups ()
60+ metadata = {
61+ "id" : id ,
62+ "hand" : hand ,
63+ "task" : f"{ task_name } _{ task_detail } " ,
64+ }
65+ return metadata
66+
67+ raise ValueError (f"Filename does not match expected pattern: { filename } " )
68+
69+
70+ def _validate_spiral_metadata (metadata : dict [str , str | datetime .datetime ]) -> None :
71+ """Validate metadata extracted from spiral drawing filename.
72+
73+ Args:
74+ metadata: Dictionary containing extracted metadata.
75+
76+ Raises:
77+ ValueError: If metadata is invalid.
78+ """
79+ if metadata ["hand" ] not in ["Dom" , "NonDom" ]:
80+ raise ValueError ("'hand' must be either 'Dom' or 'NonDom'" )
81+
82+ valid_tasks = ["spiral_trace" , "spiral_recall" ]
83+ valid_tasks_trials = [f"{ prefix } { i } " for prefix in valid_tasks for i in range (1 , 6 )]
84+ if metadata ["task" ] not in valid_tasks_trials :
85+ raise ValueError (
86+ "'task' must be either 'spiral_trace' or 'spiral_recall', numbered 1-5"
87+ )
88+
89+
3190def _validate_feature_categories (
3291 feature_categories : list [FeatureCategories ],
3392) -> set [str ]:
@@ -43,7 +102,7 @@ def _validate_feature_categories(
43102 ValueError: If no valid feature categories are provided.
44103 """
45104 feature_categories_set : set [str ] = set (feature_categories )
46- supported_categories_set = models .FeatureCategories .all ()
105+ supported_categories_set = models .SpiralFeatureCategories .all ()
47106 unknown_categories = feature_categories_set - supported_categories_set
48107 valid_requested_categories = feature_categories_set & supported_categories_set
49108
@@ -65,7 +124,7 @@ def _validate_feature_categories(
65124
66125
67126def extract_features (
68- spiral : models .Spiral ,
127+ spiral : models .Drawing ,
69128 feature_categories : list [str ],
70129 reference_spiral : np .ndarray ,
71130) -> dict [str , str ]:
@@ -83,7 +142,7 @@ def extract_features(
83142 Returns:
84143 Dictionary containing the extracted features with metadata.
85144 """
86- feature_extractors = models .FeatureCategories .get_extractors (
145+ feature_extractors = models .SpiralFeatureCategories .get_extractors (
87146 spiral , reference_spiral
88147 )
89148
@@ -168,7 +227,9 @@ def _run_file(
168227 Returns:
169228 Dictionary containing the extracted features with metadata.
170229 """
171- spiral = reader .load_spiral (input_path )
230+ spiral = reader .load_drawing_data (input_path )
231+ spiral .metadata .update (_parse_spiral_metadata (input_path .stem ))
232+ _validate_spiral_metadata (spiral .metadata )
172233 centered_spiral = center_spiral .center_spiral (spiral )
173234 reference_spiral = generate_reference_spiral .generate_reference_spiral (
174235 spiral_config
@@ -326,7 +387,7 @@ def run_pipeline(
326387 valid_categories = sorted (_validate_feature_categories (feature_categories ))
327388 logger .debug (f"Requested feature categories: { valid_categories } " )
328389 else :
329- valid_categories = [* models .FeatureCategories .all ()]
390+ valid_categories = [* models .SpiralFeatureCategories .all ()]
330391 logger .debug (f"Using default feature categories: { valid_categories } " )
331392
332393 if config_params and config_params != dataclasses .asdict (config .SpiralConfig ()):
0 commit comments