77import sys
88from dataclasses import asdict , dataclass
99from datetime import datetime
10- from typing import Dict , List , Literal , Optional
10+ from typing import Dict , List , Literal , Optional , Union
1111
1212import numpy as np
1313
4646from .rest import upload as rest_upload
4747
4848
49- @dataclass
50- class _ActiveDetector :
51- logic : DetectorLogic
52- model_info : ModelInformation
53-
54-
5549class DetectorNode (Node ):
5650
5751 def __init__ (self , name : str , detector_factory : DetectorLogicFactory ,
5852 uuid : Optional [str ] = None , use_backdoor_controls : bool = False ) -> None :
5953 super ().__init__ (name , uuid = uuid , node_type = 'detector' , needs_login = False , needs_sio = False )
6054 self ._detector_factory = detector_factory
61- self ._detector : Optional [_ActiveDetector ] = None
55+ self ._detector : _DetectorState = _Initializing ()
56+ self ._exclusive_model_build : bool = os .environ .get ('EXCLUSIVE_MODEL_BUILD' , '0' ).lower () in ('1' , 'true' )
6257 self ._remaining_init_attempts : int = 2
6358 self .organization = environment_reader .organization ()
6459 self .project = environment_reader .project ()
@@ -105,13 +100,13 @@ def get_about_response(self) -> AboutResponse:
105100 return AboutResponse (
106101 operation_mode = self .operation_mode .value ,
107102 state = self .status .state ,
108- model_info = self ._detector .model_info if self ._detector else None ,
103+ model_info = self ._detector .model_info if isinstance ( self ._detector , _ActiveDetector ) else None ,
109104 target_model = self .target_model .version if self .target_model else None ,
110105 version_control = self .version_control .value
111106 )
112107
113108 def get_model_version_response (self ) -> ModelVersionResponse :
114- current_version = self ._detector .model_info .version if self ._detector else 'None'
109+ current_version = self ._detector .model_info .version if isinstance ( self ._detector , _ActiveDetector ) else 'None'
115110 target_version = self .target_model .version if self .target_model is not None else 'None'
116111 loop_version = self .loop_deployment_target .version if self .loop_deployment_target is not None else 'None'
117112
@@ -258,14 +253,11 @@ async def detect(sid, data: Dict) -> Dict:
258253 autoupload = data .get ('autoupload' , 'filtered' ),
259254 creation_date = data .get ('creation_date' , None )
260255 )
261- if det is None :
262- return {'error' : 'no model loaded' }
263- detection_dict = jsonable_encoder (asdict (det ))
264- return detection_dict
256+ return jsonable_encoder (asdict (det ))
257+ except DetectorUnavailableError as e :
258+ return {'error' : str (e )}
265259 except Exception as e :
266260 self .log .exception ('could not detect via socketio' )
267- # with open('/tmp/bad_img_from_socket_io.jpg', 'wb') as f:
268- # f.write(data['image'])
269261 return {'error' : str (e )}
270262
271263 @self .sio .event
@@ -292,21 +284,22 @@ async def batch_detect(sid, data: Dict) -> Dict:
292284 autoupload = data .get ('autoupload' , 'filtered' ),
293285 creation_date = data .get ('creation_date' , None )
294286 )
295- if det is None :
296- return {'error' : 'no model loaded' }
297- detection_dict = jsonable_encoder (asdict (det ))
298- return detection_dict
287+ return jsonable_encoder (asdict (det ))
288+ except DetectorUnavailableError as e :
289+ return {'error' : str (e )}
299290 except Exception as e :
300291 self .log .exception ('could not detect via socketio' )
301- # with open('/tmp/bad_img_from_socket_io.jpg', 'wb') as f:
302- # f.write(data['image'])
303292 return {'error' : str (e )}
304293
305294 @self .sio .event
306295 async def info (sid ) -> Dict :
307- if self ._detector is not None :
308- return asdict (self ._detector .model_info )
309- return {"status" : "No model loaded" }
296+ match self ._detector :
297+ case _ActiveDetector () as d :
298+ return asdict (d .model_info )
299+ case _Updating () as u :
300+ return {"status" : f"Updating model to version { u .version } " }
301+ case _Initializing ():
302+ return {"status" : "No model loaded" }
310303
311304 @self .sio .event
312305 async def about (sid ) -> Dict :
@@ -357,8 +350,11 @@ async def upload(sid, data: Dict) -> Dict:
357350 except Exception as e :
358351 self .log .exception ('could not parse detections' )
359352 return {'error' : str (e )}
360- if self ._detector is not None :
361- image_metadata = self .add_category_id_to_detections (self ._detector .model_info , image_metadata )
353+ try :
354+ detector = unwrap_detector (self ._detector )
355+ image_metadata = self .add_category_id_to_detections (detector .model_info , image_metadata )
356+ except DetectorUnavailableError as e :
357+ self .log .warning ('Cannot add category IDs: %s' , e )
362358 else :
363359 image_metadata = ImageMetadata ()
364360
@@ -400,7 +396,7 @@ async def on_repeat(self) -> None:
400396 async def _sync_status_with_loop (self ) -> None :
401397 """Sync status of the detector with the Learning Loop."""
402398
403- current_model = self ._detector .model_info .version if self ._detector else None
399+ current_model = self ._detector .model_info .version if isinstance ( self ._detector , _ActiveDetector ) else None
404400 target_model_version = self .target_model .version if self .target_model else None
405401
406402 status = DetectorStatus (
@@ -444,7 +440,14 @@ async def _update_model_if_required(self) -> None:
444440 self .log .debug ('not running any updates; target model is None' )
445441 return
446442
447- current_version = self ._detector .model_info .version if self ._detector else None
443+ match self ._detector :
444+ case _ActiveDetector () as d :
445+ current_version = d .model_info .version
446+ case _Updating ():
447+ self .log .debug ('not checking for updates; model update already in progress' )
448+ return
449+ case _Initializing ():
450+ current_version = None
448451
449452 if current_version != self .target_model .version :
450453 self .log .info ('Updating model from %s to %s' ,
@@ -526,12 +529,20 @@ async def _update_model(self, target_model: ModelInformation) -> None:
526529 async def _build_and_swap_detector (self , model_dir : str ) -> None :
527530 """Load ModelInformation from model_dir, build a new DetectorLogic via the factory,
528531 then atomically swap self._detector when ready.
529- The old detector continues to serve requests until the swap."""
532+ The old detector continues to serve requests until the swap.
533+
534+ If EXCLUSIVE_MODEL_BUILD is set and a detector is active, the old detector is torn down
535+ first (freeing e.g. GPU VRAM) and detections are rejected until the new one is ready."""
530536 logging .info ('Loading model from %s' , model_dir )
531537 model_info = ModelInformation .load_from_disk (model_dir )
532538 if model_info is None :
533539 logging .warning ('No model.json found in %s' , model_dir )
534540 return
541+
542+ if self ._exclusive_model_build and isinstance (self ._detector , _ActiveDetector ):
543+ async with self .detection_lock : # wait for in-flight detections to finish
544+ self ._detector = _Updating (version = model_info .version )
545+
535546 try :
536547 new_detector = await self ._detector_factory .build (model_info )
537548 logging .info ('Successfully built detector for model %s' , model_info )
@@ -540,6 +551,8 @@ async def _build_and_swap_detector(self, model_dir: str) -> None:
540551 self ._remaining_init_attempts -= 1
541552 logging .error ('Could not build detector for model %s. Retries left: %s. Error: %s' ,
542553 model_info , self ._remaining_init_attempts , e )
554+ if not isinstance (self ._detector , _ActiveDetector ):
555+ self ._detector = _Initializing () # no model to fall back to
543556 if self ._remaining_init_attempts == 0 :
544557 raise NodeNeedsRestartError ('Could not build detector' ) from None
545558 raise
@@ -587,7 +600,7 @@ async def get_detections(self,
587600 camera_id : Optional [str ] = None ,
588601 source : Optional [str ] = None ,
589602 autoupload : Literal ['filtered' , 'all' , 'disabled' ],
590- creation_date : Optional [str ] = None ) -> Optional [ ImageMetadata ] :
603+ creation_date : Optional [str ] = None ) -> ImageMetadata :
591604 """
592605 Main processing function for the detector node.
593606
@@ -596,11 +609,8 @@ async def get_detections(self,
596609 cares about uploading to the loop and returns the detections as ImageMetadata object.
597610 Returns None if no model is loaded.
598611 """
599- detector = self ._detector
600- if detector is None :
601- return None
602-
603612 async with self .detection_lock :
613+ detector = unwrap_detector (self ._detector )
604614 metadata = await run .io_bound (detector .logic .evaluate , image )
605615
606616 metadata .tags .extend (tags )
@@ -629,19 +639,16 @@ async def get_batch_detections(self,
629639 camera_id : Optional [str ] = None ,
630640 source : Optional [str ] = None ,
631641 autoupload : str = 'filtered' ,
632- creation_date : Optional [str ] = None ) -> Optional [ ImagesMetadata ] :
642+ creation_date : Optional [str ] = None ) -> ImagesMetadata :
633643 """
634644 Processing function for the detector node when a batch inference is requested via SocketIO.
635645
636646 This function infers the detections from all images,
637647 cares about uploading to the loop and returns the detections as a list of ImageMetadata.
638648 Returns None if no model is loaded.
639649 """
640- detector = self ._detector
641- if detector is None :
642- return None
643-
644650 async with self .detection_lock :
651+ detector = unwrap_detector (self ._detector )
645652 all_detections = await run .io_bound (detector .logic .batch_evaluate , images )
646653
647654 for metadata in all_detections .items :
@@ -715,6 +722,40 @@ def register_sio_events(self, sio_client: AsyncClient):
715722 pass
716723
717724
725+ class _Initializing :
726+ """No model ready yet — first build pending or in progress."""
727+
728+
729+ @dataclass
730+ class _Updating :
731+ """Exclusive model replacement in progress — detections rejected."""
732+ version : str
733+
734+
735+ @dataclass
736+ class _ActiveDetector :
737+ logic : DetectorLogic
738+ model_info : ModelInformation
739+
740+
741+ _DetectorState = Union [_Initializing , _Updating , _ActiveDetector ]
742+
743+
744+ class DetectorUnavailableError (Exception ):
745+ pass
746+
747+
748+ def unwrap_detector (state : '_DetectorState' ) -> _ActiveDetector :
749+ """Return the active detector or raise DetectorUnavailableError."""
750+ match state :
751+ case _ActiveDetector () as d :
752+ return d
753+ case _Updating () as u :
754+ raise DetectorUnavailableError (f'updating model to version { u .version } ' )
755+ case _Initializing ():
756+ raise DetectorUnavailableError ('detector not yet initialized' )
757+
758+
718759@contextlib .contextmanager
719760def step_into (new_dir ):
720761 previous_dir = os .getcwd ()
0 commit comments