Skip to content

Commit 89e0427

Browse files
Bordacodex
andcommitted
feat(keypoints): add RF-DETR keypoint conversion bridge
Co-authored-by: Codex <codex@openai.com>
1 parent fb2dec9 commit 89e0427

2 files changed

Lines changed: 127 additions & 0 deletions

File tree

src/supervision/key_points/core.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,66 @@ def from_transformers(cls, transformers_results: Any) -> KeyPoints:
670670
else:
671671
return cls.empty()
672672

673+
@classmethod
674+
def from_rfdetr(cls, detections: Detections) -> KeyPoints:
675+
"""
676+
Create a `sv.KeyPoints` object from RF-DETR keypoint predictions.
677+
678+
RF-DETR stores keypoints under `sv.Detections.data["keypoints"]` with shape
679+
`(N, K, 3)`, where each keypoint is `(x, y, confidence)`.
680+
681+
Args:
682+
detections: RF-DETR detections containing keypoints in
683+
`detections.data["keypoints"]`.
684+
685+
Returns:
686+
A `sv.KeyPoints` object containing xy coordinates, confidence, and class IDs.
687+
688+
Raises:
689+
ValueError: If keypoints are missing or have invalid shape.
690+
691+
Examples:
692+
```pycon
693+
>>> import numpy as np
694+
>>> import supervision as sv
695+
>>> detections = sv.Detections(
696+
... xyxy=np.array([[0, 0, 10, 10]], dtype=np.float32),
697+
... class_id=np.array([0], dtype=int),
698+
... data={"keypoints": np.array([[[1, 2, 0.9], [3, 4, 0.8]]], dtype=np.float32)},
699+
... )
700+
>>> key_points = sv.KeyPoints.from_rfdetr(detections)
701+
>>> key_points.xy.shape
702+
(1, 2, 2)
703+
>>> key_points.confidence.shape
704+
(1, 2)
705+
706+
```
707+
"""
708+
if "keypoints" not in detections.data:
709+
raise ValueError(
710+
"RF-DETR keypoints are missing. Expected detections.data['keypoints'] with shape (N, K, 3)."
711+
)
712+
713+
keypoints = np.asarray(detections.data["keypoints"], dtype=np.float32)
714+
if keypoints.ndim != 3 or keypoints.shape[-1] != 3:
715+
raise ValueError(
716+
"RF-DETR keypoints must have shape (N, K, 3) with (x, y, confidence) values."
717+
)
718+
719+
xy = keypoints[..., :2].astype(np.float32, copy=False)
720+
confidence = keypoints[..., 2].astype(np.float32, copy=False)
721+
class_id = (
722+
detections.class_id.astype(int, copy=False)
723+
if detections.class_id is not None
724+
else None
725+
)
726+
727+
return cls(
728+
xy=xy,
729+
confidence=confidence,
730+
class_id=class_id,
731+
)
732+
673733
def _get_by_2d_bool_mask(self, mask: npt.NDArray[np.bool_]) -> KeyPoints:
674734
"""Filter keypoints using a 2D boolean mask of shape `(n, m)`.
675735
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
import pytest
3+
4+
import supervision as sv
5+
6+
7+
def test_keypoints_from_rfdetr_detections() -> None:
8+
"""Converts RF-DETR detections.data['keypoints'] into a KeyPoints object."""
9+
detections = sv.Detections(
10+
xyxy=np.array([[0, 0, 10, 10], [10, 10, 20, 20]], dtype=np.float32),
11+
class_id=np.array([1, 3], dtype=int),
12+
data={
13+
"keypoints": np.array(
14+
[
15+
[[1.0, 2.0, 0.9], [3.0, 4.0, 0.8]],
16+
[[5.0, 6.0, 0.7], [7.0, 8.0, 0.6]],
17+
],
18+
dtype=np.float32,
19+
)
20+
},
21+
)
22+
23+
key_points = sv.KeyPoints.from_rfdetr(detections)
24+
25+
assert key_points.xy.shape == (2, 2, 2)
26+
assert key_points.confidence is not None
27+
assert key_points.confidence.shape == (2, 2)
28+
assert key_points.class_id is not None
29+
assert np.array_equal(key_points.class_id, np.array([1, 3], dtype=int))
30+
31+
32+
def test_keypoints_from_rfdetr_missing_keypoints_raises_clear_error() -> None:
33+
"""Missing detections.data['keypoints'] raises a clear conversion error."""
34+
detections = sv.Detections(
35+
xyxy=np.array([[0, 0, 10, 10]], dtype=np.float32),
36+
class_id=np.array([0], dtype=int),
37+
)
38+
39+
with pytest.raises(ValueError, match="detections.data\\['keypoints'\\]"):
40+
sv.KeyPoints.from_rfdetr(detections)
41+
42+
43+
def test_keypoints_from_rfdetr_malformed_shape_raises_clear_error() -> None:
44+
"""Malformed keypoints shape raises a clear conversion error."""
45+
detections = sv.Detections(
46+
xyxy=np.array([[0, 0, 10, 10]], dtype=np.float32),
47+
class_id=np.array([0], dtype=int),
48+
data={"keypoints": np.array([[[1.0, 2.0]]], dtype=np.float32)},
49+
)
50+
51+
with pytest.raises(ValueError, match="shape \\(N, K, 3\\)"):
52+
sv.KeyPoints.from_rfdetr(detections)
53+
54+
55+
def test_keypoint_annotator_uses_vertex_and_edge_rendering() -> None:
56+
"""Converted RF-DETR keypoints are consumable by vertex and edge annotators."""
57+
scene = np.zeros((32, 32, 3), dtype=np.uint8)
58+
detections = sv.Detections(
59+
xyxy=np.array([[0, 0, 10, 10]], dtype=np.float32),
60+
data={"keypoints": np.array([[[10.0, 10.0, 0.9], [20.0, 20.0, 0.8]]], dtype=np.float32)},
61+
)
62+
key_points = sv.KeyPoints.from_rfdetr(detections)
63+
64+
scene = sv.VertexAnnotator().annotate(scene=scene, key_points=key_points)
65+
scene = sv.EdgeAnnotator(edges=[(1, 2)]).annotate(scene=scene, key_points=key_points)
66+
67+
assert np.any(scene != 0)

0 commit comments

Comments
 (0)