From 77f5a19bbbf96108d2363076f166170fda26e4e3 Mon Sep 17 00:00:00 2001 From: Iktae Kim Date: Thu, 12 Feb 2026 16:35:45 -0500 Subject: [PATCH 1/4] Add GridCell model and unit tests for stroke containment --- src/graphomotor/core/models.py | 58 +++++++++++++++++++++++++++++ tests/unit/test_grid_cell.py | 67 ++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 tests/unit/test_grid_cell.py diff --git a/src/graphomotor/core/models.py b/src/graphomotor/core/models.py index 0fd9150..02c6605 100644 --- a/src/graphomotor/core/models.py +++ b/src/graphomotor/core/models.py @@ -125,6 +125,64 @@ def get_extractors( } +@dataclasses.dataclass +class GridCell: + """Represents a single rectangular region in a grid layout. + + Used to assign strokes to letter regions (Alphabet) or digit regions (DSYM). + Boundary policy: a point on the exact boundary is considered inside the cell. + + Attributes: + x_min: Left boundary of the cell. + x_max: Right boundary of the cell. + y_min: Bottom boundary of the cell. + y_max: Top boundary of the cell. + index: Position of the cell in the grid (0-based). + label: Display label for the cell (e.g., 'A', 'B', '1'). + """ + + x_min: float + x_max: float + y_min: float + y_max: float + index: int = 0 + label: str = "" + + def __post_init__(self) -> None: + """Validate that min bounds are strictly less than max bounds. + + Raises: + ValueError: If x_min >= x_max or y_min >= y_max. + """ + if self.x_min >= self.x_max: + raise ValueError( + f"x_min ({self.x_min}) must be less than x_max ({self.x_max})" + ) + if self.y_min >= self.y_max: + raise ValueError( + f"y_min ({self.y_min}) must be less than y_max ({self.y_max})" + ) + + def contains_points(self, points: pd.DataFrame) -> bool: + """Check if a stroke belongs to this cell based on its centroid. + + Computes the centroid (mean x, mean y) of the provided points and checks + whether it falls within the cell boundaries (inclusive). + + Args: + points: DataFrame with 'x' and 'y' columns representing a stroke. + + Returns: + True if the stroke centroid is within the cell, False otherwise. + """ + centroid_x = points["x"].mean() + centroid_y = points["y"].mean() + return ( + self.x_min <= centroid_x <= self.x_max + and self.y_min <= centroid_y <= self.y_max + ) + + @dataclasses.dataclass class CircleTarget: """Represents a target circle in the drawing task. diff --git a/tests/unit/test_grid_cell.py b/tests/unit/test_grid_cell.py new file mode 100644 index 0000000..609de4f --- /dev/null +++ b/tests/unit/test_grid_cell.py @@ -0,0 +1,67 @@ +"""Test cases for the GridCell model.""" + +import pandas as pd +import pytest + +from graphomotor.core import models + + +@pytest.fixture +def cell() -> models.GridCell: + """Create a grid cell representing one letter region.""" + return models.GridCell( + x_min=10.0, x_max=25.0, y_min=80.0, y_max=97.0, index=0, label="A" + ) + + +@pytest.mark.parametrize( + "x_min,x_max,y_min,y_max,expected_error", + [ + (10.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"), + (5.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"), + (0.0, 1.0, 10.0, 5.0, "y_min .* must be less than y_max"), + (0.0, 1.0, 5.0, 5.0, "y_min .* must be less than y_max"), + ], +) +def test_grid_cell_invalid_bounds( + x_min: float, + x_max: float, + y_min: float, + y_max: float, + expected_error: str, +) -> None: + """Test that invalid or equal bounds raise ValueError.""" + with pytest.raises(ValueError, match=expected_error): + models.GridCell(x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max) + + +def test_stroke_centroid_inside_cell(cell: models.GridCell) -> None: + """Stroke whose centroid falls inside the cell should be contained.""" + stroke_points = pd.DataFrame({"x": [15.0, 20.0, 17.0], "y": [85.0, 90.0, 95.0]}) + assert cell.contains_points(stroke_points) + + +def test_stroke_centroid_outside_cell(cell: models.GridCell) -> None: + """Stroke whose centroid falls outside the cell should not be contained.""" + stroke_points = pd.DataFrame({"x": [33.8, 37.8, 35.3], "y": [85.8, 95.4, 90.1]}) + assert not cell.contains_points(stroke_points) + + +def test_stroke_centroid_on_boundary(cell: models.GridCell) -> None: + """Stroke whose centroid lands exactly on the cell boundary should be contained.""" + stroke_points = pd.DataFrame({"x": [10.0, 10.0], "y": [88.0, 89.0]}) + assert cell.contains_points(stroke_points) + + +def test_stroke_points_span_outside_but_centroid_inside( + cell: models.GridCell, +) -> None: + """Stroke with points outside the cell but centroid inside should be contained.""" + stroke_points = pd.DataFrame({"x": [8.0, 22.0], "y": [78.0, 98.0]}) + assert cell.contains_points(stroke_points) + + +def test_single_point_stroke(cell: models.GridCell) -> None: + """Single-point stroke should use that point as its centroid.""" + stroke_points = pd.DataFrame({"x": [17.5], "y": [90.0]}) + assert cell.contains_points(stroke_points) From 4d6ac336bad88f324d2ab08d93c328dfd2daed4e Mon Sep 17 00:00:00 2001 From: Iktae Kim Date: Thu, 12 Feb 2026 16:47:27 -0500 Subject: [PATCH 2/4] using half-open intervals for GridCell boundary checks --- src/graphomotor/core/models.py | 7 ++++--- tests/unit/test_grid_cell.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/graphomotor/core/models.py b/src/graphomotor/core/models.py index 02c6605..6496e64 100644 --- a/src/graphomotor/core/models.py +++ b/src/graphomotor/core/models.py @@ -167,7 +167,8 @@ def contains_points(self, points: pd.DataFrame) -> bool: """Check if a stroke belongs to this cell based on its centroid. Computes the centroid (mean x, mean y) of the provided points and checks - whether it falls within the cell boundaries (inclusive). + whether it falls within the cell boundaries. Uses half-open intervals + [min, max) to prevent double-assignment on shared grid edges. Args: points: DataFrame with 'x' and 'y' columns representing a stroke. @@ -178,8 +179,8 @@ def contains_points(self, points: pd.DataFrame) -> bool: centroid_x = points["x"].mean() centroid_y = points["y"].mean() return ( - self.x_min <= centroid_x <= self.x_max - and self.y_min <= centroid_y <= self.y_max + self.x_min <= centroid_x < self.x_max + and self.y_min <= centroid_y < self.y_max ) diff --git a/tests/unit/test_grid_cell.py b/tests/unit/test_grid_cell.py index 609de4f..2103767 100644 --- a/tests/unit/test_grid_cell.py +++ b/tests/unit/test_grid_cell.py @@ -47,12 +47,18 @@ def test_stroke_centroid_outside_cell(cell: models.GridCell) -> None: assert not cell.contains_points(stroke_points) -def test_stroke_centroid_on_boundary(cell: models.GridCell) -> None: - """Stroke whose centroid lands exactly on the cell boundary should be contained.""" +def test_stroke_centroid_on_lower_boundary(cell: models.GridCell) -> None: + """Stroke whose centroid lands on the lower/left boundary (min) is included.""" stroke_points = pd.DataFrame({"x": [10.0, 10.0], "y": [88.0, 89.0]}) assert cell.contains_points(stroke_points) +def test_stroke_centroid_on_upper_boundary(cell: models.GridCell) -> None: + """Stroke whose centroid lands on the upper/right boundary (max) is excluded.""" + stroke_points = pd.DataFrame({"x": [17.0, 18.0], "y": [97.0, 97.0]}) + assert not cell.contains_points(stroke_points) + + def test_stroke_points_span_outside_but_centroid_inside( cell: models.GridCell, ) -> None: From 1d8d55c185e115e7d2736118d769688b5bc6c2e6 Mon Sep 17 00:00:00 2001 From: Iktae Kim Date: Tue, 24 Feb 2026 11:09:16 -0500 Subject: [PATCH 3/4] Updated the class docstring to note the half-open interval policy and that the outer grid boundaries will be padded at the Grid level. --- src/graphomotor/core/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/graphomotor/core/models.py b/src/graphomotor/core/models.py index 6496e64..29233be 100644 --- a/src/graphomotor/core/models.py +++ b/src/graphomotor/core/models.py @@ -130,7 +130,10 @@ class GridCell: """Represents a single rectangular region in a grid layout. Used to assign strokes to letter regions (Alphabet) or digit regions (DSYM). - Boundary policy: a point on the exact boundary is considered inside the cell. + Boundary policy: uses half-open intervals [min, max) to prevent + double-assignment on shared grid edges. The outer grid boundaries should be + padded (e.g., by 0.1) at the Grid level so that centroids on the outermost + edge are not excluded. Attributes: x_min: Left boundary of the cell. From f2c46d7754d0da9e707c7772b6aee536c5579f2b Mon Sep 17 00:00:00 2001 From: Iktae Kim Date: Mon, 2 Mar 2026 11:18:22 -0500 Subject: [PATCH 4/4] Refactor boundary tests for GridCell to use parameterized inputs --- tests/unit/test_grid_cell.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_grid_cell.py b/tests/unit/test_grid_cell.py index 2103767..83d5a8a 100644 --- a/tests/unit/test_grid_cell.py +++ b/tests/unit/test_grid_cell.py @@ -47,15 +47,35 @@ def test_stroke_centroid_outside_cell(cell: models.GridCell) -> None: assert not cell.contains_points(stroke_points) -def test_stroke_centroid_on_lower_boundary(cell: models.GridCell) -> None: +@pytest.mark.parametrize( + "x_vals,y_vals", + [ + ([10.0, 10.0], [88.0, 89.0]), + ([17.0, 18.0], [80.0, 80.0]), + ], + ids=["left_boundary", "bottom_boundary"], +) +def test_stroke_centroid_on_lower_boundary( + cell: models.GridCell, x_vals: list[float], y_vals: list[float] +) -> None: """Stroke whose centroid lands on the lower/left boundary (min) is included.""" - stroke_points = pd.DataFrame({"x": [10.0, 10.0], "y": [88.0, 89.0]}) + stroke_points = pd.DataFrame({"x": x_vals, "y": y_vals}) assert cell.contains_points(stroke_points) -def test_stroke_centroid_on_upper_boundary(cell: models.GridCell) -> None: +@pytest.mark.parametrize( + "x_vals,y_vals", + [ + ([25.0, 25.0], [88.0, 89.0]), + ([17.0, 18.0], [97.0, 97.0]), + ], + ids=["right_boundary", "top_boundary"], +) +def test_stroke_centroid_on_upper_boundary( + cell: models.GridCell, x_vals: list[float], y_vals: list[float] +) -> None: """Stroke whose centroid lands on the upper/right boundary (max) is excluded.""" - stroke_points = pd.DataFrame({"x": [17.0, 18.0], "y": [97.0, 97.0]}) + stroke_points = pd.DataFrame({"x": x_vals, "y": y_vals}) assert not cell.contains_points(stroke_points)