Skip to content

Commit 33abc0f

Browse files
committed
Refactor orchestrator module to add Literal type hinting for config_params argument of run_pipeline; update imports across multiple modules for consistency of importing only at module-level.
1 parent ec77e5c commit 33abc0f

6 files changed

Lines changed: 49 additions & 35 deletions

File tree

src/graphomotor/core/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Configuration module for Graphomotor."""
22

3+
import dataclasses
34
import logging
45
import warnings
5-
from dataclasses import dataclass
66

77
import numpy as np
88

99

10-
@dataclass
10+
@dataclasses.dataclass
1111
class SpiralConfig:
1212
"""Class for the parameters of anticipated spiral drawing."""
1313

src/graphomotor/core/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Internal data class for spiral drawing data."""
22

3-
from datetime import datetime
4-
from typing import Callable
3+
import datetime
4+
import typing
55

66
import numpy as np
77
import pandas as pd
8-
from pydantic import BaseModel, ConfigDict, field_validator
8+
import pydantic
99

1010

11-
class Spiral(BaseModel):
11+
class Spiral(pydantic.BaseModel):
1212
"""Class representing a spiral drawing, encapsulating both raw data and metadata.
1313
1414
Attributes:
@@ -21,12 +21,12 @@ class Spiral(BaseModel):
2121
- start_time: Start time of drawing.
2222
"""
2323

24-
model_config = ConfigDict(arbitrary_types_allowed=True)
24+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
2525

2626
data: pd.DataFrame
27-
metadata: dict[str, str | datetime]
27+
metadata: dict[str, str | datetime.datetime]
2828

29-
@field_validator("data")
29+
@pydantic.field_validator("data")
3030
@classmethod
3131
def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
3232
"""Validate that DataFrame contains required columns and correct data types.
@@ -46,7 +46,7 @@ def validate_dataframe(cls, v: pd.DataFrame) -> pd.DataFrame:
4646

4747
return v
4848

49-
@field_validator("metadata")
49+
@pydantic.field_validator("metadata")
5050
@classmethod
5151
def validate_metadata(cls, v: dict) -> dict:
5252
"""Validate metadata dictionary for required keys and correct data types.
@@ -102,7 +102,7 @@ def all(cls) -> set[str]:
102102
@classmethod
103103
def get_extractors(
104104
cls, spiral: Spiral, reference_spiral: np.ndarray
105-
) -> dict[str, Callable[[], dict[str, float]]]:
105+
) -> dict[str, typing.Callable[[], dict[str, float]]]:
106106
"""Get all feature extractors with appropriate inputs.
107107
108108
Args:

src/graphomotor/core/orchestrator.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Runner for the Graphomotor pipeline."""
22

3+
import datetime
34
import os
45
import pathlib
5-
from datetime import datetime
6-
from typing import Literal
6+
import typing
77

88
import numpy as np
99
import pandas as pd
@@ -14,7 +14,7 @@
1414

1515
logger = config.get_logger()
1616

17-
FeatureCategories = Literal["duration", "velocity", "hausdorff", "AUC"]
17+
FeatureCategories = typing.Literal["duration", "velocity", "hausdorff", "AUC"]
1818

1919

