Skip to content

Commit 94675c2

Browse files
authored
Refactor/issue 53/refactor spiral specific to generic (#54)
* Initial changes * Update pyproject.toml * More fixes * Typos * Update test_models.py * Change to task specific orchestrator * Update reader.py Adding all possible columns * Refactor reader for any drawing columns * Lots of changes * Update reader.py mypy fix * Create test_io_utils.py Added tests for missing columns * Renaming for deidentification * Renaming and fixing tests * Addressing PR comments * Update spiral_orchestrator.py
1 parent 8a009f2 commit 94675c2

31 files changed

Lines changed: 2386 additions & 271 deletions

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
[project]
22
name = "graphomotor"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "A Python toolkit for analysis of graphomotor data collected via Curious"
55
authors = [
66
{name = "Alp Erkent", email = "alp.erkent@childmind.org"},
7+
{name = "Celia Maiorano", email = "celia.maiorano@childmind.org"},
8+
{name = "Iktae Kim", email = "iktae.kim@childmind.org"},
79
{name = "Adam Santorelli", email = "adam.santorelli@childmind.org"}
810
]
911
license = "LGPL-2.1"

src/graphomotor/core/cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import typer
88

9-
from graphomotor.core import config, orchestrator
9+
from graphomotor.core import config, spiral_orchestrator
1010
from graphomotor.plot import feature_plots, spiral_plots
1111

1212
logger = config.get_logger()
@@ -85,7 +85,7 @@ def main(
8585

8686

8787
@app.command(
88-
name="extract",
88+
name="extract-spiral",
8989
help=(
9090
"Extract features from spiral drawing data. "
9191
"Supports both single-file and batch (directory) processing."
@@ -95,7 +95,7 @@ def main(
9595
"https://github.com/childmindresearch/graphomotor?tab=readme-ov-file#feature-extraction."
9696
),
9797
)
98-
def extract(
98+
def extract_spiral(
9999
input_path: typing.Annotated[
100100
pathlib.Path,
101101
typer.Argument(
@@ -196,7 +196,7 @@ def extract(
196196
"""Extract features from spiral drawing data."""
197197
logger.debug(f"Running Graphomotor pipeline with these arguments: {locals()}")
198198

199-
config_params: dict[orchestrator.ConfigParams, float | int] = {
199+
config_params: dict[spiral_orchestrator.ConfigParams, float | int] = {
200200
"center_x": center_x,
201201
"center_y": center_y,
202202
"start_radius": start_radius,
@@ -207,11 +207,11 @@ def extract(
207207
}
208208

209209
try:
210-
orchestrator.run_pipeline(
210+
spiral_orchestrator.run_pipeline(
211211
input_path=input_path,
212212
output_path=output_path,
213213
feature_categories=typing.cast(
214-
list[orchestrator.FeatureCategories], features
214+
list[spiral_orchestrator.FeatureCategories], features
215215
),
216216
config_params=config_params,
217217
)

src/graphomotor/core/models.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import pydantic
99

1010

11-
class Spiral(pydantic.BaseModel):
12-
"""Class representing a spiral drawing, encapsulating both raw data and metadata.
11+
class Drawing(pydantic.BaseModel):
12+
"""Class representing a drawing task, encapsulating both raw data and metadata.
1313
1414
Attributes:
1515
data: DataFrame containing drawing data with required columns (line_number, x,
1616
y, UTC_Timestamp, seconds).
17+
task_name: Name of the drawing task (e.g., 'spiral', 'trails', etc.).
1718
metadata: Dictionary containing metadata about the spiral:
1819
- id: Unique identifier for the participant,
1920
- hand: Hand used ('Dom' for dominant, 'NonDom' for non-dominant),
@@ -25,6 +26,7 @@ class Spiral(pydantic.BaseModel):
2526
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
2627

2728
data: pd.DataFrame
29+
task_name: str
2830
metadata: dict[str, str | datetime.datetime]
2931

3032
@pydantic.field_validator("data")
@@ -67,22 +69,10 @@ def validate_metadata(cls, v: dict) -> dict:
6769
if len(v["id"]) != 7:
6870
raise ValueError("'id' must be 7 digits long")
6971

70-
if v["hand"] not in ["Dom", "NonDom"]:
71-
raise ValueError("'hand' must be either 'Dom' or 'NonDom'")
72-
73-
valid_tasks = ["spiral_trace", "spiral_recall"]
74-
valid_tasks_trials = [
75-
f"{prefix}{i}" for prefix in valid_tasks for i in range(1, 6)
76-
]
77-
if v["task"] not in valid_tasks_trials:
78-
raise ValueError(
79-
"'task' must be either 'spiral_trace' or 'spiral_recall', numbered 1-5"
80-
)
81-
8272
return v
8373

8474

85-
class FeatureCategories:
75+
class SpiralFeatureCategories:
8676
"""Class to hold valid feature categories for Graphomotor."""
8777

8878
DURATION = "duration"
@@ -102,7 +92,7 @@ def all(cls) -> set[str]:
10292

10393
@classmethod
10494
def get_extractors(
105-
cls, spiral: Spiral, reference_spiral: np.ndarray
95+
cls, spiral: Drawing, reference_spiral: np.ndarray
10696
) -> dict[str, typing.Callable[[], dict[str, float]]]:
10797
"""Get all feature extractors with appropriate inputs.
10898

src/graphomotor/core/orchestrator.py renamed to src/graphomotor/core/spiral_orchestrator.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
"""Runner for graphomotor."""
1+
"""Orchestrator for the Graphomotor spiral feature extraction pipeline."""
22

33
import dataclasses
44
import datetime
55
import pathlib
6+
import re
67
import time
78
import typing
89

@@ -28,6 +29,64 @@
2829
]
2930

3031

32+
def _parse_spiral_metadata(filename: str) -> dict[str, str]:
33+
"""Extract metadata from spiral drawing filename.
34+
35+
The function parses filenames of Curious exports of drawing data that are
36+
typically formatted as '[5123456]curious-ID-spiral_trace2_NonDom'. It extracts
37+
the participant ID (the value within the brackets), task name ('spiral_trace' or
38+
'spiral_recall', followed by the trial number from 1 to 5), and hand used (dominant
39+
or non-dominant). Regular expressions are used to match the expected pattern
40+
and extract the relevant components.
41+
42+
43+
Args:
44+
filename: Filename of the spiral drawing CSV file from Curious export.
45+
46+
Returns:
47+
Dictionary containing extracted metadata:
48+
- id: Participant ID (e.g., '5123456')
49+
- hand: Hand used for drawing ('Dom' or 'NonDom')
50+
- task: Task name and trial number (e.g., 'spiral_trace2')
51+
52+
Raises:
53+
ValueError: If filename does not match expected pattern.
54+
"""
55+
pattern = r"\[(\d+)\].*-([^_]+)_([^_]+)_(\w+)$"
56+
match = re.match(pattern, filename)
57+
58+
if match:
59+
id, task_name, task_detail, hand = match.groups()
60+
metadata = {
61+
"id": id,
62+
"hand": hand,
63+
"task": f"{task_name}_{task_detail}",
64+
}
65+
return metadata
66+
67+
raise ValueError(f"Filename does not match expected pattern: {filename}")
68+
69+
70+
def _validate_spiral_metadata(metadata: dict[str, str | datetime.datetime]) -> None:
71+
"""Validate metadata extracted from spiral drawing filename.
72+
73+
Args:
74+
metadata: Dictionary containing extracted metadata.
75+
76+
Raises:
77+
ValueError: If metadata is invalid.
78+
"""
79+
if metadata["hand"] not in ["Dom", "NonDom"]:
80+
raise ValueError("'hand' must be either 'Dom' or 'NonDom'")
81+
82+
valid_tasks = ["spiral_trace", "spiral_recall"]
83+
valid_tasks_trials = [f"{prefix}{i}" for prefix in valid_tasks for i in range(1, 6)]
84+
if metadata["task"] not in valid_tasks_trials:
85+
raise ValueError(
86+
"'task' must be either 'spiral_trace' or 'spiral_recall', numbered 1-5"
87+
)
88+
89+
3190
def _validate_feature_categories(
3291
feature_categories: list[FeatureCategories],
3392
) -> set[str]:
@@ -43,7 +102,7 @@ def _validate_feature_categories(
43102
ValueError: If no valid feature categories are provided.
44103
"""
45104
feature_categories_set: set[str] = set(feature_categories)
46-
supported_categories_set = models.FeatureCategories.all()
105+
supported_categories_set = models.SpiralFeatureCategories.all()
47106
unknown_categories = feature_categories_set - supported_categories_set
48107
valid_requested_categories = feature_categories_set & supported_categories_set
49108

@@ -65,7 +124,7 @@ def _validate_feature_categories(
65124

66125

67126
def extract_features(
68-
spiral: models.Spiral,
127+
spiral: models.Drawing,
69128
feature_categories: list[str],
70129
reference_spiral: np.ndarray,
71130
) -> dict[str, str]:
@@ -83,7 +142,7 @@ def extract_features(
83142
Returns:
84143
Dictionary containing the extracted features with metadata.
85144
"""
86-
feature_extractors = models.FeatureCategories.get_extractors(
145+
feature_extractors = models.SpiralFeatureCategories.get_extractors(
87146
spiral, reference_spiral
88147
)
89148

@@ -168,7 +227,9 @@ def _run_file(
168227
Returns:
169228
Dictionary containing the extracted features with metadata.
170229
"""
171-
spiral = reader.load_spiral(input_path)
230+
spiral = reader.load_drawing_data(input_path)
231+
spiral.metadata.update(_parse_spiral_metadata(input_path.stem))
232+
_validate_spiral_metadata(spiral.metadata)
172233
centered_spiral = center_spiral.center_spiral(spiral)
173234
reference_spiral = generate_reference_spiral.generate_reference_spiral(
174235
spiral_config
@@ -326,7 +387,7 @@ def run_pipeline(
326387
valid_categories = sorted(_validate_feature_categories(feature_categories))
327388
logger.debug(f"Requested feature categories: {valid_categories}")
328389
else:
329-
valid_categories = [*models.FeatureCategories.all()]
390+
valid_categories = [*models.SpiralFeatureCategories.all()]
330391
logger.debug(f"Using default feature categories: {valid_categories}")
331392

332393
if config_params and config_params != dataclasses.asdict(config.SpiralConfig()):

src/graphomotor/features/distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _segment_data(data: np.ndarray, start_prop: float, end_prop: float) -> np.nd
3030

3131

3232
def calculate_hausdorff_metrics(
33-
spiral: models.Spiral, reference_spiral: np.ndarray
33+
spiral: models.Drawing, reference_spiral: np.ndarray
3434
) -> dict[str, float]:
3535
"""Calculate Hausdorff distance metrics for a spiral object.
3636

src/graphomotor/features/drawing_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def calculate_area_under_curve(
10-
drawn_spiral: models.Spiral, reference_spiral: np.ndarray
10+
drawn_spiral: models.Drawing, reference_spiral: np.ndarray
1111
) -> dict[str, float]:
1212
"""Calculate the area between drawn and reference spirals.
1313

src/graphomotor/features/time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from graphomotor.core import models
44

55

6-
def get_task_duration(spiral: models.Spiral) -> dict[str, float]:
6+
def get_task_duration(spiral: models.Drawing) -> dict[str, float]:
77
"""Calculate the total duration of a spiral drawing task.
88
99
Args:

src/graphomotor/features/velocity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _calculate_statistics(values: np.ndarray, name: str) -> dict[str, float]:
2626
}
2727

2828

29-
def calculate_velocity_metrics(spiral: models.Spiral) -> dict[str, float]:
29+
def calculate_velocity_metrics(spiral: models.Drawing) -> dict[str, float]:
3030
"""Calculate velocity-based metrics from spiral drawing data.
3131
3232
This function computes three types of velocity metrics by calculating the difference

src/graphomotor/io/io_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""This module contains utility functions for reading data files."""
2+
3+
import pandas as pd
4+
5+
DTYPE_MAP = {
6+
"line_number": "int",
7+
"x": "float",
8+
"y": "float",
9+
"UTC_Timestamp": "float",
10+
"seconds": "float",
11+
"epoch_time_in_seconds_start": "float",
12+
"error": "str",
13+
"correct_path": "str",
14+
"actual_path": "str",
15+
"total_time": "float",
16+
"total_number_of_errors": "int",
17+
}
18+
19+
20+
def _check_missing_columns(data: pd.DataFrame, task_name: str) -> None:
21+
"""Check for missing columns in the DataFrame.
22+
23+
Args:
24+
data: DataFrame containing spiral drawing data.
25+
task_name: Name of the drawing task.
26+
27+
Raises:
28+
KeyError: If any required columns are missing.
29+
"""
30+
if "trail" in task_name.lower():
31+
required_columns = list(DTYPE_MAP.keys())
32+
else:
33+
required_columns = list(DTYPE_MAP.keys())[:6]
34+
35+
missing_columns = set(required_columns) - set(data.columns)
36+
if missing_columns:
37+
raise KeyError(f"Missing required columns: {', '.join(missing_columns)}")

0 commit comments

Comments
 (0)