@@ -32,7 +32,6 @@ class Meta:
3232 "button_add_point" ,
3333 "button_remove_point" ,
3434 "button_add_rect" ,
35- # "button_undo", # Dont support undo now
3635 "button_clear" ,
3736 "button_finish_object" ,
3837 ]
@@ -81,6 +80,7 @@ def __init__(self, config_path, on_message) -> None:
8180
8281 # Pre-inference worker
8382 self .pre_inference_thread = None
83+ self .stop_inference = False
8484
8585 def set_auto_labeling_marks (self , marks ):
8686 """Set auto labeling marks"""
@@ -216,6 +216,7 @@ def run_decoder(self, image_embedding, resized_ratio):
216216 ),
217217 }
218218 masks , _ , _ = self .decoder_session .run (None , decoder_inputs )
219+ masks = masks [0 , 0 , :, :] # Only get 1 mask
219220 masks = masks > 0.0
220221 masks = masks .reshape (self .size_after_apply_max_width_height )
221222 return masks
@@ -279,11 +280,15 @@ def predict_shapes(self, image, filename=None) -> AutoLabelingResult:
279280 else :
280281 cv_image = qt_img_to_cv_img (image )
281282 encoder_inputs , resized_ratio = self .pre_process (cv_image )
283+ if self .stop_inference :
284+ return AutoLabelingResult ([], replace = False )
282285 image_embedding = self .run_encoder (encoder_inputs )
283286 self .image_embedding_cache .put (
284287 filename ,
285288 (resized_ratio , image_embedding ),
286289 )
290+ if self .stop_inference :
291+ return AutoLabelingResult ([], replace = False )
287292 masks = self .run_decoder (image_embedding , resized_ratio )
288293 shapes = self .post_process (masks , resized_ratio )
289294 except Exception as e : # noqa
@@ -295,6 +300,9 @@ def predict_shapes(self, image, filename=None) -> AutoLabelingResult:
295300 return result
296301
297302 def unload (self ):
303+ self .stop_inference = True
304+ self .pre_inference_thread .quit ()
305+ self .pre_inference_thread .wait ()
298306 if self .encoder_session :
299307 self .encoder_session = None
300308 if self .decoder_session :
@@ -311,6 +319,8 @@ def preload_worker(self, files):
311319 image = self .load_image_from_filename (filename )
312320 if image is None :
313321 continue
322+ if self .stop_inference :
323+ return
314324 cv_image = qt_img_to_cv_img (image )
315325 encoder_inputs , resized_ratio = self .pre_process (cv_image )
316326 image_embedding = self .run_encoder (encoder_inputs )
0 commit comments