Skip to content

Commit f5ee545

Browse files
committed
feat(detection): add require_all_anchors to PolygonZone
Currently a detection counts as 'in the zone' only when every anchor in triggering_anchors is inside. For boxes that straddle the zone boundary this means a detection with many anchors (e.g. the four corners) is often under-counted unless the user shrinks triggering_anchors to a single point. Add require_all_anchors: bool = True so callers can opt into 'any anchor inside is enough'. Default preserves current behaviour. Closes #1022.
1 parent fb2dec9 commit f5ee545

2 files changed

Lines changed: 40 additions & 3 deletions

File tree

src/supervision/detection/tools/polygon_zone.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class PolygonZone:
3232
which anchors of the detections bounding box to consider when deciding on
3333
whether the detection fits within the PolygonZone
3434
(default: (sv.Position.BOTTOM_CENTER,)).
35+
require_all_anchors: If `True` (default), a detection is considered inside
36+
the zone only when *every* anchor in `triggering_anchors` is inside.
37+
If `False`, the detection triggers as soon as *any* anchor is inside.
38+
Has no effect when `triggering_anchors` has a single entry.
3539
current_count: The current count of detected objects within the zone
3640
mask: The 2D bool mask for the polygon zone
3741
@@ -62,11 +66,14 @@ def __init__(
6266
self,
6367
polygon: npt.NDArray[np.int64],
6468
triggering_anchors: Iterable[Position] = (Position.BOTTOM_CENTER,),
69+
require_all_anchors: bool = True,
6570
):
6671
self.polygon = polygon.astype(int)
67-
self.triggering_anchors = triggering_anchors
68-
if not list(self.triggering_anchors):
72+
# Materialize once so we can safely accept generators without exhausting them.
73+
self.triggering_anchors = list(triggering_anchors)
74+
if not self.triggering_anchors:
6975
raise ValueError("Triggering anchors cannot be empty.")
76+
self.require_all_anchors = require_all_anchors
7077

7178
self.current_count = 0
7279

@@ -108,7 +115,9 @@ def trigger(self, detections: Detections) -> npt.NDArray[np.bool_]:
108115
in_bounds = (x >= 0) & (y >= 0) & (x < mask_w) & (y < mask_h)
109116
x_safe = np.clip(x, 0, mask_w - 1)
110117
y_safe = np.clip(y, 0, mask_h - 1)
111-
is_in_zone = np.all(in_bounds & self.mask[y_safe, x_safe], axis=0)
118+
anchor_hits = in_bounds & self.mask[y_safe, x_safe]
119+
reduce = np.all if self.require_all_anchors else np.any
120+
is_in_zone = reduce(anchor_hits, axis=0)
112121
self.current_count = int(np.sum(is_in_zone))
113122
return is_in_zone.astype(bool)
114123

tests/detection/test_polygonzone.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ def test_empty_anchors_raises(self, polygon, triggering_anchors, exception):
4444
with exception:
4545
sv.PolygonZone(polygon, triggering_anchors=triggering_anchors)
4646

47+
def test_generator_triggering_anchors_is_materialized(self):
48+
"""A generator passed for triggering_anchors must not be silently exhausted."""
49+
zone = sv.PolygonZone(
50+
POLYGON, triggering_anchors=(p for p in [sv.Position.CENTER])
51+
)
52+
detections = _create_detections(
53+
xyxy=[[140.0, 140.0, 160.0, 160.0]], class_id=[0]
54+
)
55+
assert zone.trigger(detections)[0]
56+
4757

4858
class TestPolygonZoneTrigger:
4959
@pytest.mark.parametrize(
@@ -164,3 +174,21 @@ def test_anchor_on_polygon_boundary_included(self) -> None:
164174
)
165175
result = zone.trigger(detections)
166176
assert result[0]
177+
178+
def test_require_all_anchors_false_triggers_on_any_anchor(self) -> None:
179+
"""With require_all_anchors=False, any anchor inside triggers."""
180+
# Box [85, 85, 115, 115] has only BOTTOM_RIGHT (115, 115) inside POLYGON
181+
# ([100, 100]..[200, 200]); the other three corners are outside.
182+
detections = _create_detections(xyxy=[[85.0, 85.0, 115.0, 115.0]], class_id=[0])
183+
anchors = (
184+
sv.Position.TOP_LEFT,
185+
sv.Position.TOP_RIGHT,
186+
sv.Position.BOTTOM_LEFT,
187+
sv.Position.BOTTOM_RIGHT,
188+
)
189+
all_required = sv.PolygonZone(POLYGON, triggering_anchors=anchors)
190+
any_anchor = sv.PolygonZone(
191+
POLYGON, triggering_anchors=anchors, require_all_anchors=False
192+
)
193+
assert not all_required.trigger(detections)[0]
194+
assert any_anchor.trigger(detections)[0]

0 commit comments

Comments
 (0)