Skip to content

Commit 06690a7

Browse files
committed
feat: add Grid model and segment_strokes utility function with tests
1 parent d4a7571 commit 06690a7

4 files changed

Lines changed: 496 additions & 0 deletions

File tree

src/graphomotor/core/models.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,117 @@ def contains_points(self, points: pd.DataFrame) -> bool:
191191
)
192192

193193

194+
@dataclasses.dataclass
195+
class Grid:
196+
"""Represents a rectangular grid composed of multiple GridCell objects.
197+
198+
The grid divides a bounding box into rows and columns, where each cell can
199+
hold strokes assigned by centroid location. Cells are ordered left-to-right,
200+
top-to-bottom (row-major order).
201+
202+
Attributes:
203+
cells: List of GridCell objects that compose the grid.
204+
"""
205+
206+
cells: List[GridCell] = dataclasses.field(default_factory=list)
207+
208+
@classmethod
209+
def from_bbox(
210+
cls,
211+
x_min: float,
212+
x_max: float,
213+
y_min: float,
214+
y_max: float,
215+
n_rows: int,
216+
n_cols: int,
217+
labels: Optional[List[str]] = None,
218+
padding: float = 0.1,
219+
) -> "Grid":
220+
"""Create a Grid by subdividing a bounding box into rows and columns.
221+
222+
Cells are generated in row-major order (left-to-right, top-to-bottom).
223+
A small padding is added to the outer boundaries so that centroids
224+
falling exactly on the outermost edge are still captured by a cell.
225+
226+
Args:
227+
x_min: Left boundary of the bounding box.
228+
x_max: Right boundary of the bounding box.
229+
y_min: Bottom boundary of the bounding box.
230+
y_max: Top boundary of the bounding box.
231+
n_rows: Number of rows in the grid.
232+
n_cols: Number of columns in the grid.
233+
labels: Optional list of labels for each cell, assigned in
234+
row-major order. Must have length n_rows * n_cols if provided.
235+
padding: Amount to extend the outer boundaries to capture edge
236+
centroids (default 0.1).
237+
238+
Returns:
239+
A Grid instance populated with GridCell objects.
240+
241+
Raises:
242+
ValueError: If n_rows or n_cols is less than 1, or if the length
243+
of labels does not match n_rows * n_cols.
244+
"""
245+
if n_rows < 1 or n_cols < 1:
246+
raise ValueError("n_rows and n_cols must be at least 1.")
247+
n_cells = n_rows * n_cols
248+
if labels is not None and len(labels) != n_cells:
249+
raise ValueError(
250+
f"labels length ({len(labels)}) must match "
251+
f"n_rows * n_cols ({n_cells})."
252+
)
253+
254+
padded_x_min = x_min - padding
255+
padded_x_max = x_max + padding
256+
padded_y_min = y_min - padding
257+
padded_y_max = y_max + padding
258+
259+
col_width = (padded_x_max - padded_x_min) / n_cols
260+
row_height = (padded_y_max - padded_y_min) / n_rows
261+
262+
cells: List[GridCell] = []
263+
index = 0
264+
for row in range(n_rows):
265+
for col in range(n_cols):
266+
cell_x_min = padded_x_min + col * col_width
267+
cell_x_max = padded_x_min + (col + 1) * col_width
268+
cell_y_min = padded_y_max - (row + 1) * row_height
269+
cell_y_max = padded_y_max - row * row_height
270+
label = labels[index] if labels is not None else ""
271+
cells.append(
272+
GridCell(
273+
x_min=cell_x_min,
274+
x_max=cell_x_max,
275+
y_min=cell_y_min,
276+
y_max=cell_y_max,
277+
index=index,
278+
label=label,
279+
)
280+
)
281+
index += 1
282+
283+
return cls(cells=cells)
284+
285+
def get_cell_for_point(self, x: float, y: float) -> int:
286+
"""Return the index of the cell containing the given point.
287+
288+
Iterates through cells and returns the index of the first cell whose
289+
half-open interval [min, max) contains the point. Returns -1 if no
290+
cell contains the point.
291+
292+
Args:
293+
x: X coordinate of the point.
294+
y: Y coordinate of the point.
295+
296+
Returns:
297+
The index of the matching cell, or -1 if no cell contains the point.
298+
"""
299+
for cell in self.cells:
300+
if cell.x_min <= x < cell.x_max and cell.y_min <= y < cell.y_max:
301+
return cell.index
302+
return -1
303+
304+
194305
@dataclasses.dataclass
195306
class Stroke:
196307
"""Represents a single stroke in an Alphabet or DSYM task.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Utility functions for Alphabet and DSYM stroke segmentation."""
2+
3+
from typing import List, Optional
4+
5+
import pandas as pd
6+
7+
from graphomotor.core import models
8+
9+
10+
def segment_strokes(
11+
data: pd.DataFrame,
12+
x_min: float,
13+
x_max: float,
14+
y_min: float,
15+
y_max: float,
16+
n_rows: int,
17+
n_cols: int,
18+
labels: Optional[List[str]] = None,
19+
) -> models.Grid:
20+
"""Segment drawing data into strokes and assign them to grid cells.
21+
22+
Groups the data by ``line_number`` to create individual
23+
:class:`~graphomotor.core.models.Stroke` objects. Each stroke is assigned
24+
to a :class:`~graphomotor.core.models.GridCell` based on its centroid
25+
(mean x, mean y). The grid is constructed via
26+
:meth:`~graphomotor.core.models.Grid.from_bbox` with default padding so
27+
that centroids on the outermost edge are captured.
28+
29+
Args:
30+
data: DataFrame containing drawing data with at least ``line_number``,
31+
``x``, ``y``, and ``seconds`` columns.
32+
x_min: Left boundary of the grid bounding box.
33+
x_max: Right boundary of the grid bounding box.
34+
y_min: Bottom boundary of the grid bounding box.
35+
y_max: Top boundary of the grid bounding box.
36+
n_rows: Number of rows in the grid.
37+
n_cols: Number of columns in the grid.
38+
labels: Optional list of labels for each cell in row-major order.
39+
Must have length ``n_rows * n_cols`` if provided.
40+
41+
Returns:
42+
A :class:`~graphomotor.core.models.Grid` populated with strokes
43+
assigned to their matching cells.
44+
"""
45+
grid = models.Grid.from_bbox(
46+
x_min=x_min,
47+
x_max=x_max,
48+
y_min=y_min,
49+
y_max=y_max,
50+
n_rows=n_rows,
51+
n_cols=n_cols,
52+
labels=labels,
53+
)
54+
55+
for line_number, group in data.groupby("line_number"):
56+
stroke = models.Stroke(
57+
points=group.reset_index(drop=True),
58+
line_number=int(line_number),
59+
)
60+
61+
centroid_x = group["x"].mean()
62+
centroid_y = group["y"].mean()
63+
cell_index = grid.get_cell_for_point(centroid_x, centroid_y)
64+
65+
if cell_index != -1:
66+
grid.cells[cell_index].strokes.append(stroke)
67+
68+
return grid

