Skip to content

Commit 66658d6

Browse files
committed
fix: fix issue with NMS
See openfoodfacts/openfoodfacts-python#417 for more information. I also: - added a `nms` parameter to the image prediction route for easier debugging. - fixed a bug when calling the bbox visualization function: we filtered out bboxes with confidence < 0.5, and only displayed up to 20 bboxes. We now display all bounding boxes.
1 parent 5af00f1 commit 66658d6

4 files changed

Lines changed: 13 additions & 6 deletions

File tree

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ lark = "~1.1.4"
7474
h5py = "~3.13.0"
7575
opencv-python-headless = "~4.12.0.88"
7676
toml = "~0.10.2"
77-
openfoodfacts = "3.3.0"
77+
openfoodfacts = "3.4.0"
7878
imagehash = "~4.3.1"
7979
peewee-migrate = "~1.12.2"
8080
diskcache = "~5.6.3"

robotoff/app/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
809809
"nms_threshold", default=None
810810
)
811811
nms_eta: float | None = req.get_param_as_float("nms_eta", default=None)
812-
812+
nms: float = req.get_param_as_bool("nms", default=True)
813813
available_object_detection_models = list(
814814
ObjectDetectionModel.__members__.keys()
815815
)
@@ -868,6 +868,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
868868
threshold=threshold,
869869
nms_threshold=nms_threshold,
870870
nms_eta=nms_eta,
871+
nms=nms,
871872
)
872873

873874
if output_image:

robotoff/prediction/object_detection/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class ModelConfig(BaseModel):
9292
triton_model_name="price_tag_detection",
9393
image_size=960,
9494
label_names=["price-tag"],
95+
default_threshold=0.25,
9596
),
9697
}
9798

@@ -110,6 +111,8 @@ def add_boxes_and_labels(image_array: np.ndarray, result: ObjectDetectionResult)
110111
instance_masks=None,
111112
use_normalized_coordinates=True,
112113
line_thickness=5,
114+
max_boxes_to_draw=len(result.detection_boxes),
115+
min_score_thresh=0.0,
113116
)
114117
image_with_boxes = Image.fromarray(image_array)
115118
result.boxed_image = image_with_boxes
@@ -127,6 +130,7 @@ def detect_from_image(
127130
threshold: float | None = None,
128131
nms_threshold: float | None = None,
129132
nms_eta: float | None = None,
133+
nms: bool = True,
130134
) -> ObjectDetectionResult:
131135
"""Run an object detection model on an image.
132136
@@ -144,6 +148,7 @@ def detect_from_image(
144148
defaults to None (0.7 will be used).
145149
:param nms_eta: the NMS eta parameter to use, defaults to None (1.0 will be
146150
used).
151+
:param nms: whether to use NMS, defaults to True.
147152
:return: the detection result
148153
"""
149154
threshold = threshold or self.config.default_threshold
@@ -158,6 +163,7 @@ def detect_from_image(
158163
threshold=threshold,
159164
nms_threshold=nms_threshold,
160165
nms_eta=nms_eta,
166+
nms=nms,
161167
)
162168
for metric_name, duration in result.metrics.items():
163169
ml_metrics_logger.info(

0 commit comments

Comments
 (0)