Skip to content

Commit 7022724

Browse files
Bordaclaude[bot]
andcommitted
refine(key_points): address PR #2343 review feedback
- Generalize TRACK error msg in _resolve_color_idx to not reference "Detections object" (function now shared with keypoint annotators) - Consolidate multi-line error strings to full line width - Add Raises section to _resolve_color_idx docstring - Add TRACK-not-supported note to VertexAnnotator, EdgeAnnotator, VertexLabelAnnotator color_lookup docs - Add ColorPalette+ColorLookup tests: INDEX/CLASS/KEYPOINT for all three keypoint annotators + CLASS-without-class_id error case - Fix _resolve_color_idx signature: NDArray[np.generic] (linter) [resolve items 1,2,3,5,6] PR #2343 review by @Copilot and @Borda Challenge: evidence=VALID suggestion=VALID resolution=as-suggested --- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
1 parent db9aa09 commit 7022724

3 files changed

Lines changed: 183 additions & 13 deletions

File tree

src/supervision/annotators/utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import textwrap
55
from enum import Enum
6-
from typing import Any
6+
from typing import Any, cast
77

88
import numpy as np
99
import numpy.typing as npt
@@ -44,8 +44,8 @@ def _resolve_color_idx(
4444
instance_idx: int,
4545
color_lookup: ColorLookup,
4646
count: int,
47-
class_id: npt.NDArray[np.int_] | None = None,
48-
tracker_id: npt.NDArray[np.int_] | None = None,
47+
class_id: npt.NDArray[np.generic] | None = None,
48+
tracker_id: npt.NDArray[np.generic] | None = None,
4949
keypoint_idx: int | None = None,
5050
) -> int:
5151
"""Resolve a palette index from raw field arrays.
@@ -63,6 +63,14 @@ def _resolve_color_idx(
6363
6464
Returns:
6565
An integer index suitable for ``ColorPalette.by_idx()``.
66+
67+
Raises:
68+
ValueError: If ``instance_idx`` is out of bounds for the given ``count``.
69+
ValueError: If ``color_lookup`` is ``CLASS`` and ``class_id`` is ``None``.
70+
ValueError: If ``color_lookup`` is ``TRACK`` and ``tracker_id`` is ``None``.
71+
ValueError: If ``color_lookup`` is ``KEYPOINT`` and ``keypoint_idx`` is
72+
``None``.
73+
ValueError: If ``color_lookup`` is an unsupported strategy.
6674
"""
6775
if instance_idx >= count:
6876
raise ValueError(
@@ -74,17 +82,15 @@ def _resolve_color_idx(
7482
elif color_lookup == ColorLookup.CLASS:
7583
if class_id is None:
7684
raise ValueError(
77-
"Could not resolve color by class because "
78-
"class_id is not available. Try setting "
79-
"color_lookup to sv.ColorLookup.INDEX."
85+
"Could not resolve color by class because class_id is not available. "
86+
"Try setting color_lookup to sv.ColorLookup.INDEX."
8087
)
8188
return int(class_id[instance_idx])
8289
elif color_lookup == ColorLookup.TRACK:
8390
if tracker_id is None:
8491
raise ValueError(
85-
"Could not resolve color by track because "
86-
"tracker_id is not available. Make sure that the "
87-
"Detections object contains tracker_id data."
92+
"Could not resolve color by track because tracker_id is not available. "
93+
"Make sure tracker_id is set on the input object."
8894
)
8995
return int(tracker_id[instance_idx])
9096
elif color_lookup == ColorLookup.KEYPOINT:
@@ -434,7 +440,7 @@ def snap_boxes(
434440
bottom_shift = height - result[bottom_overflow, 3]
435441
result[bottom_overflow, 1:4:2] += bottom_shift[:, np.newaxis]
436442

437-
return result.astype(np.float32) # type: ignore
443+
return cast(np.ndarray[Any, np.dtype[np.float32]], result.astype(np.float32))
438444

439445

440446
class Trace:

src/supervision/key_points/annotators.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def __init__(
7878
color_lookup: Strategy for mapping colors to annotations.
7979
Options are `INDEX` (per-skeleton index), `CLASS`
8080
(per class_id), and `KEYPOINT` (per keypoint index within
81-
each skeleton).
81+
each skeleton). Note: ``TRACK`` is not supported for
82+
keypoint annotators.
8283
"""
8384
self.color = color
8485
self.radius = radius
@@ -181,7 +182,8 @@ def __init__(
181182
color_lookup: Strategy for mapping colors to annotations.
182183
Options are `INDEX` (per-skeleton index), `CLASS`
183184
(per class_id), and `KEYPOINT` (per keypoint index —
184-
edge inherits the color of its first endpoint).
185+
edge inherits the color of its first endpoint). Note:
186+
``TRACK`` is not supported for keypoint annotators.
185187
"""
186188
self.color = color
187189
self.thickness = thickness
@@ -797,7 +799,8 @@ def __init__(
797799
color_lookup: Strategy for mapping colors to annotations.
798800
Options are `INDEX` (per-skeleton index), `CLASS`
799801
(per class_id), and `KEYPOINT` (per keypoint index within
800-
each skeleton).
802+
each skeleton). Note: ``TRACK`` is not supported for
803+
keypoint annotators.
801804
"""
802805
if isinstance(color, list):
803806
warn_deprecated(

tests/key_points/test_annotators.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,164 @@ def test_resolve_labels_returns_expected(
467467
def test_resolve_labels_raises(self, labels, points_count, class_id, match):
468468
with pytest.raises(ValueError, match=match):
469469
sv.VertexLabelAnnotator._resolve_labels(labels, points_count, class_id)
470+
471+
472+
class TestVertexAnnotatorColorLookup:
473+
"""Verify VertexAnnotator respects each ColorLookup strategy with a ColorPalette."""
474+
475+
@pytest.fixture
476+
def key_points_with_class(self) -> sv.KeyPoints:
477+
"""Two-instance, three-keypoint set with class_id set."""
478+
return sv.KeyPoints(
479+
xy=np.array(
480+
[
481+
[[20.0, 20.0], [40.0, 40.0], [60.0, 60.0]],
482+
[[25.0, 25.0], [45.0, 45.0], [65.0, 65.0]],
483+
],
484+
dtype=np.float32,
485+
),
486+
class_id=np.array([0, 1], dtype=int),
487+
)
488+
489+
@pytest.mark.parametrize(
490+
"color_lookup",
491+
[
492+
pytest.param(sv.ColorLookup.INDEX, id="index"),
493+
pytest.param(sv.ColorLookup.CLASS, id="class"),
494+
pytest.param(sv.ColorLookup.KEYPOINT, id="keypoint"),
495+
],
496+
)
497+
def test_annotate_with_color_palette_returns_ndarray(
498+
self, scene, key_points_with_class, color_lookup
499+
):
500+
"""ColorPalette + each ColorLookup produces a modified ndarray output."""
501+
annotator = sv.VertexAnnotator(
502+
color=sv.ColorPalette.DEFAULT,
503+
radius=5,
504+
color_lookup=color_lookup,
505+
)
506+
result = annotator.annotate(
507+
scene=scene.copy(), key_points=key_points_with_class
508+
)
509+
510+
assert isinstance(result, np.ndarray)
511+
assert result.shape == scene.shape
512+
assert not np.array_equal(result, scene)
513+
514+
def test_annotate_class_lookup_raises_when_class_id_none(self, scene):
515+
"""CLASS strategy raises ValueError when key_points.class_id is None."""
516+
key_points = sv.KeyPoints(
517+
xy=np.array([[[30.0, 30.0], [50.0, 50.0]]], dtype=np.float32),
518+
)
519+
annotator = sv.VertexAnnotator(
520+
color=sv.ColorPalette.DEFAULT,
521+
color_lookup=sv.ColorLookup.CLASS,
522+
)
523+
524+
with pytest.raises(ValueError, match="class_id"):
525+
annotator.annotate(scene=scene.copy(), key_points=key_points)
526+
527+
528+
class TestEdgeAnnotatorColorLookup:
529+
"""Verify EdgeAnnotator respects each ColorLookup strategy with a ColorPalette."""
530+
531+
@pytest.fixture
532+
def key_points_triangle(self) -> sv.KeyPoints:
533+
"""Single-instance, three-vertex triangle useful with explicit edges."""
534+
return sv.KeyPoints(
535+
xy=np.array(
536+
[[[10.0, 10.0], [80.0, 10.0], [45.0, 80.0]]],
537+
dtype=np.float32,
538+
),
539+
class_id=np.array([0], dtype=int),
540+
)
541+
542+
@pytest.mark.parametrize(
543+
"color_lookup",
544+
[
545+
pytest.param(sv.ColorLookup.INDEX, id="index"),
546+
pytest.param(sv.ColorLookup.CLASS, id="class"),
547+
pytest.param(sv.ColorLookup.KEYPOINT, id="keypoint"),
548+
],
549+
)
550+
def test_annotate_with_color_palette_returns_ndarray(
551+
self, scene, key_points_triangle, color_lookup
552+
):
553+
"""ColorPalette + each ColorLookup produces a modified ndarray output."""
554+
annotator = sv.EdgeAnnotator(
555+
color=sv.ColorPalette.DEFAULT,
556+
thickness=2,
557+
edges=[(1, 2), (2, 3), (1, 3)],
558+
color_lookup=color_lookup,
559+
)
560+
result = annotator.annotate(scene=scene.copy(), key_points=key_points_triangle)
561+
562+
assert isinstance(result, np.ndarray)
563+
assert result.shape == scene.shape
564+
assert not np.array_equal(result, scene)
565+
566+
def test_annotate_class_lookup_raises_when_class_id_none(self, scene):
567+
"""CLASS strategy raises ValueError when key_points.class_id is None."""
568+
key_points = sv.KeyPoints(
569+
xy=np.array([[[10.0, 10.0], [80.0, 10.0]]], dtype=np.float32),
570+
)
571+
annotator = sv.EdgeAnnotator(
572+
color=sv.ColorPalette.DEFAULT,
573+
edges=[(1, 2)],
574+
color_lookup=sv.ColorLookup.CLASS,
575+
)
576+
577+
with pytest.raises(ValueError, match="class_id"):
578+
annotator.annotate(scene=scene.copy(), key_points=key_points)
579+
580+
581+
class TestVertexLabelAnnotatorColorLookup:
582+
"""Verify VertexLabelAnnotator respects each ColorLookup strategy."""
583+
584+
@pytest.fixture
585+
def key_points_with_class(self) -> sv.KeyPoints:
586+
"""Two-instance, two-keypoint set with class_id set."""
587+
return sv.KeyPoints(
588+
xy=np.array(
589+
[[[20.0, 20.0], [60.0, 60.0]], [[25.0, 25.0], [65.0, 65.0]]],
590+
dtype=np.float32,
591+
),
592+
class_id=np.array([0, 1], dtype=int),
593+
)
594+
595+
@pytest.mark.parametrize(
596+
"color_lookup",
597+
[
598+
pytest.param(sv.ColorLookup.INDEX, id="index"),
599+
pytest.param(sv.ColorLookup.CLASS, id="class"),
600+
pytest.param(sv.ColorLookup.KEYPOINT, id="keypoint"),
601+
],
602+
)
603+
def test_annotate_with_color_palette_returns_ndarray(
604+
self, scene, key_points_with_class, color_lookup
605+
):
606+
"""ColorPalette + each ColorLookup produces a modified ndarray output."""
607+
annotator = sv.VertexLabelAnnotator(
608+
color=sv.ColorPalette.DEFAULT,
609+
color_lookup=color_lookup,
610+
)
611+
result = annotator.annotate(
612+
scene=scene.copy(), key_points=key_points_with_class
613+
)
614+
615+
assert isinstance(result, np.ndarray)
616+
assert result.shape == scene.shape
617+
assert not np.array_equal(result, scene)
618+
619+
def test_annotate_class_lookup_raises_when_class_id_none(self, scene):
620+
"""CLASS strategy raises ValueError when key_points.class_id is None."""
621+
key_points = sv.KeyPoints(
622+
xy=np.array([[[30.0, 30.0], [50.0, 50.0]]], dtype=np.float32),
623+
)
624+
annotator = sv.VertexLabelAnnotator(
625+
color=sv.ColorPalette.DEFAULT,
626+
color_lookup=sv.ColorLookup.CLASS,
627+
)
628+
629+
with pytest.raises(ValueError, match="class_id"):
630+
annotator.annotate(scene=scene.copy(), key_points=key_points)

0 commit comments

Comments
 (0)