tests/unit/test_alphabet_utils.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Test cases for the segment_strokes utility function."""
2+
3+
import pandas as pd
4+
5+
from graphomotor.utils import alphabet_utils
6+
7+
8+
def _make_drawing_data() -> pd.DataFrame:
9+
"""Create drawing data with three strokes in distinct spatial regions.
10+
11+
Stroke 0 has centroid near (5, 85) - top-left of a 2x2 grid.
12+
Stroke 1 has centroid near (55, 85) - top-right.
13+
Stroke 2 has centroid near (5, 15) - bottom-left.
14+
"""
15+
return pd.DataFrame(
16+
{
17+
"line_number": [0, 0, 0, 1, 1, 1, 2, 2, 2],
18+
"x": [3.0, 5.0, 7.0, 53.0, 55.0, 57.0, 3.0, 5.0, 7.0],
19+
"y": [83.0, 85.0, 87.0, 83.0, 85.0, 87.0, 13.0, 15.0, 17.0],
20+
"seconds": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
21+
}
22+
)
23+
24+
25+
class TestSegmentStrokes:
26+
"""Tests for the segment_strokes function."""
27+
28+
def test_strokes_assigned_to_correct_cells(self) -> None:
29+
"""Each stroke should be placed in the cell containing its centroid."""
30+
data = _make_drawing_data()
31+
grid = alphabet_utils.segment_strokes(
32+
data=data,
33+
x_min=0.0, x_max=100.0, y_min=0.0, y_max=100.0,
34+
n_rows=2, n_cols=2, labels=["TL", "TR", "BL", "BR"],
35+
)
36+
37+
assert len(grid.cells[0].strokes) == 1
38+
assert grid.cells[0].strokes[0].line_number == 0
39+
assert len(grid.cells[1].strokes) == 1
40+
assert grid.cells[1].strokes[0].line_number == 1
41+
assert len(grid.cells[2].strokes) == 1
42+
assert grid.cells[2].strokes[0].line_number == 2
43+
assert len(grid.cells[3].strokes) == 0
44+
45+
def test_total_stroke_count_matches_line_numbers(self) -> None:
46+
"""Total strokes across all cells should equal the number of line groups."""
47+
data = _make_drawing_data()
48+
grid = alphabet_utils.segment_strokes(
49+
data=data,
50+
x_min=0.0, x_max=100.0, y_min=0.0, y_max=100.0,
51+
n_rows=2, n_cols=2,
52+
)
53+
54+
total_strokes = sum(len(c.strokes) for c in grid.cells)
55+
assert total_strokes == 3
56+
57+
def test_stroke_points_are_correct(self) -> None:
58+
"""Each Stroke should contain the correct subset of points."""
59+
data = _make_drawing_data()
60+
grid = alphabet_utils.segment_strokes(
61+
data=data,
62+
x_min=0.0, x_max=100.0, y_min=0.0, y_max=100.0,
63+
n_rows=2, n_cols=2,
64+
)
65+
66+
stroke_0 = grid.cells[0].strokes[0]
67+
assert len(stroke_0.points) == 3
68+
assert list(stroke_0.points["x"]) == [3.0, 5.0, 7.0]
69+
70+
def test_empty_dataframe(self) -> None:
71+
"""An empty DataFrame should produce a grid with no strokes."""
72+
data = pd.DataFrame(columns=["line_number", "x", "y", "seconds"])
73+
grid = alphabet_utils.segment_strokes(
74+
data=data,
75+
x_min=0.0, x_max=10.0, y_min=0.0, y_max=10.0,
76+
n_rows=1, n_cols=1,
77+
)
78+
79+
assert all(len(c.strokes) == 0 for c in grid.cells)
80+
81+
def test_stroke_outside_grid_is_not_assigned(self) -> None:
82+
"""Strokes whose centroids fall outside the grid should be dropped."""
83+
data = pd.DataFrame(
84+
{
85+
"line_number": [0, 0],
86+
"x": [500.0, 600.0],
87+
"y": [500.0, 600.0],
88+
"seconds": [0.0, 0.1],
89+
}
90+
)
91+
grid = alphabet_utils.segment_strokes(
92+
data=data,
93+
x_min=0.0, x_max=10.0, y_min=0.0, y_max=10.0,
94+
n_rows=1, n_cols=1,
95+
)
96+
97+
assert len(grid.cells[0].strokes) == 0
98+
99+
def test_multiple_strokes_in_same_cell(self) -> None:
100+
"""Multiple strokes in the same spatial region should all land in one cell."""
101+
data = pd.DataFrame(
102+
{
103+
"line_number": [0, 0, 1, 1, 2, 2],
104+
"x": [5.0, 6.0, 5.5, 6.5, 4.5, 5.5],
105+
"y": [5.0, 6.0, 5.5, 6.5, 4.5, 5.5],
106+
"seconds": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
107+
}
108+
)
109+
grid = alphabet_utils.segment_strokes(
110+
data=data,
111+
x_min=0.0, x_max=10.0, y_min=0.0, y_max=10.0,
112+
n_rows=1, n_cols=1,
113+
)
114+
115+
assert len(grid.cells[0].strokes) == 3
116+
117+
def test_grid_structure_matches_parameters(self) -> None:
118+
"""Returned grid should have the correct number of labeled cells."""
119+
data = _make_drawing_data()
120+
labels = ["A", "B", "C", "D", "E", "F"]
121+
grid = alphabet_utils.segment_strokes(
122+
data=data,
123+
x_min=0.0, x_max=100.0, y_min=0.0, y_max=100.0,
124+
n_rows=2, n_cols=3, labels=labels,
125+
)
126+
127+
assert len(grid.cells) == 6
128+
assert [c.label for c in grid.cells] == labels
129+
130+
def test_stroke_index_is_reset(self) -> None:
131+
"""Stroke points should have a reset index starting from 0."""
132+
data = pd.DataFrame(
133+
{
134+
"line_number": [5, 5, 5],
135+
"x": [1.0, 2.0, 3.0],
136+
"y": [1.0, 2.0, 3.0],
137+
"seconds": [0.0, 0.1, 0.2],
138+
},
139+
index=[10, 11, 12],
140+
)
141+
grid = alphabet_utils.segment_strokes(
142+
data=data,
143+
x_min=0.0, x_max=10.0, y_min=0.0, y_max=10.0,
144+
n_rows=1, n_cols=1,
145+
)
146+
147+
stroke = grid.cells[0].strokes[0]
148+
assert list(stroke.points.index) == [0, 1, 2]

0 commit comments

Comments
 (0)