|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import logging |
4 | 3 | from collections.abc import Iterable, Iterator |
5 | 4 | from dataclasses import dataclass, field |
6 | 5 | from typing import Any, Union, cast |
|
13 | 12 | from supervision.detection.utils.internal import get_data_item, is_data_equal |
14 | 13 | from supervision.validators import validate_key_points_fields |
15 | 14 |
|
16 | | -logger = logging.getLogger(__name__) |
17 | | - |
18 | 15 | Index1D = Union[ |
19 | 16 | int, |
20 | 17 | slice, |
|
26 | 23 | Index2D = tuple[Index1D, Index1D] |
27 | 24 |
|
28 | 25 |
|
29 | | -def _rfdetr_source_shape( |
30 | | - rfdetr_detections: Detections, |
31 | | - detections_count: int, |
32 | | -) -> npt.NDArray[np.float32]: |
33 | | - source_shape = rfdetr_detections.data.get("source_shape") |
34 | | - if source_shape is None: |
35 | | - raise ValueError( |
36 | | - "RF-DETR detections with keypoint precision data must contain " |
37 | | - "data['source_shape'] with shape (N, 2) where each row is " |
38 | | - "(height, width) in pixels." |
39 | | - ) |
40 | | - |
41 | | - source_shape_array = np.asarray(source_shape, dtype=np.float32) |
42 | | - expected_shape = (detections_count, 2) |
43 | | - if source_shape_array.shape != expected_shape: |
44 | | - raise ValueError( |
45 | | - "Expected RF-DETR source_shape shape " |
46 | | - f"{expected_shape}, got {source_shape_array.shape}." |
47 | | - ) |
48 | | - return source_shape_array |
49 | | - |
50 | | - |
51 | | -def _rfdetr_precision_cholesky_to_pixel_covariance( |
52 | | - precision_cholesky: npt.NDArray[np.float32], |
53 | | - source_shape: npt.NDArray[np.float32], |
54 | | -) -> npt.NDArray[np.float32]: |
55 | | - if precision_cholesky.ndim != 3 or precision_cholesky.shape[2] != 3: |
56 | | - raise ValueError( |
57 | | - "Expected RF-DETR keypoint precision shape (N, K, 3), " |
58 | | - f"got {precision_cholesky.shape}." |
59 | | - ) |
60 | | - if precision_cholesky.shape[0] != source_shape.shape[0]: |
61 | | - raise ValueError( |
62 | | - "RF-DETR keypoint precision and source_shape must contain the same " |
63 | | - "number of detections, got " |
64 | | - f"{precision_cholesky.shape[0]} and {source_shape.shape[0]}." |
65 | | - ) |
66 | | - |
67 | | - n_total = precision_cholesky.shape[0] * precision_cholesky.shape[1] |
68 | | - n_non_finite = 0 |
69 | | - n_singular = 0 |
70 | | - n_overflow = 0 |
71 | | - |
72 | | - covariances = np.full( |
73 | | - (*precision_cholesky.shape[:2], 2, 2), np.nan, dtype=np.float32 |
74 | | - ) |
75 | | - for detection_index, detection_precision in enumerate(precision_cholesky): |
76 | | - height, width = source_shape[detection_index] |
77 | | - scale = np.diag([width, height]).astype(np.float64) |
78 | | - for keypoint_index, params in enumerate(detection_precision): |
79 | | - if not np.isfinite(params).all(): |
80 | | - n_non_finite += 1 |
81 | | - continue |
82 | | - log_l11 = float(np.clip(params[0], -20.0, 20.0)) |
83 | | - l21 = float(np.clip(params[1], -1.0e4, 1.0e4)) |
84 | | - log_l22 = float(np.clip(params[2], -20.0, 20.0)) |
85 | | - l11 = float(np.exp(log_l11)) |
86 | | - l22 = float(np.exp(log_l22)) |
87 | | - precision = np.array( |
88 | | - [[l11 * l11, l11 * l21], [l11 * l21, l21 * l21 + l22 * l22]], |
89 | | - dtype=np.float64, |
90 | | - ) |
91 | | - try: |
92 | | - covariance = np.linalg.inv(precision) |
93 | | - except np.linalg.LinAlgError: |
94 | | - n_singular += 1 |
95 | | - continue |
96 | | - |
97 | | - pixel_covariance = scale @ covariance @ scale |
98 | | - if np.isfinite(pixel_covariance).all(): |
99 | | - covariances[detection_index, keypoint_index] = pixel_covariance |
100 | | - else: |
101 | | - n_overflow += 1 |
102 | | - |
103 | | - n_failed = n_non_finite + n_singular + n_overflow |
104 | | - if n_failed > 0: |
105 | | - logger.warning( |
106 | | - "%d of %d precision matrices failed: " |
107 | | - "non_finite=%d, singular=%d, overflow=%d", |
108 | | - n_failed, |
109 | | - n_total, |
110 | | - n_non_finite, |
111 | | - n_singular, |
112 | | - n_overflow, |
113 | | - ) |
114 | | - return covariances |
115 | | - |
116 | | - |
117 | 26 | def _optional_array_equal( |
118 | 27 | first: npt.NDArray[np.generic] | None, |
119 | 28 | second: npt.NDArray[np.generic] | None, |
@@ -250,13 +159,6 @@ class simplifies data manipulation and filtering, providing a uniform API for |
250 | 159 | key_point = sv.KeyPoints.from_transformers(results[0]) |
251 | 160 | ``` |
252 | 161 |
|
253 | | - Note: |
254 | | - [`sv.KeyPoints.from_rfdetr`][supervision.key_points.core.KeyPoints.from_rfdetr] |
255 | | - accepts ``sv.Detections`` (not native RF-DETR output) because RF-DETR keypoints |
256 | | - are attached as extra fields inside a ``sv.Detections`` object returned by |
257 | | - ``model.predict()``. Run that conversion first, then pass the result to |
258 | | - ``from_rfdetr``. |
259 | | -
|
260 | 162 | Attributes: |
261 | 163 | xy: An array of shape `(n, m, 2)` containing |
262 | 164 | `n` detected objects, each composed of `m` equally-sized |
@@ -338,111 +240,6 @@ def __eq__(self, other: object) -> bool: |
338 | 240 | ] |
339 | 241 | ) |
340 | 242 |
|
341 | | - @classmethod |
342 | | - def from_rfdetr(cls, rfdetr_detections: Detections) -> KeyPoints: |
343 | | - """ |
344 | | - Create a `sv.KeyPoints` object from RF-DETR `sv.Detections` output. |
345 | | -
|
346 | | - RF-DETR attaches keypoint coordinates to ``detections.data["keypoints"]`` |
347 | | - with shape ``(N, K, 3)`` where the last dimension stores ``[x, y, |
348 | | - confidence]`` in pixel coordinates. When RF-DETR also provides |
349 | | - ``detections.data["keypoint_precision_cholesky"]``, this method converts |
350 | | - those per-keypoint precision parameters into pixel-space covariance matrices |
351 | | - and stores them in ``key_points.data["covariance"]`` for use with |
352 | | - `sv.VertexEllipseAnnotator`. |
353 | | -
|
354 | | - Note: |
355 | | - ``detections.data["source_shape"]`` must have shape ``(N, 2)`` where each |
356 | | - row is ``(height, width)`` in pixels — note this is HW order, not the WH |
357 | | - order used by ``resolution_wh`` elsewhere in supervision. |
358 | | -
|
359 | | - Keypoint confidence values are stored as-is from RF-DETR output and are |
360 | | - expected to be probabilities in the range ``[0, 1]``. If RF-DETR returns |
361 | | - logits instead, user-supplied ``confidence_threshold`` values in |
362 | | - `sv.VertexEllipseAnnotator` should be adjusted accordingly. |
363 | | -
|
364 | | - Args: |
365 | | - rfdetr_detections: RF-DETR prediction returned by ``model.predict()``. |
366 | | -
|
367 | | - Returns: |
368 | | - A `sv.KeyPoints` object containing RF-DETR keypoints and optional |
369 | | - covariance matrices. |
370 | | -
|
371 | | - Raises: |
372 | | - ValueError: If the RF-DETR detections do not contain valid keypoints, |
373 | | - or if precision parameters are present without source shape data. |
374 | | -
|
375 | | - Examples: |
376 | | - Basic usage — keypoints only: |
377 | | -
|
378 | | - >>> import numpy as np |
379 | | - >>> import supervision as sv |
380 | | - >>> kp_arr = np.array([[[50, 80, 0.9], [60, 90, 0.8]]], dtype=np.float32) |
381 | | - >>> detections = sv.Detections( |
382 | | - ... xyxy=np.array([[10, 20, 100, 200]], dtype=np.float32), |
383 | | - ... data={"keypoints": kp_arr}, |
384 | | - ... ) |
385 | | - >>> key_points = sv.KeyPoints.from_rfdetr(detections) |
386 | | - >>> key_points.xy.shape |
387 | | - (1, 2, 2) |
388 | | -
|
389 | | - With precision Cholesky parameters (produces covariance data): |
390 | | -
|
391 | | - >>> kp_arr2 = np.array([[[50, 80, 0.9], [60, 90, 0.8]]], dtype=np.float32) |
392 | | - >>> chol = np.zeros((1, 2, 3), dtype=np.float32) |
393 | | - >>> src = np.array([[480, 640]], dtype=np.float32) |
394 | | - >>> detections_with_cov = sv.Detections( |
395 | | - ... xyxy=np.array([[10, 20, 100, 200]], dtype=np.float32), |
396 | | - ... data={ |
397 | | - ... "keypoints": kp_arr2, |
398 | | - ... "keypoint_precision_cholesky": chol, |
399 | | - ... "source_shape": src, |
400 | | - ... }, |
401 | | - ... ) |
402 | | - >>> kp = sv.KeyPoints.from_rfdetr(detections_with_cov) |
403 | | - >>> "covariance" in kp.data |
404 | | - True |
405 | | - """ |
406 | | - rfdetr_keypoints = rfdetr_detections.data.get("keypoints") |
407 | | - if rfdetr_keypoints is None: |
408 | | - raise ValueError("RF-DETR detections must contain data['keypoints'].") |
409 | | - |
410 | | - keypoints = np.asarray(rfdetr_keypoints, dtype=np.float32) |
411 | | - if keypoints.ndim != 3 or keypoints.shape[2] != 3: |
412 | | - raise ValueError( |
413 | | - f"Expected RF-DETR keypoints shape (N, K, 3), got {keypoints.shape}." |
414 | | - ) |
415 | | - if keypoints.shape[0] == 0: |
416 | | - return cls.empty() |
417 | | - |
418 | | - data: dict[str, npt.NDArray[np.generic] | list[Any]] = {} |
419 | | - precision_cholesky = rfdetr_detections.data.get("keypoint_precision_cholesky") |
420 | | - if precision_cholesky is not None: |
421 | | - precision_cholesky_array = np.asarray(precision_cholesky, dtype=np.float32) |
422 | | - if precision_cholesky_array.shape[:2] != keypoints.shape[:2]: |
423 | | - raise ValueError( |
424 | | - "keypoint_precision_cholesky shape " |
425 | | - f"{precision_cholesky_array.shape[:2]} does not match " |
426 | | - f"keypoints shape {keypoints.shape[:2]}." |
427 | | - ) |
428 | | - source_shape = _rfdetr_source_shape( |
429 | | - rfdetr_detections, detections_count=keypoints.shape[0] |
430 | | - ) |
431 | | - data["covariance"] = _rfdetr_precision_cholesky_to_pixel_covariance( |
432 | | - precision_cholesky=precision_cholesky_array, |
433 | | - source_shape=source_shape, |
434 | | - ) |
435 | | - class_id: npt.NDArray[np.int_] | None = None |
436 | | - if rfdetr_detections.class_id is not None: |
437 | | - class_id = rfdetr_detections.class_id.astype(np.int_) |
438 | | - |
439 | | - return cls( |
440 | | - xy=keypoints[:, :, :2].astype(np.float32), |
441 | | - confidence=keypoints[:, :, 2].astype(np.float32), |
442 | | - class_id=class_id, |
443 | | - data=data, |
444 | | - ) |
445 | | - |
446 | 243 | @classmethod |
447 | 244 | def from_inference(cls, inference_result: Any) -> KeyPoints: |
448 | 245 | """ |
|
0 commit comments