2020
def _ensure_path(path: pathlib.Path | str) -> pathlib.Path:
@@ -122,7 +122,7 @@ def _export_features_to_csv(
122122

123123
filename = (
124124
f"{participant_id}_{task}_{hand}_features_"
125-
f"{datetime.today().strftime('%Y%m%d')}.csv"
125+
f"{datetime.datetime.today().strftime('%Y%m%d')}.csv"
126126
)
127127

128128
if not output_path.suffix:
@@ -220,7 +220,19 @@ def run_pipeline(
220220
"hausdorff",
221221
"AUC",
222222
],
223-
config_params: dict[str, float | int] | None = None,
223+
config_params: dict[
224+
typing.Literal[
225+
"center_x",
226+
"center_y",
227+
"start_radius",
228+
"growth_rate",
229+
"start_angle",
230+
"end_angle",
231+
"num_points",
232+
],
233+
float | int,
234+
]
235+
| None = None,
224236
) -> dict[str, str]:
225237
"""Run the Graphomotor pipeline to extract features from spiral drawings.
226238
@@ -238,14 +250,14 @@ def run_pipeline(
238250
- "AUC": Area under the curve metric
239251
config_params: Optional dictionary with custom spiral configuration parameters.
240252
These parameters control reference spiral generation and spiral centering.
241-
If None, default parameters are used. Supported parameters are:
242-
- "center_x": X-coordinate of the spiral center. Default is 50.
243-
- "center_y": Y-coordinate of the spiral center. Default is 50.
244-
- "start_radius": Starting radius of the spiral. Default is 0.
245-
- "growth_rate": Growth rate of the spiral. Default is 1.075.
246-
- "start_angle": Starting angle of the spiral. Default is 0.
247-
- "end_angle": Ending angle of the spiral. Default is 8π.
248-
- "num_points": Number of points in the spiral. Default is 10000.
253+
If None, default configuration is used. Supported parameters are:
254+
- "center_x" (float): X-coordinate of the spiral center. Default is 50.
255+
- "center_y" (float): Y-coordinate of the spiral center. Default is 50.
256+
- "start_radius" (float): Starting radius of the spiral. Default is 0.
257+
- "growth_rate" (float): Growth rate of the spiral. Default is 1.075.
258+
- "start_angle" (float): Starting angle of the spiral. Default is 0.
259+
- "end_angle" (float): Ending angle of the spiral. Default is 8π.
260+
- "num_points" (int): Number of points in the spiral. Default is 10000.
249261
250262
Returns:
251263
Dictionary of extracted features.
@@ -258,7 +270,9 @@ def run_pipeline(
258270
spiral_config = None
259271
if config_params:
260272
logger.info(f"Custom spiral configuration: {config_params}")
261-
spiral_config = config.SpiralConfig.add_custom_params(config_params)
273+
spiral_config = config.SpiralConfig.add_custom_params(
274+
typing.cast(dict, config_params)
275+
)
262276

263277
features = extract_features(
264278
input_path, output_path, feature_categories, spiral_config

src/graphomotor/utils/center_spiral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Utility functions for centering a spiral."""
22

3-
from typing import overload
3+
import typing
44

55
import numpy as np
66

77
from graphomotor.core import config, models
88

99

10-
@overload
10+
@typing.overload
1111
def center_spiral(spiral: models.Spiral) -> models.Spiral: ...
12-
@overload
12+
@typing.overload
1313
def center_spiral(spiral: np.ndarray) -> np.ndarray: ...
1414
def center_spiral(spiral):
1515
"""Center a spiral by translating it to the origin.

tests/unit/test_center_spiral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Test cases for center_spiral.py functions."""
22

3-
from typing import cast
3+
import typing
44

55
import numpy as np
66
import pytest
@@ -42,4 +42,4 @@ def test_center_spiral_invalid_type(
4242
TypeError,
4343
match=f"Expected models.Spiral or np.ndarray, got {expected_type_name}",
4444
):
45-
center_spiral.center_spiral(cast(models.Spiral, invalid_input))
45+
center_spiral.center_spiral(typing.cast(models.Spiral, invalid_input))

tests/unit/test_orchestrator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Tests for the orchestrator module."""
22

3+
import datetime
34
import pathlib
4-
from datetime import datetime
5-
from typing import cast
5+
import typing
66

77
import numpy as np
88
import pytest
@@ -48,7 +48,7 @@ def test_validate_feature_categories_only_invalid(
4848

4949
def test_validate_feature_categories_mixed(caplog: pytest.LogCaptureFixture) -> None:
5050
"""Test _validate_feature_categories with mix of valid and invalid categories."""
51-
feature_categories = cast(
51+
feature_categories = typing.cast(
5252
list[orchestrator.FeatureCategories],
5353
[
5454
"duration",
@@ -136,7 +136,7 @@ def test_export_features_to_csv_no_extension_dir_exists(
136136

137137
expected_filename = (
138138
f"{valid_spiral.metadata['id']}_{valid_spiral.metadata['task']}_{valid_spiral.metadata['hand']}_"
139-
f"features_{datetime.today().strftime('%Y%m%d')}.csv"
139+
f"features_{datetime.datetime.today().strftime('%Y%m%d')}.csv"
140140
)
141141

142142
expected_output_path = output_path / expected_filename
@@ -162,7 +162,7 @@ def test_export_features_to_csv_no_extension_no_dir(
162162

163163
expected_filename = (
164164
f"{valid_spiral.metadata['id']}_{valid_spiral.metadata['task']}_{valid_spiral.metadata['hand']}_"
165-
f"features_{datetime.today().strftime('%Y%m%d')}.csv"
165+
f"features_{datetime.datetime.today().strftime('%Y%m%d')}.csv"
166166
)
167167

168168
expected_output_path = output_path / expected_filename

0 commit comments

Comments
 (0)