Skip to content

Commit 59caf4e

Browse files
authored
56 task write trails drawing feature functions get total errors (#91)
* wrote function * write unit test * valid unit test * change documentation and edit output from detect_pen_lifts to be the same as get_total_errors * rewrite unit test with actual sample data, fix detect_pen_lift unit tests for new output format
1 parent 391a6a1 commit 59caf4e

2 files changed

Lines changed: 47 additions & 4 deletions

File tree

src/graphomotor/features/trails/drawing_metrics.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from graphomotor.core import models
44

55

6-
def detect_pen_lifts(drawing: models.Drawing) -> int:
6+
def detect_pen_lifts(drawing: models.Drawing) -> dict[str, int]:
77
"""Detect pen lifts during a spiral drawing task.
88
99
Args:
@@ -12,4 +12,20 @@ def detect_pen_lifts(drawing: models.Drawing) -> int:
1212
Returns:
1313
Integer count of pen lifts detected.
1414
"""
15-
return len(drawing.data["line_number"].unique()) - 1
15+
return {"pen_lifts": len(drawing.data["line_number"].unique()) - 1}
16+
17+
18+
def get_total_errors(drawing: models.Drawing) -> dict[str, float]:
19+
"""Extract the total number of errors of a trails drawing task.
20+
21+
Args:
22+
drawing: Drawing object containing drawing data.
23+
24+
Returns:
25+
Dictionary containing the total number of errors of the task.
26+
"""
27+
if "total_number_of_errors" not in drawing.data.columns:
28+
raise ValueError(
29+
"Drawing data does not contain 'total_number_of_errors' column."
30+
)
31+
return {"total_errors": drawing.data["total_number_of_errors"].iloc[0]}
Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,47 @@
11
"""Test cases for drawing_metrics.py functions."""
22

33
import pandas as pd
4+
import pytest
45

56
from graphomotor.core import models
67
from graphomotor.features.trails import drawing_metrics
8+
from graphomotor.io import reader
79

810

911
def test_no_pen_lifts() -> None:
1012
"""Test case with no pen lifts."""
1113
df = pd.DataFrame({"line_number": [1, 1, 1, 1]})
1214
drawing = models.Drawing(data=df, task_name="trails", metadata={"id": "5555555"})
13-
assert drawing_metrics.detect_pen_lifts(drawing) == 0
15+
result = drawing_metrics.detect_pen_lifts(drawing)
16+
assert result == {"pen_lifts": 0}
1417

1518

1619
def test_valid_pen_lifts() -> None:
1720
"""Test case with valid pen lifts."""
1821
df = pd.DataFrame({"line_number": [1, 2, 3, 4]})
1922
drawing = models.Drawing(data=df, task_name="trails", metadata={"id": "5555555"})
20-
assert drawing_metrics.detect_pen_lifts(drawing) == 3
23+
result = drawing_metrics.detect_pen_lifts(drawing)
24+
assert result == {"pen_lifts": 3}
25+
26+
27+
def test_get_total_errors() -> None:
28+
"""Test ValueError when total_number_of_errors column doesn't exist."""
29+
invalid_df = pd.DataFrame({"some_other_column": [0, 1, 2]})
30+
drawing = models.Drawing(
31+
data=invalid_df, task_name="trails", metadata={"id": "5555555"}
32+
)
33+
34+
with pytest.raises(
35+
ValueError,
36+
match="Drawing data does not contain 'total_number_of_errors' column.",
37+
):
38+
drawing_metrics.get_total_errors(drawing)
39+
40+
41+
def test_valid_total_errors() -> None:
42+
"""Test case with valid total_number_of_errors column."""
43+
filepath = "tests/sample_data/[5000000]648b6b868819c1120b4f6ce3-trail4.csv"
44+
drawing = reader.load_drawing_data(filepath)
45+
46+
result = drawing_metrics.get_total_errors(drawing)
47+
assert result == {"total_errors": 1.0}

0 commit comments

Comments
 (0)