Skip to content

Commit bd58e18

Browse files
authored
102 task add GridCell model (#119)
* Add GridCell model and unit tests for stroke containment * using half-open intervals for GridCell boundary checks * Updated the class docstring to note the half-open interval policy and that the outer grid boundaries will be padded at the Grid level. * Refactor boundary tests for GridCell to use parameterized inputs
1 parent f8a3e1d commit bd58e18

2 files changed

Lines changed: 155 additions & 0 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,68 @@ def get_extractors(
125125
}
126126

127127

128+
@dataclasses.dataclass
129+
class GridCell:
130+
"""Represents a single rectangular region in a grid layout.
131+
132+
Used to assign strokes to letter regions (Alphabet) or digit regions (DSYM).
133+
Boundary policy: uses half-open intervals [min, max) to prevent
134+
double-assignment on shared grid edges. The outer grid boundaries should be
135+
padded (e.g., by 0.1) at the Grid level so that centroids on the outermost
136+
edge are not excluded.
137+
138+
Attributes:
139+
x_min: Left boundary of the cell.
140+
x_max: Right boundary of the cell.
141+
y_min: Bottom boundary of the cell.
142+
y_max: Top boundary of the cell.
143+
index: Position of the cell in the grid (0-based).
144+
label: Display label for the cell (e.g., 'A', 'B', '1').
145+
"""
146+
147+
x_min: float
148+
x_max: float
149+
y_min: float
150+
y_max: float
151+
index: int = 0
152+
label: str = ""
153+
154+
def __post_init__(self) -> None:
155+
"""Validate that min bounds are strictly less than max bounds.
156+
157+
Raises:
158+
ValueError: If x_min >= x_max or y_min >= y_max.
159+
"""
160+
if self.x_min >= self.x_max:
161+
raise ValueError(
162+
f"x_min ({self.x_min}) must be less than x_max ({self.x_max})"
163+
)
164+
if self.y_min >= self.y_max:
165+
raise ValueError(
166+
f"y_min ({self.y_min}) must be less than y_max ({self.y_max})"
167+
)
168+
169+
def contains_points(self, points: pd.DataFrame) -> bool:
170+
"""Check if a stroke belongs to this cell based on its centroid.
171+
172+
Computes the centroid (mean x, mean y) of the provided points and checks
173+
whether it falls within the cell boundaries. Uses half-open intervals
174+
[min, max) to prevent double-assignment on shared grid edges.
175+
176+
Args:
177+
points: DataFrame with 'x' and 'y' columns representing a stroke.
178+
179+
Returns:
180+
True if the stroke centroid is within the cell, False otherwise.
181+
"""
182+
centroid_x = points["x"].mean()
183+
centroid_y = points["y"].mean()
184+
return (
185+
self.x_min <= centroid_x < self.x_max
186+
and self.y_min <= centroid_y < self.y_max
187+
)
188+
189+
128190
@dataclasses.dataclass
129191
class CircleTarget:
130192
"""Represents a target circle in the drawing task.

tests/unit/test_grid_cell.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Test cases for the GridCell model."""
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from graphomotor.core import models
7+
8+
9+
@pytest.fixture
10+
def cell() -> models.GridCell:
11+
"""Create a grid cell representing one letter region."""
12+
return models.GridCell(
13+
x_min=10.0, x_max=25.0, y_min=80.0, y_max=97.0, index=0, label="A"
14+
)
15+
16+
17+
@pytest.mark.parametrize(
18+
"x_min,x_max,y_min,y_max,expected_error",
19+
[
20+
(10.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"),
21+
(5.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"),
22+
(0.0, 1.0, 10.0, 5.0, "y_min .* must be less than y_max"),
23+
(0.0, 1.0, 5.0, 5.0, "y_min .* must be less than y_max"),
24+
],
25+
)
26+
def test_grid_cell_invalid_bounds(
27+
x_min: float,
28+
x_max: float,
29+
y_min: float,
30+
y_max: float,
31+
expected_error: str,
32+
) -> None:
33+
"""Test that invalid or equal bounds raise ValueError."""
34+
with pytest.raises(ValueError, match=expected_error):
35+
models.GridCell(x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
36+
37+
38+
def test_stroke_centroid_inside_cell(cell: models.GridCell) -> None:
39+
"""Stroke whose centroid falls inside the cell should be contained."""
40+
stroke_points = pd.DataFrame({"x": [15.0, 20.0, 17.0], "y": [85.0, 90.0, 95.0]})
41+
assert cell.contains_points(stroke_points)
42+
43+
44+
def test_stroke_centroid_outside_cell(cell: models.GridCell) -> None:
45+
"""Stroke whose centroid falls outside the cell should not be contained."""
46+
stroke_points = pd.DataFrame({"x": [33.8, 37.8, 35.3], "y": [85.8, 95.4, 90.1]})
47+
assert not cell.contains_points(stroke_points)
48+
49+
50+
@pytest.mark.parametrize(
51+
"x_vals,y_vals",
52+
[
53+
([10.0, 10.0], [88.0, 89.0]),
54+
([17.0, 18.0], [80.0, 80.0]),
55+
],
56+
ids=["left_boundary", "bottom_boundary"],
57+
)
58+
def test_stroke_centroid_on_lower_boundary(
59+
cell: models.GridCell, x_vals: list[float], y_vals: list[float]
60+
) -> None:
61+
"""Stroke whose centroid lands on the lower/left boundary (min) is included."""
62+
stroke_points = pd.DataFrame({"x": x_vals, "y": y_vals})
63+
assert cell.contains_points(stroke_points)
64+
65+
66+
@pytest.mark.parametrize(
67+
"x_vals,y_vals",
68+
[
69+
([25.0, 25.0], [88.0, 89.0]),
70+
([17.0, 18.0], [97.0, 97.0]),
71+
],
72+
ids=["right_boundary", "top_boundary"],
73+
)
74+
def test_stroke_centroid_on_upper_boundary(
75+
cell: models.GridCell, x_vals: list[float], y_vals: list[float]
76+
) -> None:
77+
"""Stroke whose centroid lands on the upper/right boundary (max) is excluded."""
78+
stroke_points = pd.DataFrame({"x": x_vals, "y": y_vals})
79+
assert not cell.contains_points(stroke_points)
80+
81+
82+
def test_stroke_points_span_outside_but_centroid_inside(
83+
cell: models.GridCell,
84+
) -> None:
85+
"""Stroke with points outside the cell but centroid inside should be contained."""
86+
stroke_points = pd.DataFrame({"x": [8.0, 22.0], "y": [78.0, 98.0]})
87+
assert cell.contains_points(stroke_points)
88+
89+
90+
def test_single_point_stroke(cell: models.GridCell) -> None:
91+
"""Single-point stroke should use that point as its centroid."""
92+
stroke_points = pd.DataFrame({"x": [17.5], "y": [90.0]})
93+
assert cell.contains_points(stroke_points)

0 commit comments

Comments
 (0)