diff --git a/pyproject.toml b/pyproject.toml index fb5be2c..7bbec5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ requires-python = ">=3.12" dependencies = [ "pandas>=2.2.3", "pydantic>=2.11.1", - "scipy>=1.15.2" + "scipy>=1.15.2", + "shapely>=2.1.0" ] [dependency-groups] diff --git a/src/graphomotor/features/distance.py b/src/graphomotor/features/distance.py index 03d7dfa..9491dc7 100644 --- a/src/graphomotor/features/distance.py +++ b/src/graphomotor/features/distance.py @@ -31,7 +31,7 @@ def _segment_data(data: np.ndarray, start_prop: float, end_prop: float) -> np.nd def calculate_hausdorff_metrics( spiral: models.Spiral, reference_spiral: np.ndarray -) -> dict: +) -> dict[str, float]: """Calculate Hausdorff distance metrics for a spiral object. This function computes multiple features based on the Hausdorff distance between a diff --git a/src/graphomotor/features/drawing_error.py b/src/graphomotor/features/drawing_error.py new file mode 100644 index 0000000..efe314a --- /dev/null +++ b/src/graphomotor/features/drawing_error.py @@ -0,0 +1,36 @@ +"""Feature extraction module for drawing error-based metrics in spiral drawing data.""" + +import numpy as np +from shapely import geometry, ops + +from graphomotor.core import models + + +def calculate_area_under_curve( + drawn_spiral: models.Spiral, reference_spiral: np.ndarray +) -> dict[str, float]: + """Calculate the area between drawn and reference spirals. + + This function measures the deviation between drawn and reference spirals by + computing the enclosed area between them using the shapely library. Lower values + indicate better adherence to the template. The algorithm works by creating polygons + that connect spiral endpoints, finding intersections between lines, and calculating + the total area of the resulting polygons. + + Args: + drawn_spiral: The spiral drawn by the subject. + reference_spiral: The reference spiral. + + Returns: + Dictionary containing the area under curve metric + """ + spiral = drawn_spiral.data[["x", "y"]].values + line_drawn = geometry.LineString(spiral) + line_reference = geometry.LineString(reference_spiral) + first_segment = geometry.LineString([spiral[0], reference_spiral[0]]) + last_segment = geometry.LineString([spiral[-1], reference_spiral[-1]]) + merged_line = ops.unary_union( + [line_drawn, line_reference, first_segment, last_segment] + ) + polygons = list(ops.polygonize(merged_line)) + return {"area_under_curve": sum(p.area for p in polygons)} diff --git a/tests/unit/test_drawing_error.py b/tests/unit/test_drawing_error.py new file mode 100644 index 0000000..1d5b01f --- /dev/null +++ b/tests/unit/test_drawing_error.py @@ -0,0 +1,23 @@ +"""Test cases for drawing_error.py functions.""" + +import numpy as np +import pandas as pd + +from graphomotor.core import models +from graphomotor.features import drawing_error + + +def test_calculate_area_under_curve(valid_spiral: models.Spiral) -> None: + """Test that the area under the curve is calculated correctly.""" + x = np.linspace(-np.pi / 2, 3 * np.pi / 2, 100) + y1 = np.sin(x) + y2 = np.sin(x + np.pi) + + expected_area = 8.0 + + valid_spiral.data = pd.DataFrame({"x": x, "y": y1}) + calculated_area = drawing_error.calculate_area_under_curve( + valid_spiral, np.column_stack((x, y2)) + )["area_under_curve"] + + assert np.isclose(calculated_area, expected_area, rtol=1e-3) diff --git a/uv.lock b/uv.lock index bf0402a..2b4a7f0 100644 --- a/uv.lock +++ b/uv.lock @@ -93,6 +93,7 @@ dependencies = [ { name = "pandas" }, { name = "pydantic" }, { name = "scipy" }, + { name = "shapely" }, ] [package.dev-dependencies] @@ -112,6 +113,7 @@ requires-dist = [ { name = "pandas", specifier = ">=2.2.3" }, { name = "pydantic", specifier = ">=2.11.1" }, { name = "scipy", specifier = ">=1.15.2" }, + { name = "shapely", specifier = ">=2.1.0" }, ] [package.metadata.requires-dev] @@ -568,6 +570,41 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/c8/b3f566db71461cabd4b2d5b39bcc24a7e1c119535c8361f81426be39bb47/scipy-1.15.2-cp313-cp313t-win_amd64.whl", hash = "sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db", size = 40477705 }, ] +[[package]] +name = "shapely" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fb/fe/3b0d2f828ffaceadcdcb51b75b9c62d98e62dd95ce575278de35f24a1c20/shapely-2.1.0.tar.gz", hash = "sha256:2cbe90e86fa8fc3ca8af6ffb00a77b246b918c7cf28677b7c21489b678f6b02e", size = 313617 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/d1/6a9371ec39d3ef08e13225594e6c55b045209629afd9e6d403204507c2a8/shapely-2.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:53e7ee8bd8609cf12ee6dce01ea5affe676976cf7049315751d53d8db6d2b4b2", size = 1830732 }, + { url = "https://files.pythonhosted.org/packages/32/87/799e3e48be7ce848c08509b94d2180f4ddb02e846e3c62d0af33da4d78d3/shapely-2.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3cab20b665d26dbec0b380e15749bea720885a481fa7b1eedc88195d4a98cfa4", size = 1638404 }, + { url = "https://files.pythonhosted.org/packages/85/00/6665d77f9dd09478ab0993b8bc31668aec4fd3e5f1ddd1b28dd5830e47be/shapely-2.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4a38b39a09340273c3c92b3b9a374272a12cc7e468aeeea22c1c46217a03e5c", size = 2945316 }, + { url = "https://files.pythonhosted.org/packages/34/49/738e07d10bbc67cae0dcfe5a484c6e518a517f4f90550dda2adf3a78b9f2/shapely-2.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:edaec656bdd9b71278b98e6f77c464b1c3b2daa9eace78012ff0f0b4b5b15b04", size = 3063099 }, + { url = "https://files.pythonhosted.org/packages/88/b8/138098674559362ab29f152bff3b6630de423378fbb0324812742433a4ef/shapely-2.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c8a732ddd9b25e7a54aa748e7df8fd704e23e5d5d35b7d376d80bffbfc376d04", size = 3887873 }, + { url = "https://files.pythonhosted.org/packages/67/a8/fdae7c2db009244991d86f4d2ca09d2f5ccc9d41c312c3b1ee1404dc55da/shapely-2.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9c93693ad8adfdc9138a5a2d42da02da94f728dd2e82d2f0f442f10e25027f5f", size = 4067004 }, + { url = "https://files.pythonhosted.org/packages/ed/78/17e17d91b489019379df3ee1afc4bd39787b232aaa1d540f7d376f0280b7/shapely-2.1.0-cp312-cp312-win32.whl", hash = "sha256:d8ac6604eefe807e71a908524de23a37920133a1729fe3a4dfe0ed82c044cbf4", size = 1527366 }, + { url = "https://files.pythonhosted.org/packages/b8/bd/9249bd6dda948441e25e4fb14cbbb5205146b0fff12c66b19331f1ff2141/shapely-2.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:f4f47e631aa4f9ec5576eac546eb3f38802e2f82aeb0552f9612cb9a14ece1db", size = 1708265 }, + { url = "https://files.pythonhosted.org/packages/8d/77/4e368704b2193e74498473db4461d697cc6083c96f8039367e59009d78bd/shapely-2.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b64423295b563f43a043eb786e7a03200ebe68698e36d2b4b1c39f31dfb50dfb", size = 1830029 }, + { url = "https://files.pythonhosted.org/packages/71/3c/d888597bda680e4de987316b05ca9db07416fa29523beff64f846503302f/shapely-2.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1b5578f45adc25b235b22d1ccb9a0348c8dc36f31983e57ea129a88f96f7b870", size = 1637999 }, + { url = "https://files.pythonhosted.org/packages/03/8d/ee0e23b7ef88fba353c63a81f1f329c77f5703835db7b165e7c0b8b7f839/shapely-2.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1a7e83d383b27f02b684e50ab7f34e511c92e33b6ca164a6a9065705dd64bcb", size = 2929348 }, + { url = "https://files.pythonhosted.org/packages/d1/a7/5c9cb413e4e2ce52c16be717e94abd40ce91b1f8974624d5d56154c5d40b/shapely-2.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:942031eb4d8f7b3b22f43ba42c09c7aa3d843aa10d5cc1619fe816e923b66e55", size = 3048973 }, + { url = "https://files.pythonhosted.org/packages/84/23/45b90c0bd2157b238490ca56ef2eedf959d3514c7d05475f497a2c88b6d9/shapely-2.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d2843c456a2e5627ee6271800f07277c0d2652fb287bf66464571a057dbc00b3", size = 3873148 }, + { url = "https://files.pythonhosted.org/packages/c0/bc/ed7d5d37f5395166042576f0c55a12d7e56102799464ba7ea3a72a38c769/shapely-2.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8c4b17469b7f39a5e6a7cfea79f38ae08a275427f41fe8b48c372e1449147908", size = 4052655 }, + { url = "https://files.pythonhosted.org/packages/c0/8f/a1dafbb10d20d1c569f2db3fb1235488f624dafe8469e8ce65356800ba31/shapely-2.1.0-cp313-cp313-win32.whl", hash = "sha256:30e967abd08fce49513d4187c01b19f139084019f33bec0673e8dbeb557c45e4", size = 1526600 }, + { url = "https://files.pythonhosted.org/packages/e3/f0/9f8cdf2258d7aed742459cea51c70d184de92f5d2d6f5f7f1ded90a18c31/shapely-2.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:1dc8d4364483a14aba4c844b7bd16a6fa3728887e2c33dfa1afa34a3cf4d08a5", size = 1707115 }, + { url = "https://files.pythonhosted.org/packages/75/ed/32952df461753a65b3e5d24c8efb361d3a80aafaef0b70d419063f6f2c11/shapely-2.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:673e073fea099d1c82f666fb7ab0a00a77eff2999130a69357ce11941260d855", size = 1824847 }, + { url = "https://files.pythonhosted.org/packages/ff/b9/2284de512af30b02f93ddcdd2e5c79834a3cf47fa3ca11b0f74396feb046/shapely-2.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6d1513f915a56de67659fe2047c1ad5ff0f8cbff3519d1e74fced69c9cb0e7da", size = 1631035 }, + { url = "https://files.pythonhosted.org/packages/35/16/a59f252a7e736b73008f10d0950ffeeb0d5953be7c0bdffd39a02a6ba310/shapely-2.1.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d6a7043178890b9e028d80496ff4c79dc7629bff4d78a2f25323b661756bab8", size = 2968639 }, + { url = "https://files.pythonhosted.org/packages/a5/0a/6a20eca7b0092cfa243117e8e145a58631a4833a0a519ec9b445172e83a0/shapely-2.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb638378dc3d76f7e85b67d7e2bb1366811912430ac9247ac00c127c2b444cdc", size = 3055713 }, + { url = "https://files.pythonhosted.org/packages/fb/44/eeb0c7583b1453d1cf7a319a1d738e08f98a5dc993fa1ef3c372983e4cb5/shapely-2.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:737124e87d91d616acf9a911f74ac55e05db02a43a6a7245b3d663817b876055", size = 3890478 }, + { url = "https://files.pythonhosted.org/packages/5d/6e/37ff3c6af1d408cacb0a7d7bfea7b8ab163a5486e35acb08997eae9d8756/shapely-2.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8e6c229e7bb87aae5df82fa00b6718987a43ec168cc5affe095cca59d233f314", size = 4036148 }, + { url = "https://files.pythonhosted.org/packages/c8/6a/8c0b7de3aeb5014a23f06c5e9d3c7852ebcf0d6b00fe660b93261e310e24/shapely-2.1.0-cp313-cp313t-win32.whl", hash = "sha256:a9580bda119b1f42f955aa8e52382d5c73f7957e0203bc0c0c60084846f3db94", size = 1535993 }, + { url = "https://files.pythonhosted.org/packages/a8/91/ae80359a58409d52e4d62c7eacc7eb3ddee4b9135f1db884b6a43cf2e174/shapely-2.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e8ff4e5cfd799ba5b6f37b5d5527dbd85b4a47c65b6d459a03d0962d2a9d4d10", size = 1717777 }, +] + [[package]] name = "six" version = "1.17.0"