Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/graphomotor/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,65 @@ def get_extractors(
}


@dataclasses.dataclass
Comment thread
kimit0310 marked this conversation as resolved.
class GridCell:
"""Represents a single rectangular region in a grid layout.

Used to assign strokes to letter regions (Alphabet) or digit regions (DSYM).
Comment thread
kimit0310 marked this conversation as resolved.
Boundary policy: a point on the exact boundary is considered inside the cell.

Attributes:
x_min: Left boundary of the cell.
x_max: Right boundary of the cell.
y_min: Bottom boundary of the cell.
y_max: Top boundary of the cell.
index: Position of the cell in the grid (0-based).
label: Display label for the cell (e.g., 'A', 'B', '1').
"""

x_min: float
x_max: float
y_min: float
y_max: float
index: int = 0
label: str = ""

def __post_init__(self) -> None:
"""Validate that min bounds are strictly less than max bounds.

Raises:
ValueError: If x_min >= x_max or y_min >= y_max.
"""
if self.x_min >= self.x_max:
raise ValueError(
f"x_min ({self.x_min}) must be less than x_max ({self.x_max})"
)
if self.y_min >= self.y_max:
raise ValueError(
f"y_min ({self.y_min}) must be less than y_max ({self.y_max})"
)

def contains_points(self, points: pd.DataFrame) -> bool:
"""Check if a stroke belongs to this cell based on its centroid.

Computes the centroid (mean x, mean y) of the provided points and checks
whether it falls within the cell boundaries. Uses half-open intervals
[min, max) to prevent double-assignment on shared grid edges.

Args:
points: DataFrame with 'x' and 'y' columns representing a stroke.

Returns:
True if the stroke centroid is within the cell, False otherwise.
"""
centroid_x = points["x"].mean()
centroid_y = points["y"].mean()
return (
Comment thread
kimit0310 marked this conversation as resolved.
self.x_min <= centroid_x < self.x_max
and self.y_min <= centroid_y < self.y_max
)


@dataclasses.dataclass
class CircleTarget:
"""Represents a target circle in the drawing task.
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/test_grid_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test cases for the GridCell model."""

import pandas as pd
import pytest

from graphomotor.core import models


@pytest.fixture
def cell() -> models.GridCell:
"""Create a grid cell representing one letter region."""
return models.GridCell(
x_min=10.0, x_max=25.0, y_min=80.0, y_max=97.0, index=0, label="A"
)


@pytest.mark.parametrize(
"x_min,x_max,y_min,y_max,expected_error",
[
(10.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"),
(5.0, 5.0, 0.0, 1.0, "x_min .* must be less than x_max"),
(0.0, 1.0, 10.0, 5.0, "y_min .* must be less than y_max"),
(0.0, 1.0, 5.0, 5.0, "y_min .* must be less than y_max"),
],
)
def test_grid_cell_invalid_bounds(
x_min: float,
x_max: float,
y_min: float,
y_max: float,
expected_error: str,
) -> None:
"""Test that invalid or equal bounds raise ValueError."""
with pytest.raises(ValueError, match=expected_error):
models.GridCell(x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)


def test_stroke_centroid_inside_cell(cell: models.GridCell) -> None:
"""Stroke whose centroid falls inside the cell should be contained."""
stroke_points = pd.DataFrame({"x": [15.0, 20.0, 17.0], "y": [85.0, 90.0, 95.0]})
assert cell.contains_points(stroke_points)


def test_stroke_centroid_outside_cell(cell: models.GridCell) -> None:
"""Stroke whose centroid falls outside the cell should not be contained."""
stroke_points = pd.DataFrame({"x": [33.8, 37.8, 35.3], "y": [85.8, 95.4, 90.1]})
assert not cell.contains_points(stroke_points)


def test_stroke_centroid_on_lower_boundary(cell: models.GridCell) -> None:
"""Stroke whose centroid lands on the lower/left boundary (min) is included."""
Comment thread
kimit0310 marked this conversation as resolved.
stroke_points = pd.DataFrame({"x": [10.0, 10.0], "y": [88.0, 89.0]})
assert cell.contains_points(stroke_points)


def test_stroke_centroid_on_upper_boundary(cell: models.GridCell) -> None:
"""Stroke whose centroid lands on the upper/right boundary (max) is excluded."""
Comment thread
kimit0310 marked this conversation as resolved.
stroke_points = pd.DataFrame({"x": [17.0, 18.0], "y": [97.0, 97.0]})
assert not cell.contains_points(stroke_points)


def test_stroke_points_span_outside_but_centroid_inside(
cell: models.GridCell,
) -> None:
"""Stroke with points outside the cell but centroid inside should be contained."""
stroke_points = pd.DataFrame({"x": [8.0, 22.0], "y": [78.0, 98.0]})
assert cell.contains_points(stroke_points)


def test_single_point_stroke(cell: models.GridCell) -> None:
"""Single-point stroke should use that point as its centroid."""
stroke_points = pd.DataFrame({"x": [17.5], "y": [90.0]})
assert cell.contains_points(stroke_points)