-
Notifications
You must be signed in to change notification settings - Fork 0
feature/issue-23/implement-orchestrator #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5daa951
Update Python version requirements to include 3.11 and add orchestrat…
alperkent 7e6b9bf
Refactor config.py to enable customizability of SpiralConfig and add …
alperkent-cmi 70396dd
Refactor center_spiral.py, orchestrator.py, and other various changes
alperkent-cmi 349aa8c
Add experiments directory to .gitignore
alperkent-cmi a19972d
Refactor orchestrator functions to improve naming conventions and enh…
alperkent-cmi 72c9a56
Fix type casting in test_center_spiral_invalid_type to ensure proper …
alperkent-cmi b2f52c3
Refactor test_export_features_to_csv to include exception handling
alperkent-cmi dac5c1b
Move FeatureCategories class to models.py, update type hints in orche…
alperkent-cmi ec77e5c
Refactor run_pipeline function to improve docstring clarity and set d…
alperkent-cmi 33abc0f
Refactor orchestrator module to add `Literal` type hinting for `confi…
alperkent-cmi a84ccf7
Refactor smoke test for `run_pipeline` to be more comprehensive
alperkent-cmi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,6 @@ | ||
| # experiments | ||
| experiments/ | ||
|
|
||
| # vscode | ||
| .vscode/ | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,251 @@ | ||
| """Runner for the Graphomotor pipeline.""" | ||
|
|
||
| import os | ||
| import pathlib | ||
| from datetime import datetime | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from graphomotor.core import config, models | ||
| from graphomotor.io import reader | ||
| from graphomotor.utils import center_spiral, generate_reference_spiral | ||
|
|
||
| logger = config.get_logger() | ||
|
|
||
|
|
||
| def _ensure_path(path: pathlib.Path | str) -> pathlib.Path: | ||
| """Ensure that the input is a Path object. | ||
|
|
||
| Args: | ||
| path: Input path, can be string or Path | ||
|
|
||
| Returns: | ||
| Path object | ||
| """ | ||
| return pathlib.Path(path) if isinstance(path, str) else path | ||
|
|
||
|
|
||
| def _validate_feature_categories(feature_categories: list[str]) -> set[str]: | ||
| """Validate requested feature categories and return valid ones. | ||
|
|
||
| Args: | ||
| feature_categories: List of feature categories to validate. | ||
|
|
||
| Returns: | ||
| Set of valid feature categories. | ||
|
|
||
| Raises: | ||
| ValueError: If no valid feature categories are provided. | ||
| """ | ||
| unknown_categories = set(feature_categories) - config.FeatureCategories.all() | ||
| valid_requested_categories = ( | ||
| set(feature_categories) & config.FeatureCategories.all() | ||
| ) | ||
|
|
||
| if unknown_categories: | ||
| logger.warning( | ||
| "Unknown feature categories requested, these categories will be ignored: " | ||
| f"{unknown_categories}" | ||
| ) | ||
|
|
||
| if not valid_requested_categories: | ||
| error_msg = ( | ||
| "No valid feature categories provided. " | ||
| f"Supported categories: {config.FeatureCategories.all()}" | ||
| ) | ||
| logger.error(error_msg) | ||
| raise ValueError(error_msg) | ||
|
|
||
| return valid_requested_categories | ||
|
|
||
|
|
||
| def _get_feature_categories( | ||
| spiral: models.Spiral, | ||
| reference_spiral: np.ndarray, | ||
| feature_categories: list[str], | ||
| ) -> dict[str, float]: | ||
| """Feature categories dispatcher. | ||
|
|
||
| This function chooses which feature categories to extract based on the provided | ||
| sequence of valid category names and returns a dictionary containing the extracted | ||
| features. | ||
|
|
||
| Args: | ||
| spiral: The spiral data to extract features from. | ||
| reference_spiral: The reference spiral used for calculating features. | ||
| feature_categories: List of feature categories to extract. | ||
|
|
||
| Returns: | ||
| Dictionary containing the extracted features. | ||
| """ | ||
| valid_categories = _validate_feature_categories(feature_categories) | ||
|
|
||
| feature_extractors = config.FeatureCategories.get_extractors( | ||
| spiral, reference_spiral | ||
| ) | ||
|
|
||
| features = {} | ||
| for category in valid_categories: | ||
| logger.debug(f"Extracting {category} features") | ||
| category_features = feature_extractors[category]() | ||
| features.update(category_features) | ||
| logger.debug(f"{category.capitalize()} features extracted: {category_features}") | ||
|
|
||
| return features | ||
|
|
||
|
|
||
| def _export_features_to_csv( | ||
| spiral: models.Spiral, | ||
| features: dict[str, str], | ||
| input_path: pathlib.Path, | ||
| output_path: pathlib.Path, | ||
| ) -> None: | ||
| """Export extracted features to a CSV file. | ||
|
|
||
| Args: | ||
| spiral: The spiral data used for feature extraction. | ||
| features: Dictionary containing the extracted features. | ||
| input_path: Path to the input CSV file. | ||
| output_path: Path to the output CSV file. | ||
| """ | ||
| logger.info(f"Saving extracted features to {output_path}") | ||
|
|
||
| participant_id = spiral.metadata.get("id") | ||
| task = spiral.metadata.get("task") | ||
| hand = spiral.metadata.get("hand") | ||
|
|
||
| filename = ( | ||
| f"{participant_id}_{task}_{hand}_features_" | ||
| f"{datetime.today().strftime('%Y%m%d')}.csv" | ||
| ) | ||
|
|
||
| if not output_path.suffix: | ||
| if not os.path.exists(output_path): | ||
| logger.info(f"Creating directory that doesn't exist: {output_path}") | ||
| os.makedirs(output_path, exist_ok=True) | ||
| output_file = output_path / filename | ||
| else: | ||
| parent_dir = output_path.parent | ||
| if not os.path.exists(parent_dir): | ||
| logger.info(f"Creating parent directory that doesn't exist: {parent_dir}") | ||
| os.makedirs(parent_dir, exist_ok=True) | ||
| output_file = output_path | ||
|
|
||
| if os.path.exists(output_file): | ||
| logger.info(f"Overwriting existing file: {output_file}") | ||
|
|
||
| metadata = { | ||
| "participant_id": participant_id, | ||
| "task": task, | ||
| "hand": hand, | ||
| "source_file": str(input_path), | ||
| } | ||
|
|
||
| features_df = pd.DataFrame( | ||
| { | ||
| "variable": list(metadata.keys()) + list(features.keys()), | ||
| "value": list(metadata.values()) + list(features.values()), | ||
| } | ||
| ) | ||
|
|
||
| try: | ||
| features_df.to_csv(output_file, index=False, header=False) | ||
| logger.debug(f"Features saved successfully to {output_file}") | ||
| except Exception as e: | ||
|
alperkent marked this conversation as resolved.
|
||
| logger.error(f"Failed to save features to {output_file}: {str(e)}") | ||
| raise | ||
|
|
||
|
|
||
| def extract_features( | ||
| input_path: pathlib.Path | str, | ||
| output_path: pathlib.Path | str | None, | ||
| feature_categories: list[str], | ||
|
alperkent marked this conversation as resolved.
Outdated
|
||
| spiral_config: config.SpiralConfig | None, | ||
| ) -> dict[str, str]: | ||
| """Extract features from spiral drawing data. | ||
|
|
||
| Args: | ||
| input_path: Path to the input CSV file containing spiral drawing data. | ||
| output_path: Path to the output directory for saving extracted features. If | ||
| None, features are not saved. | ||
| feature_categories: List of feature categories to extract. Valid options are: | ||
| - "duration": Extract task duration. | ||
| - "velocity": Extract velocity-based metrics. | ||
| - "hausdorff": Extract Hausdorff distance metrics. | ||
| - "AUC": Extract area under the curve metric. | ||
| spiral_config: Optional configuration for spiral parameters. If None, default | ||
| parameters are used. | ||
|
|
||
| Returns: | ||
| Dictionary containing the extracted features. | ||
| """ | ||
| logger.debug(f"Loading spiral data from {input_path}") | ||
| input_path = _ensure_path(input_path) | ||
| spiral = reader.load_spiral(input_path) | ||
| spiral = center_spiral.center_spiral(spiral) | ||
|
alperkent marked this conversation as resolved.
Outdated
|
||
|
|
||
| logger.debug("Generating reference spiral to calculate features") | ||
| config_to_use = spiral_config or config.SpiralConfig() | ||
| reference_spiral = generate_reference_spiral.generate_reference_spiral( | ||
| config=config_to_use | ||
| ) | ||
| reference_spiral = center_spiral.center_spiral(reference_spiral) | ||
|
|
||
| features = _get_feature_categories(spiral, reference_spiral, feature_categories) | ||
| logger.info(f"Feature extraction complete. Extracted {len(features)} features") | ||
|
|
||
| formatted_features = {k: f"{v:.15f}" for k, v in features.items()} | ||
|
alperkent marked this conversation as resolved.
|
||
|
|
||
| if output_path: | ||
|
alperkent marked this conversation as resolved.
|
||
| output_path = _ensure_path(output_path) | ||
| _export_features_to_csv(spiral, formatted_features, input_path, output_path) | ||
|
|
||
| return formatted_features | ||
|
|
||
|
|
||
| def run_pipeline( | ||
| input_path: pathlib.Path | str, | ||
| output_path: pathlib.Path | str | None, | ||
| feature_categories: list[str], | ||
|
alperkent marked this conversation as resolved.
Outdated
|
||
| config_params: dict[str, float | int] | None = None, | ||
| ) -> dict[str, str]: | ||
| """Run the Graphomotor pipeline to extract features from spiral drawing data. | ||
|
|
||
| Args: | ||
| input_path: Path to the input CSV file containing spiral drawing data. | ||
| output_path: Path where extracted features will be saved. Three behaviors are | ||
| possible: | ||
| - If None: Features are calculated but not saved to disk. | ||
| - If path includes file extension (e.g., "/path/to/output.csv"): This exact | ||
| file path is used. | ||
| - If path has no file extension (e.g., "/path/to/output"): A file is | ||
| created in that directory with auto-generated filename: | ||
| "/path/to/output/{participant_id}_{task}_{hand}_features_{date}.csv". | ||
| feature_categories: List of feature categories to extract. Options are: | ||
| - "duration": Extract task duration. | ||
| - "velocity": Extract velocity-based metrics. | ||
| - "hausdorff": Extract Hausdorff distance metrics. | ||
| - "AUC": Extract area under the curve metric. | ||
| config_params: Optional configuration parameters for spiral drawing. If None, | ||
| default parameters are used. | ||
|
|
||
| Returns: | ||
| Dictionary containing the extracted features. | ||
| """ | ||
| logger.info("Starting Graphomotor pipeline") | ||
| logger.info(f"Input path: {input_path}") | ||
| logger.info(f"Output path: {output_path}") | ||
| logger.info(f"Feature categories: {feature_categories}") | ||
|
|
||
| spiral_config = None | ||
| if config_params: | ||
| logger.info(f"Custom spiral configuration: {config_params}") | ||
| spiral_config = config.SpiralConfig.add_custom_params(config_params) | ||
|
|
||
| features = extract_features( | ||
| input_path, output_path, feature_categories, spiral_config | ||
| ) | ||
|
|
||
| logger.info("Graphomotor pipeline completed successfully") | ||
| return features | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.