Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.12, 3.13]
python-version: [3.11, 3.12, 3.13]
dependency-mode: [lowest-direct, highest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# experiments
experiments/

# vscode
.vscode/

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
]
license = "LGPL-2.1"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
dependencies = [
"pandas>=2.2.3",
"pydantic>=2.11.1",
Expand Down
92 changes: 83 additions & 9 deletions src/graphomotor/core/config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,94 @@
"""Configuration module for Graphomotor."""

import logging
import warnings
from dataclasses import dataclass
from typing import Callable

import numpy as np

from graphomotor.core import models
from graphomotor.features import distance, drawing_error, time, velocity

class _SpiralConfig:
"""Configuration for the reference spiral generation."""

SPIRAL_CENTER_X = 50
SPIRAL_CENTER_Y = 50
SPIRAL_START_RADIUS = 0
SPIRAL_GROWTH_RATE = 1.075
SPIRAL_START_ANGLE = 0
SPIRAL_END_ANGLE = 8 * np.pi
SPIRAL_NUM_POINTS = 10000
class FeatureCategories:
Comment thread
alperkent marked this conversation as resolved.
Outdated
"""Class to hold valid feature categories for Graphomotor."""

DURATION = "duration"
VELOCITY = "velocity"
HAUSDORFF = "hausdorff"
AUC = "AUC"

@classmethod
def all(cls) -> set[str]:
"""Return all valid feature categories."""
return {
cls.DURATION,
cls.VELOCITY,
cls.HAUSDORFF,
cls.AUC,
}

@classmethod
def get_extractors(
cls, spiral: models.Spiral, reference_spiral: np.ndarray
) -> dict[str, Callable[[], dict[str, float]]]:
"""Get all feature extractors with appropriate inputs.

Args:
spiral: The spiral data to extract features from.
reference_spiral: Reference spiral for comparison-based metrics.

Returns:
Dictionary mapping category names to their feature extractor functions.
"""
return {
cls.DURATION: lambda: time.get_task_duration(spiral),
cls.VELOCITY: lambda: velocity.calculate_velocity_metrics(spiral),
cls.HAUSDORFF: lambda: distance.calculate_hausdorff_metrics(
spiral, reference_spiral
),
cls.AUC: lambda: drawing_error.calculate_area_under_curve(
spiral, reference_spiral
),
}


@dataclass
class SpiralConfig:
"""Class for the parameters of anticipated spiral drawing."""

center_x: float = 50
center_y: float = 50
start_radius: float = 0
growth_rate: float = 1.075
start_angle: float = 0
end_angle: float = 8 * np.pi
num_points: int = 10000

@classmethod
def add_custom_params(cls, config_dict: dict[str, float | int]) -> "SpiralConfig":
"""Update the SpiralConfig instance with custom parameters.

Args:
config_dict: Dictionary with configuration parameters.

Returns:
SpiralConfig instance with updated parameters.
"""
config = cls()
for key, value in config_dict.items():
if hasattr(config, key):
setattr(config, key, value)
else:
valid_params = ", ".join(
f.name for f in cls.__dataclass_fields__.values()
)
warnings.warn(
f"Unknown configuration parameters will be ignored: {key}. "
f"Valid parameters are: {valid_params}"
)
return config


def get_logger() -> logging.Logger:
Expand Down
2 changes: 1 addition & 1 deletion src/graphomotor/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class Spiral(BaseModel):
"""A class representing a spiral drawing, encapsulating both raw data and metadata.
"""Class representing a spiral drawing, encapsulating both raw data and metadata.

Attributes:
data: DataFrame containing drawing data with required columns (line_number, x,
Expand Down
251 changes: 251 additions & 0 deletions src/graphomotor/core/orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
"""Runner for the Graphomotor pipeline."""

import os
import pathlib
from datetime import datetime

import numpy as np
import pandas as pd

from graphomotor.core import config, models
from graphomotor.io import reader
from graphomotor.utils import center_spiral, generate_reference_spiral

logger = config.get_logger()


def _ensure_path(path: pathlib.Path | str) -> pathlib.Path:
"""Ensure that the input is a Path object.

Args:
path: Input path, can be string or Path

Returns:
Path object
"""
return pathlib.Path(path) if isinstance(path, str) else path


def _validate_feature_categories(feature_categories: list[str]) -> set[str]:
"""Validate requested feature categories and return valid ones.

Args:
feature_categories: List of feature categories to validate.

Returns:
Set of valid feature categories.

Raises:
ValueError: If no valid feature categories are provided.
"""
unknown_categories = set(feature_categories) - config.FeatureCategories.all()
valid_requested_categories = (
set(feature_categories) & config.FeatureCategories.all()
)

if unknown_categories:
logger.warning(
"Unknown feature categories requested, these categories will be ignored: "
f"{unknown_categories}"
)

if not valid_requested_categories:
error_msg = (
"No valid feature categories provided. "
f"Supported categories: {config.FeatureCategories.all()}"
)
logger.error(error_msg)
raise ValueError(error_msg)

return valid_requested_categories


def _get_feature_categories(
spiral: models.Spiral,
reference_spiral: np.ndarray,
feature_categories: list[str],
) -> dict[str, float]:
"""Feature categories dispatcher.

This function chooses which feature categories to extract based on the provided
sequence of valid category names and returns a dictionary containing the extracted
features.

Args:
spiral: The spiral data to extract features from.
reference_spiral: The reference spiral used for calculating features.
feature_categories: List of feature categories to extract.

Returns:
Dictionary containing the extracted features.
"""
valid_categories = _validate_feature_categories(feature_categories)

feature_extractors = config.FeatureCategories.get_extractors(
spiral, reference_spiral
)

features = {}
for category in valid_categories:
logger.debug(f"Extracting {category} features")
category_features = feature_extractors[category]()
features.update(category_features)
logger.debug(f"{category.capitalize()} features extracted: {category_features}")

return features


def _export_features_to_csv(
spiral: models.Spiral,
features: dict[str, str],
input_path: pathlib.Path,
output_path: pathlib.Path,
) -> None:
"""Export extracted features to a CSV file.

Args:
spiral: The spiral data used for feature extraction.
features: Dictionary containing the extracted features.
input_path: Path to the input CSV file.
output_path: Path to the output CSV file.
"""
logger.info(f"Saving extracted features to {output_path}")

participant_id = spiral.metadata.get("id")
task = spiral.metadata.get("task")
hand = spiral.metadata.get("hand")

filename = (
f"{participant_id}_{task}_{hand}_features_"
f"{datetime.today().strftime('%Y%m%d')}.csv"
)

if not output_path.suffix:
if not os.path.exists(output_path):
logger.info(f"Creating directory that doesn't exist: {output_path}")
os.makedirs(output_path, exist_ok=True)
output_file = output_path / filename
else:
parent_dir = output_path.parent
if not os.path.exists(parent_dir):
logger.info(f"Creating parent directory that doesn't exist: {parent_dir}")
os.makedirs(parent_dir, exist_ok=True)
output_file = output_path

if os.path.exists(output_file):
logger.info(f"Overwriting existing file: {output_file}")

metadata = {
"participant_id": participant_id,
"task": task,
"hand": hand,
"source_file": str(input_path),
}

features_df = pd.DataFrame(
{
"variable": list(metadata.keys()) + list(features.keys()),
"value": list(metadata.values()) + list(features.values()),
}
)

try:
features_df.to_csv(output_file, index=False, header=False)
logger.debug(f"Features saved successfully to {output_file}")
except Exception as e:
Comment thread
alperkent marked this conversation as resolved.
logger.error(f"Failed to save features to {output_file}: {str(e)}")
raise


def extract_features(
input_path: pathlib.Path | str,
output_path: pathlib.Path | str | None,
feature_categories: list[str],
Comment thread
alperkent marked this conversation as resolved.
Outdated
spiral_config: config.SpiralConfig | None,
) -> dict[str, str]:
"""Extract features from spiral drawing data.

Args:
input_path: Path to the input CSV file containing spiral drawing data.
output_path: Path to the output directory for saving extracted features. If
None, features are not saved.
feature_categories: List of feature categories to extract. Valid options are:
- "duration": Extract task duration.
- "velocity": Extract velocity-based metrics.
- "hausdorff": Extract Hausdorff distance metrics.
- "AUC": Extract area under the curve metric.
spiral_config: Optional configuration for spiral parameters. If None, default
parameters are used.

Returns:
Dictionary containing the extracted features.
"""
logger.debug(f"Loading spiral data from {input_path}")
input_path = _ensure_path(input_path)
spiral = reader.load_spiral(input_path)
spiral = center_spiral.center_spiral(spiral)
Comment thread
alperkent marked this conversation as resolved.
Outdated

logger.debug("Generating reference spiral to calculate features")
config_to_use = spiral_config or config.SpiralConfig()
reference_spiral = generate_reference_spiral.generate_reference_spiral(
config=config_to_use
)
reference_spiral = center_spiral.center_spiral(reference_spiral)

features = _get_feature_categories(spiral, reference_spiral, feature_categories)
logger.info(f"Feature extraction complete. Extracted {len(features)} features")

formatted_features = {k: f"{v:.15f}" for k, v in features.items()}
Comment thread
alperkent marked this conversation as resolved.

if output_path:
Comment thread
alperkent marked this conversation as resolved.
output_path = _ensure_path(output_path)
_export_features_to_csv(spiral, formatted_features, input_path, output_path)

return formatted_features


def run_pipeline(
input_path: pathlib.Path | str,
output_path: pathlib.Path | str | None,
feature_categories: list[str],
Comment thread
alperkent marked this conversation as resolved.
Outdated
config_params: dict[str, float | int] | None = None,
) -> dict[str, str]:
"""Run the Graphomotor pipeline to extract features from spiral drawing data.

Args:
input_path: Path to the input CSV file containing spiral drawing data.
output_path: Path where extracted features will be saved. Three behaviors are
possible:
- If None: Features are calculated but not saved to disk.
- If path includes file extension (e.g., "/path/to/output.csv"): This exact
file path is used.
- If path has no file extension (e.g., "/path/to/output"): A file is
created in that directory with auto-generated filename:
"/path/to/output/{participant_id}_{task}_{hand}_features_{date}.csv".
feature_categories: List of feature categories to extract. Options are:
- "duration": Extract task duration.
- "velocity": Extract velocity-based metrics.
- "hausdorff": Extract Hausdorff distance metrics.
- "AUC": Extract area under the curve metric.
config_params: Optional configuration parameters for spiral drawing. If None,
default parameters are used.

Returns:
Dictionary containing the extracted features.
"""
logger.info("Starting Graphomotor pipeline")
logger.info(f"Input path: {input_path}")
logger.info(f"Output path: {output_path}")
logger.info(f"Feature categories: {feature_categories}")

spiral_config = None
if config_params:
logger.info(f"Custom spiral configuration: {config_params}")
spiral_config = config.SpiralConfig.add_custom_params(config_params)

Check warning on line 244 in src/graphomotor/core/orchestrator.py

View check run for this annotation

Codecov / codecov/patch

src/graphomotor/core/orchestrator.py#L243-L244

Added lines #L243 - L244 were not covered by tests

features = extract_features(
input_path, output_path, feature_categories, spiral_config
)

logger.info("Graphomotor pipeline completed successfully")
return features
Loading