-
Notifications
You must be signed in to change notification settings - Fork 0
104 Add Grid model and segment_strokes utility function #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -191,6 +191,116 @@ def contains_points(self, points: pd.DataFrame) -> bool: | |
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Grid: | ||
| """Represents a rectangular grid composed of multiple GridCell objects. | ||
|
|
||
| The grid divides a bounding box into rows and columns, where each cell can | ||
| hold strokes assigned by centroid location. Cells are ordered left-to-right, | ||
| top-to-bottom (row-major order). | ||
|
|
||
| Attributes: | ||
| cells: List of GridCell objects that compose the grid. | ||
| """ | ||
|
|
||
| cells: List[GridCell] = dataclasses.field(default_factory=list) | ||
|
|
||
| @classmethod | ||
| def from_bbox( | ||
| cls, | ||
| x_min: float, | ||
| x_max: float, | ||
| y_min: float, | ||
| y_max: float, | ||
| n_rows: int, | ||
| n_cols: int, | ||
| labels: Optional[List[str]] = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should labels be optional? |
||
| padding: float = 0.1, | ||
| ) -> "Grid": | ||
| """Create a Grid by subdividing a bounding box into rows and columns. | ||
|
|
||
| Cells are generated in row-major order (left-to-right, top-to-bottom). | ||
| A small padding is added to the outer boundaries so that centroids | ||
| falling exactly on the outermost edge are still captured by a cell. | ||
|
|
||
| Args: | ||
| x_min: Left boundary of the bounding box. | ||
| x_max: Right boundary of the bounding box. | ||
| y_min: Bottom boundary of the bounding box. | ||
| y_max: Top boundary of the bounding box. | ||
| n_rows: Number of rows in the grid. | ||
| n_cols: Number of columns in the grid. | ||
| labels: Optional list of labels for each cell, assigned in | ||
| row-major order. Must have length n_rows * n_cols if provided. | ||
| padding: Amount to extend the outer boundaries to capture edge | ||
| centroids (default 0.1). | ||
|
|
||
| Returns: | ||
| A Grid instance populated with GridCell objects. | ||
|
|
||
| Raises: | ||
| ValueError: If n_rows or n_cols is less than 1, or if the length | ||
| of labels does not match n_rows * n_cols. | ||
| """ | ||
| if n_rows < 1 or n_cols < 1: | ||
| raise ValueError("n_rows and n_cols must be at least 1.") | ||
| n_cells = n_rows * n_cols | ||
| if labels is not None and len(labels) != n_cells: | ||
| raise ValueError( | ||
| f"labels length ({len(labels)}) must match n_rows * n_cols ({n_cells})." | ||
| ) | ||
|
|
||
| padded_x_min = x_min - padding | ||
| padded_x_max = x_max + padding | ||
| padded_y_min = y_min - padding | ||
| padded_y_max = y_max + padding | ||
|
|
||
| col_width = (padded_x_max - padded_x_min) / n_cols | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you take the padding into effect for the col/row size? it should only be for the start/ends correct? this padding was just for when a centroid lands on the outer boundary?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run this example and check: 3x3 grid, from [0,30] for x,y. now check you cell_x_min and cell_x_max for row,cols (1,2) |
||
| row_height = (padded_y_max - padded_y_min) / n_rows | ||
|
|
||
| cells: List[GridCell] = [] | ||
| index = 0 | ||
| for row in range(n_rows): | ||
| for col in range(n_cols): | ||
| cell_x_min = padded_x_min + col * col_width | ||
| cell_x_max = padded_x_min + (col + 1) * col_width | ||
| cell_y_min = padded_y_max - (row + 1) * row_height | ||
| cell_y_max = padded_y_max - row * row_height | ||
| label = labels[index] if labels is not None else "" | ||
| cells.append( | ||
| GridCell( | ||
| x_min=cell_x_min, | ||
| x_max=cell_x_max, | ||
| y_min=cell_y_min, | ||
| y_max=cell_y_max, | ||
| index=index, | ||
| label=label, | ||
| ) | ||
| ) | ||
| index += 1 | ||
|
|
||
| return cls(cells=cells) | ||
|
|
||
| def get_cell_for_point(self, x: float, y: float) -> int: | ||
| """Return the index of the cell containing the given point. | ||
|
|
||
| Iterates through cells and returns the index of the first cell whose | ||
| half-open interval [min, max) contains the point. Returns -1 if no | ||
| cell contains the point. | ||
|
|
||
| Args: | ||
| x: X coordinate of the point. | ||
| y: Y coordinate of the point. | ||
|
|
||
| Returns: | ||
| The index of the matching cell, or -1 if no cell contains the point. | ||
| """ | ||
| for cell in self.cells: | ||
| if cell.x_min <= x < cell.x_max and cell.y_min <= y < cell.y_max: | ||
| return cell.index | ||
| return -1 | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Stroke: | ||
| """Represents a single stroke in an Alphabet or DSYM task. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| """Utility functions for Alphabet and DSYM stroke segmentation.""" | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from graphomotor.core import models | ||
|
|
||
|
|
||
| def segment_strokes( | ||
| data: pd.DataFrame, | ||
| x_min: float, | ||
| x_max: float, | ||
| y_min: float, | ||
| y_max: float, | ||
| n_rows: int, | ||
| n_cols: int, | ||
| labels: Optional[List[str]] = None, | ||
| ) -> models.Grid: | ||
| """Segment drawing data into strokes and assign them to grid cells. | ||
|
|
||
| Groups the data by line_number to create individual Stroke objects. Each | ||
| stroke is assigned to a GridCell based on its centroid (mean x, mean y). | ||
| The grid is constructed via Grid.from_bbox with default padding so that | ||
| centroids on the outermost edge are captured. | ||
|
|
||
| Args: | ||
| data: DataFrame containing drawing data with at least line_number, | ||
| x, y, and seconds columns. | ||
| x_min: Left boundary of the grid bounding box. | ||
| x_max: Right boundary of the grid bounding box. | ||
| y_min: Bottom boundary of the grid bounding box. | ||
| y_max: Top boundary of the grid bounding box. | ||
| n_rows: Number of rows in the grid. | ||
| n_cols: Number of columns in the grid. | ||
| labels: Optional list of labels for each cell in row-major order. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As before should this be optional? |
||
| Must have length n_rows * n_cols if provided. | ||
|
|
||
| Returns: | ||
| A Grid populated with strokes assigned to their matching cells. | ||
| """ | ||
| grid = models.Grid.from_bbox( | ||
| x_min=x_min, | ||
| x_max=x_max, | ||
| y_min=y_min, | ||
| y_max=y_max, | ||
| n_rows=n_rows, | ||
| n_cols=n_cols, | ||
| labels=labels, | ||
| ) | ||
|
|
||
| for line_number, group in data.groupby("line_number"): | ||
| stroke = models.Stroke( | ||
| points=group.reset_index(drop=True), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does reset_index do here? |
||
| line_number=int(line_number), | ||
| ) | ||
|
|
||
| centroid_x = group["x"].mean() | ||
| centroid_y = group["y"].mean() | ||
| cell_index = grid.get_cell_for_point(centroid_x, centroid_y) | ||
|
|
||
| if cell_index != -1: | ||
| grid.cells[cell_index].strokes.append(stroke) | ||
|
|
||
| return grid | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| """Test cases for the segment_strokes utility function.""" | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from graphomotor.utils import alphabet_utils | ||
|
|
||
|
|
||
| def _make_drawing_data() -> pd.DataFrame: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not make this a fixture |
||
| """Create drawing data with three strokes in distinct spatial regions. | ||
|
|
||
| Stroke 0 has centroid near (5, 85) - top-left of a 2x2 grid. | ||
| Stroke 1 has centroid near (55, 85) - top-right. | ||
| Stroke 2 has centroid near (5, 15) - bottom-left. | ||
| """ | ||
| return pd.DataFrame( | ||
| { | ||
| "line_number": [0, 0, 0, 1, 1, 1, 2, 2, 2], | ||
| "x": [3.0, 5.0, 7.0, 53.0, 55.0, 57.0, 3.0, 5.0, 7.0], | ||
| "y": [83.0, 85.0, 87.0, 83.0, 85.0, 87.0, 13.0, 15.0, 17.0], | ||
| "seconds": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| class TestSegmentStrokes: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since when do we make a class for testing? |
||
| """Tests for the segment_strokes function.""" | ||
|
|
||
| def test_strokes_assigned_to_correct_cells(self) -> None: | ||
| """Each stroke should be placed in the cell containing its centroid.""" | ||
| data = _make_drawing_data() | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=100.0, | ||
| y_min=0.0, | ||
| y_max=100.0, | ||
| n_rows=2, | ||
| n_cols=2, | ||
| labels=["TL", "TR", "BL", "BR"], | ||
| ) | ||
|
|
||
| assert len(grid.cells[0].strokes) == 1 | ||
| assert grid.cells[0].strokes[0].line_number == 0 | ||
| assert len(grid.cells[1].strokes) == 1 | ||
| assert grid.cells[1].strokes[0].line_number == 1 | ||
| assert len(grid.cells[2].strokes) == 1 | ||
| assert grid.cells[2].strokes[0].line_number == 2 | ||
| assert len(grid.cells[3].strokes) == 0 | ||
|
|
||
| def test_total_stroke_count_matches_line_numbers(self) -> None: | ||
| """Total strokes across all cells should equal the number of line groups.""" | ||
| data = _make_drawing_data() | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=100.0, | ||
| y_min=0.0, | ||
| y_max=100.0, | ||
| n_rows=2, | ||
| n_cols=2, | ||
| ) | ||
|
|
||
| total_strokes = sum(len(c.strokes) for c in grid.cells) | ||
| assert total_strokes == 3 | ||
|
|
||
| def test_stroke_points_are_correct(self) -> None: | ||
| """Each Stroke should contain the correct subset of points.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this only tests one stroke |
||
| data = _make_drawing_data() | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=100.0, | ||
| y_min=0.0, | ||
| y_max=100.0, | ||
| n_rows=2, | ||
| n_cols=2, | ||
| ) | ||
|
|
||
| stroke_0 = grid.cells[0].strokes[0] | ||
| assert len(stroke_0.points) == 3 | ||
| assert list(stroke_0.points["x"]) == [3.0, 5.0, 7.0] | ||
|
|
||
| def test_empty_dataframe(self) -> None: | ||
| """An empty DataFrame should produce a grid with no strokes.""" | ||
| data = pd.DataFrame(columns=["line_number", "x", "y", "seconds"]) | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=10.0, | ||
| y_min=0.0, | ||
| y_max=10.0, | ||
| n_rows=1, | ||
| n_cols=1, | ||
| ) | ||
|
|
||
| assert all(len(c.strokes) == 0 for c in grid.cells) | ||
|
|
||
| def test_stroke_outside_grid_is_not_assigned(self) -> None: | ||
| """Strokes whose centroids fall outside the grid should be dropped.""" | ||
| data = pd.DataFrame( | ||
| { | ||
| "line_number": [0, 0], | ||
| "x": [500.0, 600.0], | ||
| "y": [500.0, 600.0], | ||
| "seconds": [0.0, 0.1], | ||
| } | ||
| ) | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=10.0, | ||
| y_min=0.0, | ||
| y_max=10.0, | ||
| n_rows=1, | ||
| n_cols=1, | ||
| ) | ||
|
|
||
| assert len(grid.cells[0].strokes) == 0 | ||
|
|
||
| def test_multiple_strokes_in_same_cell(self) -> None: | ||
| """Multiple strokes in the same spatial region should all land in one cell.""" | ||
| data = pd.DataFrame( | ||
| { | ||
| "line_number": [0, 0, 1, 1, 2, 2], | ||
| "x": [5.0, 6.0, 5.5, 6.5, 4.5, 5.5], | ||
| "y": [5.0, 6.0, 5.5, 6.5, 4.5, 5.5], | ||
| "seconds": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5], | ||
| } | ||
| ) | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=10.0, | ||
| y_min=0.0, | ||
| y_max=10.0, | ||
| n_rows=1, | ||
| n_cols=1, | ||
| ) | ||
|
|
||
| assert len(grid.cells[0].strokes) == 3 | ||
|
|
||
| def test_grid_structure_matches_parameters(self) -> None: | ||
| """Returned grid should have the correct number of labeled cells.""" | ||
| data = _make_drawing_data() | ||
| labels = ["A", "B", "C", "D", "E", "F"] | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=100.0, | ||
| y_min=0.0, | ||
| y_max=100.0, | ||
| n_rows=2, | ||
| n_cols=3, | ||
| labels=labels, | ||
| ) | ||
|
|
||
| assert len(grid.cells) == 6 | ||
| assert [c.label for c in grid.cells] == labels | ||
|
|
||
| def test_stroke_index_is_reset(self) -> None: | ||
| """Stroke points should have a reset index starting from 0.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this answeres my other question, weird pandas stuff that makes an index pervasive I guess? |
||
| data = pd.DataFrame( | ||
| { | ||
| "line_number": [5, 5, 5], | ||
| "x": [1.0, 2.0, 3.0], | ||
| "y": [1.0, 2.0, 3.0], | ||
| "seconds": [0.0, 0.1, 0.2], | ||
| }, | ||
| index=[10, 11, 12], | ||
| ) | ||
| grid = alphabet_utils.segment_strokes( | ||
| data=data, | ||
| x_min=0.0, | ||
| x_max=10.0, | ||
| y_min=0.0, | ||
| y_max=10.0, | ||
| n_rows=1, | ||
| n_cols=1, | ||
| ) | ||
|
|
||
| stroke = grid.cells[0].strokes[0] | ||
| assert list(stroke.points.index) == [0, 1, 2] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still feel like this is a potential future refactor point, too many of these dataclasses feel like regular classes (not just for data storage)