diff --git a/src/graphomotor/core/models.py b/src/graphomotor/core/models.py index 60b416d..57a348a 100644 --- a/src/graphomotor/core/models.py +++ b/src/graphomotor/core/models.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import pydantic +import scipy.spatial.distance as dist class Drawing(pydantic.BaseModel): @@ -119,6 +120,39 @@ def get_extractors( } +@dataclasses.dataclass +class CircleTarget: + """Represents a target circle in the drawing task. + + Attributes: + order: The order of the circle in the sequence. + label: The label of the circle. + center_x: The x-coordinate of the circle's center. + center_y: The y-coordinate of the circle's center. + radius: The radius of the circle. + """ + + order: int + label: str + center_x: float + center_y: float + radius: float + + def contains_point(self, x: float, y: float, tolerance: float = 1.5) -> bool: + """Check if a point is within the circle (with tolerance multiplier). + + Args: + x: X coordinate of the point. + y: Y coordinate of the point. + tolerance: Multiplier for the radius to define tolerance boundary. + + Returns: + True if the point is within the circle (with tolerance), False otherwise. + """ + distance = np.sqrt((x - self.center_x) ** 2 + (y - self.center_y) ** 2) + return distance <= (self.radius * tolerance) + + @dataclasses.dataclass class LineSegment: """Represents a line drawn between two circles. @@ -164,35 +198,36 @@ class LineSegment: velocities: typing.List[float] = dataclasses.field(default_factory=list) accelerations: typing.List[float] = dataclasses.field(default_factory=list) + def calculate_path_optimality( + self, + start_circle: CircleTarget, + end_circle: CircleTarget, + ) -> None: + """Calculate path optimality ratio. -@dataclasses.dataclass -class CircleTarget: - """Represents a target circle in the drawing task. - - Attributes: - order: The order of the circle in the sequence. - label: The label of the circle. - center_x: The x-coordinate of the circle's center. - center_y: The y-coordinate of the circle's center. - radius: The radius of the circle. - """ - - order: int - label: str - center_x: float - center_y: float - radius: float - - def contains_point(self, x: float, y: float, tolerance: float = 1.5) -> bool: - """Check if a point is within the circle (with tolerance multiplier). + The default value for path optimality in the LineSegment object is 0.0. This + function updates the path_optimality attribute of the LineSegment object based + on the optimal distance between the start and end circles, adjusted for their + radii. If the optimal distance is less than or equal to zero, the path + optimality remains 0.0. Args: - x: X coordinate of the point. - y: Y coordinate of the point. - tolerance: Multiplier for the radius to define tolerance boundary. + segment: LineSegment object for which to calculate path optimality. + start_circle: CircleTarget representing the start circle. + end_circle: CircleTarget representing the end circle. Returns: - True if the point is within the circle (with tolerance), False otherwise. + Path optimality ratio. """ - distance = np.sqrt((x - self.center_x) ** 2 + (y - self.center_y) ** 2) - return distance <= (self.radius * tolerance) + optimal_distance = ( + dist.euclidean( + [start_circle.center_x, start_circle.center_y], + [end_circle.center_x, end_circle.center_y], + ) + - start_circle.radius + - end_circle.radius + ) + + if optimal_distance > 0: + self.path_optimality = optimal_distance / self.distance + return diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 9a2302e..e158f47 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -95,3 +95,43 @@ def test_point_outside_with_default_tolerance(circle: models.CircleTarget) -> No """Point outside default tolerance boundary should not be contained.""" assert not circle.contains_point(16.0, 0.0) assert not circle.contains_point(0.0, 16.0) + + +def test_path_optimality_positive() -> None: + """Test case for path optimality with positive optimal distance.""" + start = models.CircleTarget(order=1, label="1", center_x=0, center_y=0, radius=1) + end = models.CircleTarget(order=2, label="2", center_x=10, center_y=0, radius=1) + segment = models.LineSegment( + start_label="1", + end_label="2", + points=pd.DataFrame(), + is_error=False, + line_number=1, + distance=8, + ) + expected_optimal_distance = ( + end.center_x - start.center_x - start.radius - end.radius + ) + expected_path_optimality = expected_optimal_distance / segment.distance + + segment.calculate_path_optimality(start, end) + + assert segment.path_optimality == expected_path_optimality + + +def test_path_optimality_non_positive_distance() -> None: + """Test case where optimal distance is zero or negative, so no assignment occurs.""" + start = models.CircleTarget(order=1, label="1", center_x=0, center_y=0, radius=5) + end = models.CircleTarget(order=2, label="2", center_x=8, center_y=0, radius=5) + segment = models.LineSegment( + start_label="1", + end_label="2", + points=pd.DataFrame(), + is_error=False, + line_number=1, + distance=5, + ) + + segment.calculate_path_optimality(start, end) + + assert segment.path_optimality == 0.0