Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 91 additions & 25 deletions openfoodfacts/ml/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,64 @@ def to_list(self) -> list[JSONType]:
return results


def apply_nms(
bboxes: np.ndarray,
scores: np.ndarray,
classes: np.ndarray,
threshold: float,
nms_threshold: float,
nms_eta: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Apply the non-maximum suppression algorithm to the bounding boxes.

We use `NMSBoxes` from the openCV library to perform NMS.

:param bboxes: The bounding boxes in format [y_min, x_min, y_max, x_max], in relative
coordinates. Shape: (N, 4)
:param scores: The confidence scores. Shape: (N,)
:param classes: The class labels, as an int array. Shape: (N,)
:param threshold: The confidence threshold to use to filter out the bounding boxes.
Shape: (N,)
:param nms_threshold: The NMS threshold to use.
:param nms_eta: The NMS eta to use.
:return: bounding boxes, scores, classes after NMS
"""

count = bboxes.shape[0]
bboxes_nms = np.zeros((count, 4), dtype=np.float32)
for i in range(count):
# raw_detection_boxes format: y_min, x_min, y_max, x_max
# expected format: x, y, width, height
# See Rec2d constructor:
# https://docs.opencv.org/4.x/javadoc/org/opencv/core/Rect2d.html#%3Cinit%3E(double,double,double,double)
# and NMSBoxes documentation:
# https://docs.opencv.org/4.x/d6/d0f/group__dnn.html#ga6e9e67e8d1d8b3a70b55ab9ea715e70d
bboxes_nms[i, 0] = bboxes[i, 1] # x_min
bboxes_nms[i, 1] = bboxes[i, 0] # y_min
bboxes_nms[i, 2] = bboxes[i, 3] - bboxes[i, 1] # box_width
bboxes_nms[i, 3] = bboxes[i, 2] - bboxes[i, 0] # box_height

detection_box_indices = dnn.NMSBoxes(
bboxes=bboxes_nms, # type: ignore
scores=scores, # type: ignore
score_threshold=threshold, # type: ignore
# the following values are copied from Ultralytics settings
nms_threshold=nms_threshold,
eta=nms_eta,
)

detection_classes = np.zeros(len(detection_box_indices), dtype=int)
detection_scores = np.zeros(len(detection_box_indices), dtype=np.float32)
detection_boxes = np.zeros((len(detection_box_indices), 4), dtype=np.float32)

for i, idx in enumerate(detection_box_indices):
detection_classes[i] = classes[idx]
detection_scores[i] = scores[idx]
detection_boxes[i] = bboxes[idx]

return detection_boxes, detection_scores, detection_classes


class ObjectDetector:
def __init__(self, model_name: str, label_names: list[str], image_size: int = 640):
"""An object detection detector based on Yolo models.
Expand All @@ -177,6 +235,7 @@ def detect_from_image(
nms_threshold: float | None = None,
nms_eta: float | None = None,
model_version: str | None = None,
nms: bool = True,
) -> ObjectDetectionRawResult:
"""Run an object detection model on an image.

Expand All @@ -193,6 +252,7 @@ def detect_from_image(
will be used).
:param model_version: the version of the model to use, defaults to
None (latest).
:param nms: whether to use NMS, defaults to True.
:return: the detection result
"""
metrics: dict[str, float] = {}
Expand Down Expand Up @@ -221,6 +281,7 @@ def detect_from_image(
original_shape=original_shape,
nms_threshold=nms_threshold,
nms_eta=nms_eta,
nms=nms,
)

metrics.update(response.metrics)
Expand Down Expand Up @@ -248,6 +309,7 @@ def postprocess(
original_shape: tuple[int, int],
nms_threshold: float | None = None,
nms_eta: float | None = None,
nms: bool = True,
) -> ObjectDetectionRawResult:
"""Postprocess the output of the object detection model.

Expand All @@ -258,6 +320,8 @@ def postprocess(
use, defaults to None (0.7 will be used).
:param nms_eta: the NMS eta parameter to use, defaults to None (1.0
will be used).
:param nms: whether to apply NMS or not, defaults to True. If False,
`nms_threshold` and `nms_eta` are ignored.
:return: the detection result
"""
if len(response.outputs) != 1:
Expand All @@ -284,15 +348,15 @@ def postprocess(
raw_detection_classes = np.zeros(rows, dtype=int)
raw_detection_scores = np.zeros(rows, dtype=np.float32)
raw_detection_boxes = np.zeros((rows, 4), dtype=np.float32)

selected = 0
for i in range(rows):
classes_scores = output[4:, i]
max_cls_idx = np.argmax(classes_scores)
max_score = classes_scores[max_cls_idx]
if max_score < threshold:
continue
raw_detection_classes[i] = max_cls_idx
raw_detection_scores[i] = max_score
raw_detection_classes[selected] = max_cls_idx
raw_detection_scores[selected] = max_score

# The bounding box is in the format (x, y, width, height) in
# relative coordinates
Expand All @@ -313,30 +377,32 @@ def postprocess(
original_shape=original_shape,
image_size=self.image_size,
)
raw_detection_boxes[i, 0] = max(0.0, min(1.0, reversed_bboxes[0]))
raw_detection_boxes[i, 1] = max(0.0, min(1.0, reversed_bboxes[1]))
raw_detection_boxes[i, 2] = max(0.0, min(1.0, reversed_bboxes[2]))
raw_detection_boxes[i, 3] = max(0.0, min(1.0, reversed_bboxes[3]))
raw_detection_boxes[selected, 0] = max(0.0, min(1.0, reversed_bboxes[0]))
raw_detection_boxes[selected, 1] = max(0.0, min(1.0, reversed_bboxes[1]))
raw_detection_boxes[selected, 2] = max(0.0, min(1.0, reversed_bboxes[2]))
raw_detection_boxes[selected, 3] = max(0.0, min(1.0, reversed_bboxes[3]))
selected += 1

raw_detection_classes = raw_detection_classes[:selected]
raw_detection_scores = raw_detection_scores[:selected]
raw_detection_boxes = raw_detection_boxes[:selected]

metrics: dict[str, float] = {}
with PerfTimer("postprocess_nms_time", metrics):
# Perform NMS (Non Maximum Suppression)
detection_box_indices = dnn.NMSBoxes(
raw_detection_boxes, # type: ignore
raw_detection_scores, # type: ignore
score_threshold=threshold,
# the following values are copied from Ultralytics settings
nms_threshold=nms_threshold,
eta=nms_eta,
)
detection_classes = np.zeros(len(detection_box_indices), dtype=int)
detection_scores = np.zeros(len(detection_box_indices), dtype=np.float32)
detection_boxes = np.zeros((len(detection_box_indices), 4), dtype=np.float32)

for i, idx in enumerate(detection_box_indices):
detection_classes[i] = raw_detection_classes[idx]
detection_scores[i] = raw_detection_scores[idx]
detection_boxes[i] = raw_detection_boxes[idx]

if nms:
with PerfTimer("postprocess_nms_time", metrics):
detection_boxes, detection_scores, detection_classes = apply_nms(
bboxes=raw_detection_boxes,
scores=raw_detection_scores,
classes=raw_detection_classes,
threshold=threshold,
nms_threshold=nms_threshold,
nms_eta=nms_eta,
)
else:
detection_classes = raw_detection_classes
detection_scores = raw_detection_scores
detection_boxes = raw_detection_boxes

result = ObjectDetectionRawResult(
num_detections=rows,
Expand Down
Loading
Loading