Skip to content

Commit 77f5a19

Browse files
committed
Add GridCell model and unit tests for stroke containment
1 parent da2c270 commit 77f5a19

2 files changed

Lines changed: 125 additions & 0 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,64 @@ def get_extractors(
125125
}
126126

127127

128+
@dataclasses.dataclass
129+
class GridCell:
130+
"""Represents a single rectangular region in a grid layout.
131+
132+
Used to assign strokes to letter regions (Alphabet) or digit regions (DSYM).
133+
Boundary policy: a point on the exact boundary is considered inside the cell.
134+
135+
Attributes:
136+
x_min: Left boundary of the cell.
137+
x_max: Right boundary of the cell.
138+
y_min: Bottom boundary of the cell.
139+
y_max: Top boundary of the cell.
140+
index: Position of the cell in the grid (0-based).
141+
label: Display label for the cell (e.g., 'A', 'B', '1').
142+
"""
143+
144+
x_min: float
145+
x_max: float
146+
y_min: float
147+
y_max: float
148+
index: int = 0
149+
label: str = ""
150+
151+
def __post_init__(self) -> None:
152+
"""Validate that min bounds are strictly less than max bounds.
153+
154+
Raises:
155+
ValueError: If x_min >= x_max or y_min >= y_max.
156+
"""
157+
if self.x_min >= self.x_max:
158+
raise ValueError(
159+
f"x_min ({self.x_min}) must be less than x_max ({self.x_max})"
160+
)
161+
if self.y_min >= self.y_max:
162+
raise ValueError(
163+
f"y_min ({self.y_min}) must be less than y_max ({self.y_max})"
164+
)
165+
166+
def contains_points(self, points: pd.DataFrame) -> bool:
167+
"""Check if a stroke belongs to this cell based on its centroid.
168+
169+
Computes the centroid (mean x, mean y) of the provided points and checks
170+
whether it falls within the cell boundaries (inclusive).
171+
172+
Args:
173+
points: DataFrame with 'x' and 'y' columns representing a stroke.
174+
175+
Returns:
176+
True if the stroke centroid is within the cell, False otherwise.
177+
"""
178+
centroid_x = points["x"].mean()
179+
centroid_y = points["y"].mean()
180+
return (
181+
self.x_min <= centroid_x <= self.x_max
182+
and self.y_min <= centroid_y <= self.y_max
183+
)
184+
185+
128186
@dataclasses.dataclass
129187
class CircleTarget:
130188
"""Represents a target circle in the drawing task.

tests/unit/test_grid_cell.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Test cases for the GridCell model."""
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from graphomotor.core import models
7+
8+
9+
@pytest.fixture
10+
def cell() -> models.GridCell:
11+
"""Create a grid cell representing one letter region."""
12+
return models.GridCell(
13+
x_min=10.0, x_max=25.0, y_min=80.0, y_max=97.0, index=0, label="A"
14+
)
15+
16+
17+
@pytest.mark.parametrize(
18+
"x_min,x_max,y_min,y_max,expected_error",
19+
[
20+
(10.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"),
21+
(5.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"),
22+
(0.0, 1.0, 10.0, 5.0, "y_min .* must be less than y_max"),
23+
(0.0, 1.0, 5.0, 5.0, "y_min .* must be less than y_max"),
24+
],
25+
)
26+
def test_grid_cell_invalid_bounds(
27+
x_min: float,
28+
x_max: float,
29+
y_min: float,
30+
y_max: float,
31+
expected_error: str,
32+
) -> None:
33+
"""Test that invalid or equal bounds raise ValueError."""
34+
with pytest.raises(ValueError, match=expected_error):
35+
models.GridCell(x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
36+
37+
38+
def test_stroke_centroid_inside_cell(cell: models.GridCell) -> None:
39+
"""Stroke whose centroid falls inside the cell should be contained."""
40+
stroke_points = pd.DataFrame({"x": [15.0, 20.0, 17.0], "y": [85.0, 90.0, 95.0]})
41+
assert cell.contains_points(stroke_points)
42+
43+
44+
def test_stroke_centroid_outside_cell(cell: models.GridCell) -> None:
45+
"""Stroke whose centroid falls outside the cell should not be contained."""
46+
stroke_points = pd.DataFrame({"x": [33.8, 37.8, 35.3], "y": [85.8, 95.4, 90.1]})
47+
assert not cell.contains_points(stroke_points)
48+
49+
50+
def test_stroke_centroid_on_boundary(cell: models.GridCell) -> None:
51+
"""Stroke whose centroid lands exactly on the cell boundary should be contained."""
52+
stroke_points = pd.DataFrame({"x": [10.0, 10.0], "y": [88.0, 89.0]})
53+
assert cell.contains_points(stroke_points)
54+
55+
56+
def test_stroke_points_span_outside_but_centroid_inside(
57+
cell: models.GridCell,
58+
) -> None:
59+
"""Stroke with points outside the cell but centroid inside should be contained."""
60+
stroke_points = pd.DataFrame({"x": [8.0, 22.0], "y": [78.0, 98.0]})
61+
assert cell.contains_points(stroke_points)
62+
63+
64+
def test_single_point_stroke(cell: models.GridCell) -> None:
65+
"""Single-point stroke should use that point as its centroid."""
66+
stroke_points = pd.DataFrame({"x": [17.5], "y": [90.0]})
67+
assert cell.contains_points(stroke_points)

0 commit comments

Comments
 (0)