Skip to content

Commit af9c7ae

Browse files
authored
57 task write trails drawing feature functions percent accurate paths (#92)
* percent_accurate_paths * unit tests * fix docstring, change variable name
1 parent 59caf4e commit af9c7ae

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

src/graphomotor/features/trails/drawing_metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,27 @@ def get_total_errors(drawing: models.Drawing) -> dict[str, float]:
2929
"Drawing data does not contain 'total_number_of_errors' column."
3030
)
3131
return {"total_errors": drawing.data["total_number_of_errors"].iloc[0]}
32+
33+
34+
def percent_accurate_paths(drawing: models.Drawing) -> dict[str, float]:
35+
"""Calculate the percentage of accurate paths in a trails drawing task.
36+
37+
Args:
38+
drawing: Drawing object containing drawing data.
39+
40+
Returns:
41+
Dictionary containing the percentage of accurate paths of the task.
42+
43+
Raises:
44+
ValueError: If required columns are missing in the drawing data.
45+
"""
46+
if not {"correct_path", "actual_path"}.issubset(drawing.data.columns):
47+
raise ValueError(
48+
"DataFrame must contain 'correct_path' and 'actual_path' columns."
49+
)
50+
51+
return {
52+
"percent_accurate_paths": (
53+
(drawing.data["correct_path"] == drawing.data["actual_path"]).mean() * 100
54+
)
55+
}

tests/unit/test_trails_drawing_metrics.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,26 @@ def test_valid_total_errors() -> None:
4545

4646
result = drawing_metrics.get_total_errors(drawing)
4747
assert result == {"total_errors": 1.0}
48+
49+
50+
def test_percent_accurate_paths_missing_columns() -> None:
51+
"""Test ValueError when required columns are missing."""
52+
invalid_df = pd.DataFrame({"some_other_column": [0, 1, 2]})
53+
drawing = models.Drawing(
54+
data=invalid_df, task_name="trails", metadata={"id": "5555555"}
55+
)
56+
57+
with pytest.raises(
58+
ValueError,
59+
match="DataFrame must contain 'correct_path' and 'actual_path' columns.",
60+
):
61+
drawing_metrics.percent_accurate_paths(drawing)
62+
63+
64+
def test_percent_accurate_paths_sample_data() -> None:
65+
"""Test percent_accurate_paths with sample drawing data."""
66+
filepath = "tests/sample_data/[5000000]648b6b868819c1120b4f6ce3-trail4.csv"
67+
drawing = reader.load_drawing_data(filepath)
68+
69+
result = drawing_metrics.percent_accurate_paths(drawing)
70+
assert result == {"percent_accurate_paths": 100.0}

0 commit comments

Comments
 (0)