Skip to content

Commit e6c2670

Browse files
committed
Fix model path + Add ONNX validation
1 parent 213e52c commit e6c2670

3 files changed

Lines changed: 24 additions & 3 deletions

File tree

anylabeling/configs/auto_labeling/segment_anything_vit_b_quant.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
type: segment_anything
2-
name: segment_anything_vit_b-r20230416
2+
name: segment_anything_vit_b_quant-r20230416
33
display_name: Segment Anything (ViT-B Quant)
44
encoder_model_path: https://github.com/vietanhdev/anylabeling-assets/releases/download/v0.2.0/segment_anything_vit_b_encoder_quant.onnx
55
decoder_model_path: https://github.com/vietanhdev/anylabeling-assets/releases/download/v0.2.0/segment_anything_vit_b_decoder_quant.onnx

anylabeling/services/auto_labeling/model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import abstractmethod
66

77
import yaml
8+
import onnx
89

910
from PyQt5.QtCore import QFile
1011
from PyQt5.QtGui import QImage
@@ -86,7 +87,17 @@ def get_model_abs_path(self, model_path, model_folder_name):
8687
)
8788
)
8889
if os.path.exists(model_abs_path):
89-
return model_abs_path
90+
if model_abs_path.lower().endswith(".onnx"):
91+
try:
92+
onnx.checker.check_model(model_abs_path)
93+
except onnx.checker.ValidationError as e:
94+
print("The model is invalid: %s" % e)
95+
print("Action: Delete and redownload...")
96+
os.remove(model_abs_path)
97+
else:
98+
return model_abs_path
99+
else:
100+
return model_abs_path
90101
pathlib.Path(model_abs_path).parent.mkdir(parents=True, exist_ok=True)
91102

92103
# Download model from url

anylabeling/services/auto_labeling/segment_anything.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ class Meta:
3232
"button_add_point",
3333
"button_remove_point",
3434
"button_add_rect",
35-
# "button_undo", # Dont support undo now
3635
"button_clear",
3736
"button_finish_object",
3837
]
@@ -81,6 +80,7 @@ def __init__(self, config_path, on_message) -> None:
8180

8281
# Pre-inference worker
8382
self.pre_inference_thread = None
83+
self.stop_inference = False
8484

8585
def set_auto_labeling_marks(self, marks):
8686
"""Set auto labeling marks"""
@@ -216,6 +216,7 @@ def run_decoder(self, image_embedding, resized_ratio):
216216
),
217217
}
218218
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
219+
masks = masks[0, 0, :, :] # Only get 1 mask
219220
masks = masks > 0.0
220221
masks = masks.reshape(self.size_after_apply_max_width_height)
221222
return masks
@@ -279,11 +280,15 @@ def predict_shapes(self, image, filename=None) -> AutoLabelingResult:
279280
else:
280281
cv_image = qt_img_to_cv_img(image)
281282
encoder_inputs, resized_ratio = self.pre_process(cv_image)
283+
if self.stop_inference:
284+
return AutoLabelingResult([], replace=False)
282285
image_embedding = self.run_encoder(encoder_inputs)
283286
self.image_embedding_cache.put(
284287
filename,
285288
(resized_ratio, image_embedding),
286289
)
290+
if self.stop_inference:
291+
return AutoLabelingResult([], replace=False)
287292
masks = self.run_decoder(image_embedding, resized_ratio)
288293
shapes = self.post_process(masks, resized_ratio)
289294
except Exception as e: # noqa
@@ -295,6 +300,9 @@ def predict_shapes(self, image, filename=None) -> AutoLabelingResult:
295300
return result
296301

297302
def unload(self):
303+
self.stop_inference = True
304+
self.pre_inference_thread.quit()
305+
self.pre_inference_thread.wait()
298306
if self.encoder_session:
299307
self.encoder_session = None
300308
if self.decoder_session:
@@ -311,6 +319,8 @@ def preload_worker(self, files):
311319
image = self.load_image_from_filename(filename)
312320
if image is None:
313321
continue
322+
if self.stop_inference:
323+
return
314324
cv_image = qt_img_to_cv_img(image)
315325
encoder_inputs, resized_ratio = self.pre_process(cv_image)
316326
image_embedding = self.run_encoder(encoder_inputs)

0 commit comments

Comments
 (0)