From ef391d4d800040acae2c50606eeea5bb847d4a11 Mon Sep 17 00:00:00 2001 From: Alp Erkent Date: Thu, 10 Apr 2025 13:05:44 -0400 Subject: [PATCH 1/5] Implement Hausdorff distance metrics for spiral analysis - Updated pyproject.toml and uv.lock to include scipy as a dependency. - Implemented Hausdorff distance metrics in distance.py for analyzing spiral drawing data. - Added tests for distance metrics and data segmentation in test_distance.py. - Created reference spiral generation in config.py and corresponding tests in test_config.py. --- pyproject.toml | 3 +- src/graphomotor/core/config.py | 17 ++++ src/graphomotor/features/__init__.py | 1 + src/graphomotor/features/distance.py | 123 ++++++++++++++++++++++++ tests/conftest.py | 21 +++++ tests/unit/test_config.py | 15 +++ tests/unit/test_distance.py | 134 +++++++++++++++++++++++++++ uv.lock | 40 ++++++++ 8 files changed, 353 insertions(+), 1 deletion(-) create mode 100644 src/graphomotor/core/config.py create mode 100644 src/graphomotor/features/__init__.py create mode 100644 src/graphomotor/features/distance.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_distance.py diff --git a/pyproject.toml b/pyproject.toml index 5c49a05..fb5be2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,8 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "pandas>=2.2.3", - "pydantic>=2.11.1" + "pydantic>=2.11.1", + "scipy>=1.15.2" ] [dependency-groups] diff --git a/src/graphomotor/core/config.py b/src/graphomotor/core/config.py new file mode 100644 index 0000000..4da36c4 --- /dev/null +++ b/src/graphomotor/core/config.py @@ -0,0 +1,17 @@ +"""Configuration module for the graphomotor repository.""" + +import numpy as np + + +def generate_reference_spiral() -> np.ndarray: + """Generates a reference spiral for feature extraction purposes.""" + cx, cy = (50, 50) # center of the spiral + a = 0 # starting radius + b = 1.075 # growth rate + num_points = 10000 + spiral_length = 8 * np.pi # spiral makes 4 full rotations + theta = np.linspace(0, spiral_length, num_points) + r = a + b * theta + x = cx + r * np.cos(theta) + y = cy + r * np.sin(theta) + return np.column_stack((x, y)) diff --git a/src/graphomotor/features/__init__.py b/src/graphomotor/features/__init__.py new file mode 100644 index 0000000..7cff8c5 --- /dev/null +++ b/src/graphomotor/features/__init__.py @@ -0,0 +1 @@ +""".. include:: ../../README.md""" # noqa: D415 diff --git a/src/graphomotor/features/distance.py b/src/graphomotor/features/distance.py new file mode 100644 index 0000000..c77cf46 --- /dev/null +++ b/src/graphomotor/features/distance.py @@ -0,0 +1,123 @@ +"""Feature extraction module for distance-based metrics in spiral drawing data.""" + +import numpy as np +from scipy import stats +from scipy.spatial import distance + +from graphomotor.core import models + + +def _segment_data(data: np.ndarray, start_pct: float, end_pct: float) -> np.ndarray: + """Extract segment of data based on percentage range. + + Args: + data: Data to segment + start_pct: Start percentage [0-1) + end_pct: End percentage (0-1] + + Returns: + Segmented data + """ + if not (0 <= start_pct < end_pct <= 1): + raise ValueError( + "Percentages must be between 0 and 1, " + "and start_pct must be less than end_pct" + ) + num_samples = len(data) + start_idx = int(start_pct * num_samples) + end_idx = int(end_pct * num_samples) + return data[start_idx:end_idx] + + +def calculate_hausdorff_metrics( + spiral: models.Spiral, reference_spiral: np.ndarray +) -> dict: + """Calculate Hausdorff distance metrics for a spiral object. + + This function computes multiple features based on the Hausdorff distance between a + drawn spiral and a reference (ideal) spiral, as described in [1]. The Hausdorff + distance measures the maximum distance of a set to the nearest point in the other + set. This metric and its derivatives capture various aspects of the spatial + relationship between the drawn and reference spirals. Calculated features include: + - max_haus_dist: The maximum of the directed Hausdorff distances between the + data points and the reference data points. + - sum_haus_dist: The sum of the directed Hausdorff distances. + - sum_haus_dist_time: The sum of the directed Hausdorff distances divided by + the total drawing duration. + - iqr_haus_dist: The interquartile range of the directed Hausdorff distances. + - max_haus_dist_start: The maximum of the directed Hausdorff distances between + the beginning segment (0% to 25%) of data points and the beginning segment + of reference data points divided by the number of data points in the + beginning segment. + - max_haus_dist_end: The maximum of the directed Hausdorff distances in the + ending segment (75% to 100%) of data points and the ending segment of + reference data points divided by the number of data points in the ending + segment. + - max_haus_dist_mid: The maximum of the directed Hausdorff distances in the + middle segment (15% to 85%) of data points and the ending segment of + reference data points (this metric is not divided by the number of data + points in the middle segment unlike previous ones). + - max_haus_dist_mid_time: The maximum of the directed Hausdorff distances in + the middle segment divided by the total drawing duration. + + Args: + spiral: Spiral object with drawing data + reference_spiral: Reference spiral data for comparison + + Returns: + Dictionary containing Hausdorff distance-based features + + References: + [1] Messan, Komi S et al. “Assessment of Smartphone-Based Spiral Tracing in + Multiple Sclerosis Reveals Intra-Individual Reproducibility as a Major + Determinant of the Clinical Utility of the Digital Test.” Frontiers in + medical technology vol. 3 714682. 1 Feb. 2022, doi:10.3389/fmedt.2021.714682 + """ + spiral_data = np.column_stack((spiral.data["x"].values, spiral.data["y"].values)) + + total_duration = spiral.data["seconds"].iloc[-1] + + start_segment_data = _segment_data(spiral_data, 0.0, 0.25) + end_segment_data = _segment_data(spiral_data, 0.75, 1.0) + mid_segment_data = _segment_data(spiral_data, 0.15, 0.85) + + if ( + len(start_segment_data) == 0 + or len(end_segment_data) == 0 + or len(mid_segment_data) == 0 + ): + raise ValueError( + "Segmented data is empty, check spiral data or segment percentages" + ) + + start_segment_ref = _segment_data(reference_spiral, 0.0, 0.25) + end_segment_ref = _segment_data(reference_spiral, 0.75, 1.0) + mid_segment_ref = _segment_data(reference_spiral, 0.15, 0.85) + + haus_dist = [ + distance.directed_hausdorff(spiral_data, reference_spiral)[0], + distance.directed_hausdorff(reference_spiral, spiral_data)[0], + ] + haus_dist_start = [ + distance.directed_hausdorff(start_segment_data, start_segment_ref)[0], + distance.directed_hausdorff(start_segment_ref, start_segment_data)[0], + ] + haus_dist_end = [ + distance.directed_hausdorff(end_segment_data, end_segment_ref)[0], + distance.directed_hausdorff(end_segment_ref, end_segment_data)[0], + ] + haus_dist_mid = [ + distance.directed_hausdorff(mid_segment_data, mid_segment_ref)[0], + distance.directed_hausdorff(mid_segment_ref, mid_segment_data)[0], + ] + + return { + "max_haus_dist": np.max(haus_dist), + "sum_haus_dist": np.sum(haus_dist), + "sum_haus_dist_time": np.sum(haus_dist) / total_duration, + "iqr_haus_dist": stats.iqr(haus_dist), + "max_haus_dist_start": np.max(haus_dist_start) / len(start_segment_data), + "max_haus_dist_end": np.max(haus_dist_end) / len(end_segment_data), + "max_haus_dist_mid": np.max(haus_dist_mid), + "max_haus_dist_mid_time": np.max(haus_dist_mid) / total_duration, + } diff --git a/tests/conftest.py b/tests/conftest.py index d9fd149..845f847 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,12 @@ import datetime import pathlib +import numpy as np import pandas as pd import pytest +from graphomotor.core import config, models + @pytest.fixture def sample_data() -> pathlib.Path: @@ -35,3 +38,21 @@ def valid_spiral_metadata() -> dict[str, str | datetime.datetime]: tz=datetime.timezone.utc, ), } + + +@pytest.fixture +def valid_spiral( + valid_spiral_data: pd.DataFrame, + valid_spiral_metadata: dict[str, str | datetime.datetime], +) -> models.Spiral: + """Create a valid Spiral object.""" + return models.Spiral( + data=valid_spiral_data, + metadata=valid_spiral_metadata, + ) + + +@pytest.fixture +def reference_spiral() -> np.ndarray: + """Create a reference spiral for testing.""" + return config.generate_reference_spiral() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..b612998 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,15 @@ +"""Test cases for config.py functions.""" + +import numpy as np +import pytest + +from graphomotor.core import config + + +def test_generate_reference_spiral() -> None: + """Test the generation of a reference spiral.""" + spiral = config.generate_reference_spiral() + assert isinstance(spiral, np.ndarray) + assert spiral.shape == (10000, 2) + assert spiral[0] == pytest.approx([50, 50]) + assert spiral[-1] == pytest.approx([50 + 1.075 * 8 * np.pi, 50]) diff --git a/tests/unit/test_distance.py b/tests/unit/test_distance.py new file mode 100644 index 0000000..8a5d63f --- /dev/null +++ b/tests/unit/test_distance.py @@ -0,0 +1,134 @@ +"""Test cases for distance.py functions.""" + +import numpy as np +import pandas as pd +import pytest +import scipy.spatial.distance as dist +from scipy import stats + +from graphomotor.core import models +from graphomotor.features import distance + + +def test_segment_data_valid() -> None: + """Test that the data is segmented correctly.""" + data = np.array([[i, i] for i in range(100)]) + + segment = distance._segment_data(data, 0.1, 0.3) + assert len(segment) == 20 + assert segment[0][0] == 10 + assert segment[-1][0] == 29 + + +@pytest.mark.parametrize( + "start_pct,end_pct", + [ + (-0.1, 0.5), + (0.1, 1.1), + (0.6, 0.5), + (0.5, 0.5), + ], +) +def test_segment_data_invalid(start_pct: float, end_pct: float) -> None: + """Test that invalid percentages raise a ValueError.""" + data = np.array([[i, i] for i in range(100)]) + + with pytest.raises( + ValueError, + match=( + "Percentages must be between 0 and 1, " + "and start_pct must be less than end_pct" + ), + ): + distance._segment_data(data, start_pct, end_pct) + + +def test_calculate_hausdorff_metrics( + valid_spiral: models.Spiral, reference_spiral: np.ndarray +) -> None: + """Test that each Hausdorff metric is calculated.""" + metrics = distance.calculate_hausdorff_metrics(valid_spiral, reference_spiral) + + expected_metrics = [ + "max_haus_dist", + "sum_haus_dist", + "sum_haus_dist_time", + "iqr_haus_dist", + "max_haus_dist_start", + "max_haus_dist_end", + "max_haus_dist_mid", + "max_haus_dist_mid_time", + ] + + for metric in expected_metrics: + assert metric in metrics + assert isinstance(metrics[metric], float) + + +def test_calculate_hausdorff_metrics_empty_segments( + valid_spiral_data: pd.DataFrame, + valid_spiral_metadata: dict, + reference_spiral: np.ndarray, +) -> None: + """Test that empty segments raise a ValueError.""" + small_spiral_data = valid_spiral_data.iloc[:3] + small_spiral = models.Spiral( + data=small_spiral_data, + metadata=valid_spiral_metadata, + ) + with pytest.raises( + ValueError, + match="Segmented data is empty, check spiral data or segment percentages", + ): + distance.calculate_hausdorff_metrics(small_spiral, reference_spiral) + + +def test_hausdorff_metrics_values( + valid_spiral: models.Spiral, reference_spiral: np.ndarray +) -> None: + """Test that Hausdorff metrics are calculated correctly.""" + metrics = distance.calculate_hausdorff_metrics(valid_spiral, reference_spiral) + + data = valid_spiral.data[["x", "y"]].values + ref_data = reference_spiral + + total_duration = valid_spiral.data["seconds"].iloc[-1] + + data_start = data[: int(len(data) * 0.25)] + data_end = data[int(len(data) * 0.75) :] + data_mid = data[int(len(data) * 0.15) : int(len(data) * 0.85)] + + ref_data_start = ref_data[: int(len(ref_data) * 0.25)] + ref_data_end = ref_data[int(len(ref_data) * 0.75) :] + ref_data_mid = ref_data[int(len(ref_data) * 0.15) : int(len(ref_data) * 0.85)] + + dist_matrix = dist.cdist(data, ref_data, "euclidean") + dist_matrix_start = dist.cdist(data_start, ref_data_start, "euclidean") + dist_matrix_end = dist.cdist(data_end, ref_data_end, "euclidean") + dist_matrix_mid = dist.cdist(data_mid, ref_data_mid, "euclidean") + + haus_dist = [ + np.max(np.min(dist_matrix, axis=0)), + np.max(np.min(dist_matrix, axis=1)), + ] + haus_dist_start = [ + np.max(np.min(dist_matrix_start, axis=0)), + np.max(np.min(dist_matrix_start, axis=1)), + ] + haus_dist_end = [ + np.max(np.min(dist_matrix_end, axis=0)), + np.max(np.min(dist_matrix_end, axis=1)), + ] + haus_dist_mid = [ + np.max(np.min(dist_matrix_mid, axis=0)), + np.max(np.min(dist_matrix_mid, axis=1)), + ] + + assert metrics["max_haus_dist"] == np.max(haus_dist) + assert metrics["sum_haus_dist"] == np.sum(haus_dist) + assert metrics["sum_haus_dist_time"] == np.sum(haus_dist) / total_duration + assert metrics["iqr_haus_dist"] == stats.iqr(haus_dist) + assert metrics["max_haus_dist_start"] == np.max(haus_dist_start) / len(data_start) + assert metrics["max_haus_dist_end"] == np.max(haus_dist_end) / len(data_end) + assert metrics["max_haus_dist_mid"] == np.max(haus_dist_mid) + assert metrics["max_haus_dist_mid_time"] == np.max(haus_dist_mid) / total_duration diff --git a/uv.lock b/uv.lock index 8e66b26..bf0402a 100644 --- a/uv.lock +++ b/uv.lock @@ -92,6 +92,7 @@ source = { editable = "." } dependencies = [ { name = "pandas" }, { name = "pydantic" }, + { name = "scipy" }, ] [package.dev-dependencies] @@ -110,6 +111,7 @@ docs = [ requires-dist = [ { name = "pandas", specifier = ">=2.2.3" }, { name = "pydantic", specifier = ">=2.11.1" }, + { name = "scipy", specifier = ">=1.15.2" }, ] [package.metadata.requires-dev] @@ -528,6 +530,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d6/d4/dd813703af8a1e2ac33bf3feb27e8a5ad514c9f219df80c64d69807e7f71/ruff-0.11.2-py3-none-win_arm64.whl", hash = "sha256:52933095158ff328f4c77af3d74f0379e34fd52f175144cefc1b192e7ccd32b4", size = 10441990 }, ] +[[package]] +name = "scipy" +version = "1.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/b9/31ba9cd990e626574baf93fbc1ac61cf9ed54faafd04c479117517661637/scipy-1.15.2.tar.gz", hash = "sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec", size = 59417316 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/5d/3c78815cbab499610f26b5bae6aed33e227225a9fa5290008a733a64f6fc/scipy-1.15.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd", size = 38756184 }, + { url = "https://files.pythonhosted.org/packages/37/20/3d04eb066b471b6e171827548b9ddb3c21c6bbea72a4d84fc5989933910b/scipy-1.15.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301", size = 30163558 }, + { url = "https://files.pythonhosted.org/packages/a4/98/e5c964526c929ef1f795d4c343b2ff98634ad2051bd2bbadfef9e772e413/scipy-1.15.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93", size = 22437211 }, + { url = "https://files.pythonhosted.org/packages/1d/cd/1dc7371e29195ecbf5222f9afeedb210e0a75057d8afbd942aa6cf8c8eca/scipy-1.15.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20", size = 25232260 }, + { url = "https://files.pythonhosted.org/packages/f0/24/1a181a9e5050090e0b5138c5f496fee33293c342b788d02586bc410c6477/scipy-1.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e", size = 35198095 }, + { url = "https://files.pythonhosted.org/packages/c0/53/eaada1a414c026673eb983f8b4a55fe5eb172725d33d62c1b21f63ff6ca4/scipy-1.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8", size = 37297371 }, + { url = "https://files.pythonhosted.org/packages/e9/06/0449b744892ed22b7e7b9a1994a866e64895363572677a316a9042af1fe5/scipy-1.15.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11", size = 36872390 }, + { url = "https://files.pythonhosted.org/packages/6a/6f/a8ac3cfd9505ec695c1bc35edc034d13afbd2fc1882a7c6b473e280397bb/scipy-1.15.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53", size = 39700276 }, + { url = "https://files.pythonhosted.org/packages/f5/6f/e6e5aff77ea2a48dd96808bb51d7450875af154ee7cbe72188afb0b37929/scipy-1.15.2-cp312-cp312-win_amd64.whl", hash = "sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded", size = 40942317 }, + { url = "https://files.pythonhosted.org/packages/53/40/09319f6e0f276ea2754196185f95cd191cb852288440ce035d5c3a931ea2/scipy-1.15.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf", size = 38717587 }, + { url = "https://files.pythonhosted.org/packages/fe/c3/2854f40ecd19585d65afaef601e5e1f8dbf6758b2f95b5ea93d38655a2c6/scipy-1.15.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37", size = 30100266 }, + { url = "https://files.pythonhosted.org/packages/dd/b1/f9fe6e3c828cb5930b5fe74cb479de5f3d66d682fa8adb77249acaf545b8/scipy-1.15.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d", size = 22373768 }, + { url = "https://files.pythonhosted.org/packages/15/9d/a60db8c795700414c3f681908a2b911e031e024d93214f2d23c6dae174ab/scipy-1.15.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb", size = 25154719 }, + { url = "https://files.pythonhosted.org/packages/37/3b/9bda92a85cd93f19f9ed90ade84aa1e51657e29988317fabdd44544f1dd4/scipy-1.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27", size = 35163195 }, + { url = "https://files.pythonhosted.org/packages/03/5a/fc34bf1aa14dc7c0e701691fa8685f3faec80e57d816615e3625f28feb43/scipy-1.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0", size = 37255404 }, + { url = "https://files.pythonhosted.org/packages/4a/71/472eac45440cee134c8a180dbe4c01b3ec247e0338b7c759e6cd71f199a7/scipy-1.15.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32", size = 36860011 }, + { url = "https://files.pythonhosted.org/packages/01/b3/21f890f4f42daf20e4d3aaa18182dddb9192771cd47445aaae2e318f6738/scipy-1.15.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d", size = 39657406 }, + { url = "https://files.pythonhosted.org/packages/0d/76/77cf2ac1f2a9cc00c073d49e1e16244e389dd88e2490c91d84e1e3e4d126/scipy-1.15.2-cp313-cp313-win_amd64.whl", hash = "sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f", size = 40961243 }, + { url = "https://files.pythonhosted.org/packages/4c/4b/a57f8ddcf48e129e6054fa9899a2a86d1fc6b07a0e15c7eebff7ca94533f/scipy-1.15.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9", size = 38870286 }, + { url = "https://files.pythonhosted.org/packages/0c/43/c304d69a56c91ad5f188c0714f6a97b9c1fed93128c691148621274a3a68/scipy-1.15.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f", size = 30141634 }, + { url = "https://files.pythonhosted.org/packages/44/1a/6c21b45d2548eb73be9b9bff421aaaa7e85e22c1f9b3bc44b23485dfce0a/scipy-1.15.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6", size = 22415179 }, + { url = "https://files.pythonhosted.org/packages/74/4b/aefac4bba80ef815b64f55da06f62f92be5d03b467f2ce3668071799429a/scipy-1.15.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af", size = 25126412 }, + { url = "https://files.pythonhosted.org/packages/b1/53/1cbb148e6e8f1660aacd9f0a9dfa2b05e9ff1cb54b4386fe868477972ac2/scipy-1.15.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274", size = 34952867 }, + { url = "https://files.pythonhosted.org/packages/2c/23/e0eb7f31a9c13cf2dca083828b97992dd22f8184c6ce4fec5deec0c81fcf/scipy-1.15.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776", size = 36890009 }, + { url = "https://files.pythonhosted.org/packages/03/f3/e699e19cabe96bbac5189c04aaa970718f0105cff03d458dc5e2b6bd1e8c/scipy-1.15.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828", size = 36545159 }, + { url = "https://files.pythonhosted.org/packages/af/f5/ab3838e56fe5cc22383d6fcf2336e48c8fe33e944b9037fbf6cbdf5a11f8/scipy-1.15.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28", size = 39136566 }, + { url = "https://files.pythonhosted.org/packages/0a/c8/b3f566db71461cabd4b2d5b39bcc24a7e1c119535c8361f81426be39bb47/scipy-1.15.2-cp313-cp313t-win_amd64.whl", hash = "sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db", size = 40477705 }, +] + [[package]] name = "six" version = "1.17.0" From 96136722087c316b56e28f0572a712dbd413780e Mon Sep 17 00:00:00 2001 From: Alp Erkent Date: Wed, 16 Apr 2025 15:41:21 -0400 Subject: [PATCH 2/5] Refactor and move reference spiral generation script and related tests, rename Hausdorff features, remove unnecessary tests --- src/graphomotor/core/config.py | 17 --- src/graphomotor/core/models.py | 12 +-- src/graphomotor/features/distance.py | 101 +++++++++--------- src/graphomotor/utils/reference_spiral.py | 109 +++++++++++++++++++ tests/conftest.py | 7 +- tests/unit/test_config.py | 15 --- tests/unit/test_distance.py | 122 ++++------------------ tests/unit/test_reference_spiral.py | 17 +++ 8 files changed, 205 insertions(+), 195 deletions(-) delete mode 100644 src/graphomotor/core/config.py create mode 100644 src/graphomotor/utils/reference_spiral.py delete mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_reference_spiral.py diff --git a/src/graphomotor/core/config.py b/src/graphomotor/core/config.py deleted file mode 100644 index 4da36c4..0000000 --- a/src/graphomotor/core/config.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Configuration module for the graphomotor repository.""" - -import numpy as np - - -def generate_reference_spiral() -> np.ndarray: - """Generates a reference spiral for feature extraction purposes.""" - cx, cy = (50, 50) # center of the spiral - a = 0 # starting radius - b = 1.075 # growth rate - num_points = 10000 - spiral_length = 8 * np.pi # spiral makes 4 full rotations - theta = np.linspace(0, spiral_length, num_points) - r = a + b * theta - x = cx + r * np.cos(theta) - y = cy + r * np.sin(theta) - return np.column_stack((x, y)) diff --git a/src/graphomotor/core/models.py b/src/graphomotor/core/models.py index 0bd3b5e..116f1bc 100644 --- a/src/graphomotor/core/models.py +++ b/src/graphomotor/core/models.py @@ -10,13 +10,13 @@ class Spiral(BaseModel): """A class representing a spiral drawing, encapsulating both raw data and metadata. Attributes: - data: DataFrame containing drawing data with required columns - (line_number, x, y, UTC_Timestamp, seconds) + data: DataFrame containing drawing data with required columns (line_number, x, + y, UTC_Timestamp, seconds). metadata: Dictionary containing metadata about the spiral: - - id: Unique identifier for the participant - - hand: Hand used ('Dom' for dominant, 'NonDom' for non-dominant) - - task: Task name - - start_time: Start time of drawing + - id: Unique identifier for the participant, + - hand: Hand used ('Dom' for dominant, 'NonDom' for non-dominant), + - task: Task name, + - start_time: Start time of drawing. """ model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/graphomotor/features/distance.py b/src/graphomotor/features/distance.py index c77cf46..186e427 100644 --- a/src/graphomotor/features/distance.py +++ b/src/graphomotor/features/distance.py @@ -7,25 +7,25 @@ from graphomotor.core import models -def _segment_data(data: np.ndarray, start_pct: float, end_pct: float) -> np.ndarray: - """Extract segment of data based on percentage range. +def _segment_data(data: np.ndarray, start_prop: float, end_prop: float) -> np.ndarray: + """Extract segment of data based on given proportion range. Args: - data: Data to segment - start_pct: Start percentage [0-1) - end_pct: End percentage (0-1] + data: Data to segment. + start_prop: Start proportion, [0-1). + end_prop: End proportion, (0-1]. Returns: - Segmented data + Segmented data. """ - if not (0 <= start_pct < end_pct <= 1): + if not (0 <= start_prop < end_prop <= 1): raise ValueError( - "Percentages must be between 0 and 1, " - "and start_pct must be less than end_pct" + "Proportions must be between 0 and 1, " + "and start_prop must be less than end_prop" ) num_samples = len(data) - start_idx = int(start_pct * num_samples) - end_idx = int(end_pct * num_samples) + start_idx = int(start_prop * num_samples) + end_idx = int(end_prop * num_samples) return data[start_idx:end_idx] @@ -35,37 +35,40 @@ def calculate_hausdorff_metrics( """Calculate Hausdorff distance metrics for a spiral object. This function computes multiple features based on the Hausdorff distance between a - drawn spiral and a reference (ideal) spiral, as described in [1]. The Hausdorff + drawn spiral and a reference (ideal) spiral, as described in the [1]. Implementation + is based on the original R script provided with the publication. The Hausdorff distance measures the maximum distance of a set to the nearest point in the other set. This metric and its derivatives capture various aspects of the spatial relationship between the drawn and reference spirals. Calculated features include: - - max_haus_dist: The maximum of the directed Hausdorff distances between the - data points and the reference data points. - - sum_haus_dist: The sum of the directed Hausdorff distances. - - sum_haus_dist_time: The sum of the directed Hausdorff distances divided by - the total drawing duration. - - iqr_haus_dist: The interquartile range of the directed Hausdorff distances. - - max_haus_dist_start: The maximum of the directed Hausdorff distances between - the beginning segment (0% to 25%) of data points and the beginning segment - of reference data points divided by the number of data points in the - beginning segment. - - max_haus_dist_end: The maximum of the directed Hausdorff distances in the - ending segment (75% to 100%) of data points and the ending segment of - reference data points divided by the number of data points in the ending - segment. - - max_haus_dist_mid: The maximum of the directed Hausdorff distances in the - middle segment (15% to 85%) of data points and the ending segment of - reference data points (this metric is not divided by the number of data - points in the middle segment unlike previous ones). - - max_haus_dist_mid_time: The maximum of the directed Hausdorff distances in - the middle segment divided by the total drawing duration. + - hausdorff_distance_maximum: The maximum of the directed Hausdorff distances + between the data points and the reference data points, + - hausdorff_distance_sum: The sum of the directed Hausdorff distances, + - hausdorff_distance_sum_per_second: The sum of the directed Hausdorff distances + divided by the total drawing duration, + - hausdorff_distance_interquartile_range: The interquartile range of the + directed Hausdorff distances, + - hausdorff_distance_start_segment_maximum_normalized: The maximum of the + directed Hausdorff distances between the beginning segment (0% to 25%) of + data points and the beginning segment of reference data points divided by + the number of data points in the beginning segment, + - hausdorff_distance_end_segment_maximum_normalized: The maximum of the directed + Hausdorff distances in the ending segment (75% to 100%) of data points and + the ending segment of reference data points divided by the number of data + points in the ending segment, + - hausdorff_distance_middle_segment_maximum: The maximum of the directed + Hausdorff distances in the middle segment (15% to 85%) of data points and + the ending segment of reference data points (this metric is not divided by + the number of data points in the middle segment unlike previous ones), + - hausdorff_distance_middle_segment_maximum_per_second: The maximum of the + directed Hausdorff distances in the middle segment divided by the total + drawing duration. Args: - spiral: Spiral object with drawing data - reference_spiral: Reference spiral data for comparison + spiral: Spiral object with drawing data. + reference_spiral: Reference spiral data for comparison. Returns: - Dictionary containing Hausdorff distance-based features + Dictionary containing Hausdorff distance-based features. References: [1] Messan, Komi S et al. “Assessment of Smartphone-Based Spiral Tracing in @@ -81,15 +84,6 @@ def calculate_hausdorff_metrics( end_segment_data = _segment_data(spiral_data, 0.75, 1.0) mid_segment_data = _segment_data(spiral_data, 0.15, 0.85) - if ( - len(start_segment_data) == 0 - or len(end_segment_data) == 0 - or len(mid_segment_data) == 0 - ): - raise ValueError( - "Segmented data is empty, check spiral data or segment percentages" - ) - start_segment_ref = _segment_data(reference_spiral, 0.0, 0.25) end_segment_ref = _segment_data(reference_spiral, 0.75, 1.0) mid_segment_ref = _segment_data(reference_spiral, 0.15, 0.85) @@ -112,12 +106,15 @@ def calculate_hausdorff_metrics( ] return { - "max_haus_dist": np.max(haus_dist), - "sum_haus_dist": np.sum(haus_dist), - "sum_haus_dist_time": np.sum(haus_dist) / total_duration, - "iqr_haus_dist": stats.iqr(haus_dist), - "max_haus_dist_start": np.max(haus_dist_start) / len(start_segment_data), - "max_haus_dist_end": np.max(haus_dist_end) / len(end_segment_data), - "max_haus_dist_mid": np.max(haus_dist_mid), - "max_haus_dist_mid_time": np.max(haus_dist_mid) / total_duration, + "hausdorff_distance_maximum": np.max(haus_dist), + "hausdorff_distance_sum": np.sum(haus_dist), + "hausdorff_distance_sum_per_second": np.sum(haus_dist) / total_duration, + "hausdorff_distance_interquartile_range": stats.iqr(haus_dist), + "hausdorff_distance_start_segment_maximum_normalized": np.max(haus_dist_start) + / len(start_segment_data), + "hausdorff_distance_end_segment_maximum_normalized": np.max(haus_dist_end) + / len(end_segment_data), + "hausdorff_distance_middle_segment_maximum": np.max(haus_dist_mid), + "hausdorff_distance_middle_segment_maximum_per_second": np.max(haus_dist_mid) + / total_duration, } diff --git a/src/graphomotor/utils/reference_spiral.py b/src/graphomotor/utils/reference_spiral.py new file mode 100644 index 0000000..87f9732 --- /dev/null +++ b/src/graphomotor/utils/reference_spiral.py @@ -0,0 +1,109 @@ +"""Generate a reference spiral with equidistant points along its arc length.""" + +import numpy as np +from scipy import integrate, optimize + +_SPIRAL_CENTER_X = 50 +_SPIRAL_CENTER_Y = 50 +_SPIRAL_INITIAL_RADIUS = 0 +_SPIRAL_GROWTH_RATE = 1.075 +_SPIRAL_TOTAL_ROTATION = 8 * np.pi +_SPIRAL_NUM_POINTS = 10000 + + +def _spiral_arc_length_integrand(t: float) -> float: + """Calculate the differential arc length at angle t for an Archimedean spiral. + + Args: + t: Angle parameter. + + Returns: + Differential arc length value. + """ + r_t = _SPIRAL_INITIAL_RADIUS + _SPIRAL_GROWTH_RATE * t + return np.sqrt(r_t**2 + _SPIRAL_GROWTH_RATE**2) + + +def _calculate_arc_length(theta: float) -> float: + """Calculate the arc length of the spiral from 0 to theta. + + Args: + theta: The angle in radians. + + Returns: + The arc length of the spiral from 0 to theta. + """ + return integrate.quad(lambda t: _spiral_arc_length_integrand(t), 0, theta)[0] + + +def _arc_length_difference(theta: float, target_arc_length: float) -> float: + """Function to find the root for a given arc length. + + Args: + theta: Angle to evaluate. + target_arc_length: Target arc length value. + + Returns: + Difference between calculated and target arc length. + """ + return _calculate_arc_length(theta) - target_arc_length + + +def _find_theta_for_arc_length(target_arc_length: float) -> float: + """Find the theta value for a given arc length. + + Args: + target_arc_length: Target arc length. + + Returns: + Angle theta corresponding to the arc length. + """ + solution = optimize.root_scalar( + lambda theta: _arc_length_difference(theta, target_arc_length), + bracket=[0, _SPIRAL_TOTAL_ROTATION], + ) + return solution.root + + +def generate_reference_spiral() -> np.ndarray: + """Generate a reference spiral with equidistant points along its arc length. + + This function creates an Archimedean spiral with points distributed at equal arc + length intervals. The generated spiral serves as a standardized reference template + for feature extraction algorithms that compare user-drawn spirals with an ideal + form. + + The algorithm works by: + 1. Computing the total arc length for the entire spiral (0 to 8π), + 2. Creating equidistant target arc length values, + 3. For each target arc length, finding the corresponding theta value that + produces that arc length using numerical root finding, + 4. Converting these theta values to Cartesian coordinates. + + Mathematical formulas used: + - Spiral equation: r(θ) = a + b·θ + - Arc length differential: ds = √(r(θ)² + b²) dθ + - Arc length from 0 to θ: s(θ) = ∫₀ᶿ √(r(t)² + b²) dt + - Cartesian coordinates: x = cx + r·cos(θ), y = cy + r·sin(θ) + + Parameters used: + - Center coordinates: (50, 50) + - Initial radius (a): 0 + - Growth rate (b): 1.075 + - Total rotation: 4 complete revolutions (θ from 0 to 8π) + - Number of points: 10,000 + + Returns: + Array with shape (10000, 2) containing Cartesian coordinates of the spiral. + """ + total_arc_length = _calculate_arc_length(_SPIRAL_TOTAL_ROTATION) + + arc_length_values = np.linspace(0, total_arc_length, _SPIRAL_NUM_POINTS) + + theta_values = np.array([_find_theta_for_arc_length(s) for s in arc_length_values]) + + r_values = _SPIRAL_INITIAL_RADIUS + _SPIRAL_GROWTH_RATE * theta_values + x_values = _SPIRAL_CENTER_X + r_values * np.cos(theta_values) + y_values = _SPIRAL_CENTER_Y + r_values * np.sin(theta_values) + + return np.column_stack((x_values, y_values)) diff --git a/tests/conftest.py b/tests/conftest.py index 845f847..091c0e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,8 @@ import pandas as pd import pytest -from graphomotor.core import config, models +from graphomotor.core import models +from graphomotor.utils import reference_spiral @pytest.fixture @@ -53,6 +54,6 @@ def valid_spiral( @pytest.fixture -def reference_spiral() -> np.ndarray: +def ref_spiral() -> np.ndarray: """Create a reference spiral for testing.""" - return config.generate_reference_spiral() + return reference_spiral.generate_reference_spiral() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py deleted file mode 100644 index b612998..0000000 --- a/tests/unit/test_config.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Test cases for config.py functions.""" - -import numpy as np -import pytest - -from graphomotor.core import config - - -def test_generate_reference_spiral() -> None: - """Test the generation of a reference spiral.""" - spiral = config.generate_reference_spiral() - assert isinstance(spiral, np.ndarray) - assert spiral.shape == (10000, 2) - assert spiral[0] == pytest.approx([50, 50]) - assert spiral[-1] == pytest.approx([50 + 1.075 * 8 * np.pi, 50]) diff --git a/tests/unit/test_distance.py b/tests/unit/test_distance.py index 8a5d63f..9a4d8f1 100644 --- a/tests/unit/test_distance.py +++ b/tests/unit/test_distance.py @@ -1,10 +1,7 @@ """Test cases for distance.py functions.""" import numpy as np -import pandas as pd import pytest -import scipy.spatial.distance as dist -from scipy import stats from graphomotor.core import models from graphomotor.features import distance @@ -20,115 +17,36 @@ def test_segment_data_valid() -> None: assert segment[-1][0] == 29 -@pytest.mark.parametrize( - "start_pct,end_pct", - [ - (-0.1, 0.5), - (0.1, 1.1), - (0.6, 0.5), - (0.5, 0.5), - ], -) -def test_segment_data_invalid(start_pct: float, end_pct: float) -> None: +def test_segment_data_invalid() -> None: """Test that invalid percentages raise a ValueError.""" data = np.array([[i, i] for i in range(100)]) with pytest.raises( ValueError, match=( - "Percentages must be between 0 and 1, " - "and start_pct must be less than end_pct" + "Proportions must be between 0 and 1, " + "and start_prop must be less than end_prop" ), ): - distance._segment_data(data, start_pct, end_pct) + distance._segment_data(data, 0.6, 0.5) def test_calculate_hausdorff_metrics( - valid_spiral: models.Spiral, reference_spiral: np.ndarray + valid_spiral: models.Spiral, ref_spiral: np.ndarray ) -> None: """Test that each Hausdorff metric is calculated.""" - metrics = distance.calculate_hausdorff_metrics(valid_spiral, reference_spiral) - - expected_metrics = [ - "max_haus_dist", - "sum_haus_dist", - "sum_haus_dist_time", - "iqr_haus_dist", - "max_haus_dist_start", - "max_haus_dist_end", - "max_haus_dist_mid", - "max_haus_dist_mid_time", - ] - - for metric in expected_metrics: - assert metric in metrics - assert isinstance(metrics[metric], float) - - -def test_calculate_hausdorff_metrics_empty_segments( - valid_spiral_data: pd.DataFrame, - valid_spiral_metadata: dict, - reference_spiral: np.ndarray, -) -> None: - """Test that empty segments raise a ValueError.""" - small_spiral_data = valid_spiral_data.iloc[:3] - small_spiral = models.Spiral( - data=small_spiral_data, - metadata=valid_spiral_metadata, - ) - with pytest.raises( - ValueError, - match="Segmented data is empty, check spiral data or segment percentages", - ): - distance.calculate_hausdorff_metrics(small_spiral, reference_spiral) - - -def test_hausdorff_metrics_values( - valid_spiral: models.Spiral, reference_spiral: np.ndarray -) -> None: - """Test that Hausdorff metrics are calculated correctly.""" - metrics = distance.calculate_hausdorff_metrics(valid_spiral, reference_spiral) - - data = valid_spiral.data[["x", "y"]].values - ref_data = reference_spiral - - total_duration = valid_spiral.data["seconds"].iloc[-1] - - data_start = data[: int(len(data) * 0.25)] - data_end = data[int(len(data) * 0.75) :] - data_mid = data[int(len(data) * 0.15) : int(len(data) * 0.85)] - - ref_data_start = ref_data[: int(len(ref_data) * 0.25)] - ref_data_end = ref_data[int(len(ref_data) * 0.75) :] - ref_data_mid = ref_data[int(len(ref_data) * 0.15) : int(len(ref_data) * 0.85)] - - dist_matrix = dist.cdist(data, ref_data, "euclidean") - dist_matrix_start = dist.cdist(data_start, ref_data_start, "euclidean") - dist_matrix_end = dist.cdist(data_end, ref_data_end, "euclidean") - dist_matrix_mid = dist.cdist(data_mid, ref_data_mid, "euclidean") - - haus_dist = [ - np.max(np.min(dist_matrix, axis=0)), - np.max(np.min(dist_matrix, axis=1)), - ] - haus_dist_start = [ - np.max(np.min(dist_matrix_start, axis=0)), - np.max(np.min(dist_matrix_start, axis=1)), - ] - haus_dist_end = [ - np.max(np.min(dist_matrix_end, axis=0)), - np.max(np.min(dist_matrix_end, axis=1)), - ] - haus_dist_mid = [ - np.max(np.min(dist_matrix_mid, axis=0)), - np.max(np.min(dist_matrix_mid, axis=1)), - ] - - assert metrics["max_haus_dist"] == np.max(haus_dist) - assert metrics["sum_haus_dist"] == np.sum(haus_dist) - assert metrics["sum_haus_dist_time"] == np.sum(haus_dist) / total_duration - assert metrics["iqr_haus_dist"] == stats.iqr(haus_dist) - assert metrics["max_haus_dist_start"] == np.max(haus_dist_start) / len(data_start) - assert metrics["max_haus_dist_end"] == np.max(haus_dist_end) / len(data_end) - assert metrics["max_haus_dist_mid"] == np.max(haus_dist_mid) - assert metrics["max_haus_dist_mid_time"] == np.max(haus_dist_mid) / total_duration + metrics = distance.calculate_hausdorff_metrics(valid_spiral, ref_spiral) + + expected_metrics = { + "hausdorff_distance_maximum", + "hausdorff_distance_sum", + "hausdorff_distance_sum_per_second", + "hausdorff_distance_interquartile_range", + "hausdorff_distance_start_segment_maximum_normalized", + "hausdorff_distance_end_segment_maximum_normalized", + "hausdorff_distance_middle_segment_maximum", + "hausdorff_distance_middle_segment_maximum_per_second", + } + + assert set(metrics.keys()) == expected_metrics + assert all(isinstance(value, float) for value in metrics.values()) diff --git a/tests/unit/test_reference_spiral.py b/tests/unit/test_reference_spiral.py new file mode 100644 index 0000000..1b39b3e --- /dev/null +++ b/tests/unit/test_reference_spiral.py @@ -0,0 +1,17 @@ +"""Test cases for reference_spiral.py functions.""" + +import numpy as np + +from graphomotor.utils import reference_spiral + + +def test_generate_reference_spiral() -> None: + """Test the generation of a reference spiral.""" + spiral = reference_spiral.generate_reference_spiral() + assert isinstance(spiral, np.ndarray) + assert spiral.shape == (10000, 2) + assert np.array_equal(spiral[0], [50, 50]) + assert np.allclose(spiral[-1], [50 + 1.075 * 8 * np.pi, 50], atol=1e-8) + + distances = np.linalg.norm(np.diff(spiral, axis=0), axis=1) + assert np.allclose(distances, distances[0], atol=1e-4) From 4e7a02d21dbf48f785519dacf38451abaab8a720 Mon Sep 17 00:00:00 2001 From: Alp Erkent Date: Thu, 17 Apr 2025 16:01:35 -0400 Subject: [PATCH 3/5] Refactor reference spiral calculations: rename constants for clarity, edit docstrings to reflect changes, simplify arc length calculations, and modify tests for checking points are distributed equally along the spiral. --- src/graphomotor/utils/reference_spiral.py | 54 +++++++++-------------- tests/unit/test_reference_spiral.py | 29 +++++++++--- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/src/graphomotor/utils/reference_spiral.py b/src/graphomotor/utils/reference_spiral.py index 87f9732..413cb4b 100644 --- a/src/graphomotor/utils/reference_spiral.py +++ b/src/graphomotor/utils/reference_spiral.py @@ -5,13 +5,14 @@ _SPIRAL_CENTER_X = 50 _SPIRAL_CENTER_Y = 50 -_SPIRAL_INITIAL_RADIUS = 0 +_SPIRAL_START_RADIUS = 0 _SPIRAL_GROWTH_RATE = 1.075 -_SPIRAL_TOTAL_ROTATION = 8 * np.pi +_SPIRAL_START_ANGLE = 0 +_SPIRAL_END_ANGLE = 8 * np.pi _SPIRAL_NUM_POINTS = 10000 -def _spiral_arc_length_integrand(t: float) -> float: +def _arc_length_integrand(t: float) -> float: """Calculate the differential arc length at angle t for an Archimedean spiral. Args: @@ -20,37 +21,26 @@ def _spiral_arc_length_integrand(t: float) -> float: Returns: Differential arc length value. """ - r_t = _SPIRAL_INITIAL_RADIUS + _SPIRAL_GROWTH_RATE * t + r_t = _SPIRAL_START_RADIUS + _SPIRAL_GROWTH_RATE * t return np.sqrt(r_t**2 + _SPIRAL_GROWTH_RATE**2) def _calculate_arc_length(theta: float) -> float: - """Calculate the arc length of the spiral from 0 to theta. + """Calculate the arc length of the spiral from _SPIRAL_START_ANGLE to theta. Args: theta: The angle in radians. Returns: - The arc length of the spiral from 0 to theta. + The arc length of the spiral from _SPIRAL_START_ANGLE to theta. """ - return integrate.quad(lambda t: _spiral_arc_length_integrand(t), 0, theta)[0] - - -def _arc_length_difference(theta: float, target_arc_length: float) -> float: - """Function to find the root for a given arc length. - - Args: - theta: Angle to evaluate. - target_arc_length: Target arc length value. - - Returns: - Difference between calculated and target arc length. - """ - return _calculate_arc_length(theta) - target_arc_length + return integrate.quad( + lambda t: _arc_length_integrand(t), _SPIRAL_START_ANGLE, theta + )[0] def _find_theta_for_arc_length(target_arc_length: float) -> float: - """Find the theta value for a given arc length. + """Find the theta value for a given arc length using numerical root finding. Args: target_arc_length: Target arc length. @@ -59,8 +49,8 @@ def _find_theta_for_arc_length(target_arc_length: float) -> float: Angle theta corresponding to the arc length. """ solution = optimize.root_scalar( - lambda theta: _arc_length_difference(theta, target_arc_length), - bracket=[0, _SPIRAL_TOTAL_ROTATION], + lambda theta: _calculate_arc_length(theta) - target_arc_length, + bracket=[_SPIRAL_START_ANGLE, _SPIRAL_END_ANGLE], ) return solution.root @@ -74,7 +64,7 @@ def generate_reference_spiral() -> np.ndarray: form. The algorithm works by: - 1. Computing the total arc length for the entire spiral (0 to 8π), + 1. Computing the total arc length for the entire spiral, 2. Creating equidistant target arc length values, 3. For each target arc length, finding the corresponding theta value that produces that arc length using numerical root finding, @@ -87,22 +77,22 @@ def generate_reference_spiral() -> np.ndarray: - Cartesian coordinates: x = cx + r·cos(θ), y = cy + r·sin(θ) Parameters used: - - Center coordinates: (50, 50) - - Initial radius (a): 0 - - Growth rate (b): 1.075 - - Total rotation: 4 complete revolutions (θ from 0 to 8π) - - Number of points: 10,000 + - Center coordinates: (cx, cy) = (_SPIRAL_CENTER_X, _SPIRAL_CENTER_Y) + - Start radius: a = _SPIRAL_START_RADIUS + - Growth rate: b = _SPIRAL_GROWTH_RATE + - Total rotation: θ = _SPIRAL_END_ANGLE - _SPIRAL_START_ANGLE + - Number of points: N = _SPIRAL_NUM_POINTS Returns: - Array with shape (10000, 2) containing Cartesian coordinates of the spiral. + Array with shape (N, 2) containing Cartesian coordinates of the spiral points. """ - total_arc_length = _calculate_arc_length(_SPIRAL_TOTAL_ROTATION) + total_arc_length = _calculate_arc_length(_SPIRAL_END_ANGLE) arc_length_values = np.linspace(0, total_arc_length, _SPIRAL_NUM_POINTS) theta_values = np.array([_find_theta_for_arc_length(s) for s in arc_length_values]) - r_values = _SPIRAL_INITIAL_RADIUS + _SPIRAL_GROWTH_RATE * theta_values + r_values = _SPIRAL_START_RADIUS + _SPIRAL_GROWTH_RATE * theta_values x_values = _SPIRAL_CENTER_X + r_values * np.cos(theta_values) y_values = _SPIRAL_CENTER_Y + r_values * np.sin(theta_values) diff --git a/tests/unit/test_reference_spiral.py b/tests/unit/test_reference_spiral.py index 1b39b3e..694c48f 100644 --- a/tests/unit/test_reference_spiral.py +++ b/tests/unit/test_reference_spiral.py @@ -8,10 +8,27 @@ def test_generate_reference_spiral() -> None: """Test the generation of a reference spiral.""" spiral = reference_spiral.generate_reference_spiral() - assert isinstance(spiral, np.ndarray) - assert spiral.shape == (10000, 2) - assert np.array_equal(spiral[0], [50, 50]) - assert np.allclose(spiral[-1], [50 + 1.075 * 8 * np.pi, 50], atol=1e-8) + arc_lengths = np.linalg.norm(spiral[1:] - spiral[:-1], axis=1) + mean_arc_length = np.mean(arc_lengths) + + expected_mean_arc_length = reference_spiral._calculate_arc_length( + reference_spiral._SPIRAL_END_ANGLE + ) / (reference_spiral._SPIRAL_NUM_POINTS - 1) - distances = np.linalg.norm(np.diff(spiral, axis=0), axis=1) - assert np.allclose(distances, distances[0], atol=1e-4) + assert isinstance(spiral, np.ndarray) + assert spiral.shape == (reference_spiral._SPIRAL_NUM_POINTS, 2) + assert np.array_equal( + spiral[0], + [reference_spiral._SPIRAL_CENTER_X, reference_spiral._SPIRAL_CENTER_Y], + ) + assert np.allclose( + spiral[-1], + [ + reference_spiral._SPIRAL_CENTER_X + + reference_spiral._SPIRAL_GROWTH_RATE * reference_spiral._SPIRAL_END_ANGLE, + reference_spiral._SPIRAL_CENTER_Y, + ], + atol=1e-8, + ) + assert np.allclose(arc_lengths, mean_arc_length, rtol=1e-3) + assert np.isclose(mean_arc_length, expected_mean_arc_length, rtol=1e-6) From c2ef39f11f7d090d5352150edbcc68b07bc94969 Mon Sep 17 00:00:00 2001 From: Alp Erkent Date: Thu, 17 Apr 2025 16:03:07 -0400 Subject: [PATCH 4/5] Fix typo in docstring for calculate_hausdorff_metrics function --- src/graphomotor/features/distance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphomotor/features/distance.py b/src/graphomotor/features/distance.py index 186e427..03d7dfa 100644 --- a/src/graphomotor/features/distance.py +++ b/src/graphomotor/features/distance.py @@ -35,7 +35,7 @@ def calculate_hausdorff_metrics( """Calculate Hausdorff distance metrics for a spiral object. This function computes multiple features based on the Hausdorff distance between a - drawn spiral and a reference (ideal) spiral, as described in the [1]. Implementation + drawn spiral and a reference (ideal) spiral, as described in [1]. Implementation is based on the original R script provided with the publication. The Hausdorff distance measures the maximum distance of a set to the nearest point in the other set. This metric and its derivatives capture various aspects of the spatial From a86ce3a3e8016c0a6f6958bed4af5726b9bb3b0c Mon Sep 17 00:00:00 2001 From: Alp Erkent Date: Fri, 18 Apr 2025 13:27:45 -0400 Subject: [PATCH 5/5] Reorder calculations in test_generate_reference_spiral so that Arrange, Act, Assert pattern holds --- tests/unit/test_reference_spiral.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_reference_spiral.py b/tests/unit/test_reference_spiral.py index 694c48f..f4fad61 100644 --- a/tests/unit/test_reference_spiral.py +++ b/tests/unit/test_reference_spiral.py @@ -7,14 +7,14 @@ def test_generate_reference_spiral() -> None: """Test the generation of a reference spiral.""" - spiral = reference_spiral.generate_reference_spiral() - arc_lengths = np.linalg.norm(spiral[1:] - spiral[:-1], axis=1) - mean_arc_length = np.mean(arc_lengths) - expected_mean_arc_length = reference_spiral._calculate_arc_length( reference_spiral._SPIRAL_END_ANGLE ) / (reference_spiral._SPIRAL_NUM_POINTS - 1) + spiral = reference_spiral.generate_reference_spiral() + arc_lengths = np.linalg.norm(spiral[1:] - spiral[:-1], axis=1) + mean_arc_length = np.mean(arc_lengths) + assert isinstance(spiral, np.ndarray) assert spiral.shape == (reference_spiral._SPIRAL_NUM_POINTS, 2) assert np.array_equal(