11import asyncio
22import contextlib
3+ import logging
34import os
45import shutil
56import subprocess
67import sys
7- from dataclasses import asdict
8+ from dataclasses import asdict , dataclass
89from datetime import datetime
910from typing import Dict , List , Literal , Optional
1011
3233from ..helpers import background_tasks , environment_reader , run
3334from ..helpers .misc import numpy_image_from_dict
3435from ..node import Node
35- from .detector_logic import DetectorLogic
36+ from .detector_logic import DetectorLogic , DetectorLogicFactory
3637from .exceptions import NodeNeedsRestartError
3738from .inbox_filter .relevance_filter import RelevanceFilter
3839from .outbox import Outbox
4546from .rest import upload as rest_upload
4647
4748
49+ @dataclass
50+ class _ActiveDetector :
51+ logic : DetectorLogic
52+ model_info : ModelInformation
53+
54+
4855class DetectorNode (Node ):
4956
50- def __init__ (self , name : str , detector : DetectorLogic , uuid : Optional [str ] = None , use_backdoor_controls : bool = False ) -> None :
57+ def __init__ (self , name : str , detector_factory : DetectorLogicFactory ,
58+ uuid : Optional [str ] = None , use_backdoor_controls : bool = False ) -> None :
5159 super ().__init__ (name , uuid = uuid , node_type = 'detector' , needs_login = False , needs_sio = False )
52- self .detector_logic = detector
60+ self ._detector_factory = detector_factory
61+ self ._detector : Optional [_ActiveDetector ] = None
62+ self ._remaining_init_attempts : int = 2
5363 self .organization = environment_reader .organization ()
5464 self .project = environment_reader .project ()
5565 assert self .organization and self .project , 'Detector node needs an organization and an project'
@@ -95,13 +105,13 @@ def get_about_response(self) -> AboutResponse:
95105 return AboutResponse (
96106 operation_mode = self .operation_mode .value ,
97107 state = self .status .state ,
98- model_info = self .detector_logic .model_info , # pylint: disable=protected-access
108+ model_info = self ._detector .model_info if self . _detector else None ,
99109 target_model = self .target_model .version if self .target_model else None ,
100110 version_control = self .version_control .value
101111 )
102112
103113 def get_model_version_response (self ) -> ModelVersionResponse :
104- current_version = self .detector_logic .model_info .version if self .detector_logic . model_info is not None else 'None' # pylint: disable=protected-access
114+ current_version = self ._detector .model_info .version if self ._detector else 'None'
105115 target_version = self .target_model .version if self .target_model is not None else 'None'
106116 loop_version = self .loop_deployment_target .version if self .loop_deployment_target is not None else 'None'
107117
@@ -183,15 +193,10 @@ async def soft_reload(self) -> None:
183193 self .socket_connection_broken = True
184194 await self .on_startup ()
185195
186- # simulate startup
187- await self .detector_logic .soft_reload ()
188- await self ._load_model_info_and_init_model ()
189- self .operation_mode = OperationMode .Idle
190-
191196 async def on_startup (self ) -> None :
192197 try :
193198 self .outbox .ensure_continuous_upload ()
194- await self ._load_model_info_and_init_model ()
199+ await self ._build_and_swap_detector ()
195200 except Exception :
196201 self .log .exception ("error during 'startup'" )
197202 self .operation_mode = OperationMode .Idle
@@ -297,8 +302,8 @@ async def batch_detect(sid, data: Dict) -> Dict:
297302
298303 @self .sio .event
299304 async def info (sid ) -> Dict :
300- if self .detector_logic . model_info is not None :
301- return asdict (self .detector_logic .model_info )
305+ if self ._detector is not None :
306+ return asdict (self ._detector .model_info )
302307 return {"status" : "No model loaded" }
303308
304309 @self .sio .event
@@ -350,8 +355,8 @@ async def upload(sid, data: Dict) -> Dict:
350355 except Exception as e :
351356 self .log .exception ('could not parse detections' )
352357 return {'error' : str (e )}
353- if self .detector_logic . model_info is not None :
354- image_metadata = self .add_category_id_to_detections (self .detector_logic .model_info , image_metadata )
358+ if self ._detector is not None :
359+ image_metadata = self .add_category_id_to_detections (self ._detector .model_info , image_metadata )
355360 else :
356361 image_metadata = ImageMetadata ()
357362
@@ -393,11 +398,7 @@ async def on_repeat(self) -> None:
393398 async def _sync_status_with_loop (self ) -> None :
394399 """Sync status of the detector with the Learning Loop."""
395400
396- if self .detector_logic .model_info is not None :
397- current_model = self .detector_logic .model_info .version
398- else :
399- current_model = None
400-
401+ current_model = self ._detector .model_info .version if self ._detector else None
401402 target_model_version = self .target_model .version if self .target_model else None
402403
403404 status = DetectorStatus (
@@ -409,7 +410,7 @@ async def _sync_status_with_loop(self) -> None:
409410 operation_mode = self .operation_mode ,
410411 current_model = current_model ,
411412 target_model = target_model_version ,
412- model_format = self .detector_logic .model_format ,
413+ model_format = self ._detector_factory .model_format ,
413414 )
414415
415416 self .log_status_on_change (status .state , status )
@@ -441,8 +442,7 @@ async def _update_model_if_required(self) -> None:
441442 self .log .debug ('not running any updates; target model is None' )
442443 return
443444
444- current_version = self .detector_logic .model_info .version \
445- if self .detector_logic .model_info is not None else None
445+ current_version = self ._detector .model_info .version if self ._detector else None
446446
447447 if current_version != self .target_model .version :
448448 self .log .info ('Updating model from %s to %s' ,
@@ -499,7 +499,7 @@ async def _update_model(self, target_model: ModelInformation) -> None:
499499 await self .data_exchanger .download_model (target_model_folder ,
500500 Context (organization = self .organization ,
501501 project = self .project ),
502- target_model .id , self .detector_logic .model_format )
502+ target_model .id , self ._detector_factory .model_format )
503503 self .log .info ('Downloaded model %s' , target_model .version )
504504 except Exception :
505505 self .log .exception ('Could not download model %s' , target_model .version )
@@ -517,22 +517,40 @@ async def _update_model(self, target_model: ModelInformation) -> None:
517517 self .log .info ('Updated symlink for model to %s' , os .readlink (model_symlink ))
518518
519519 try :
520- await self ._load_model_info_and_init_model ()
520+ await self ._build_and_swap_detector ()
521521 except NodeNeedsRestartError :
522522 self .log .error ('Node needs restart' )
523523 sys .exit (0 )
524524 except Exception :
525- self .log .exception ('Could not load model , will retry download on next check' )
525+ self .log .exception ('Could not build detector , will retry download on next check' )
526526 shutil .rmtree (target_model_folder , ignore_errors = True )
527527 self .target_model = None
528528 return
529529
530530 await self ._sync_status_with_loop ()
531531 # self.reload(reason='new model installed')
532532
533- async def _load_model_info_and_init_model (self ) -> None :
534- async with self .detection_lock :
535- self .detector_logic .load_model_info_and_init_model ()
533+ async def _build_and_swap_detector (self ) -> None :
534+ """Load ModelInformation from disk, build a new DetectorLogic via the factory,
535+ then atomically swap self._detector when ready.
536+ The old detector continues to serve requests until the swap."""
537+ logging .info ('Loading model from %s' , GLOBALS .data_folder )
538+ model_info = ModelInformation .load_from_disk (f'{ GLOBALS .data_folder } /model' )
539+ if model_info is None :
540+ logging .info ('No model found at %s/model' , GLOBALS .data_folder )
541+ return
542+ try :
543+ new_detector = await self ._detector_factory .build (model_info )
544+ logging .info ('Successfully built detector for model %s' , model_info )
545+ self ._remaining_init_attempts = 2
546+ except Exception as e :
547+ self ._remaining_init_attempts -= 1
548+ logging .error ('Could not build detector for model %s. Retries left: %s. Error: %s' ,
549+ model_info , self ._remaining_init_attempts , e )
550+ if self ._remaining_init_attempts == 0 :
551+ raise NodeNeedsRestartError ('Could not build detector' ) from None
552+ raise
553+ self ._detector = _ActiveDetector (new_detector , model_info )
536554
537555# ================================== API Implementations ==================================
538556
@@ -560,17 +578,21 @@ async def get_detections(self,
560578 camera_id : Optional [str ] = None ,
561579 source : Optional [str ] = None ,
562580 autoupload : Literal ['filtered' , 'all' , 'disabled' ],
563- creation_date : Optional [str ] = None ) -> ImageMetadata :
581+ creation_date : Optional [str ] = None ) -> Optional [ ImageMetadata ] :
564582 """
565583 Main processing function for the detector node.
566584
567585 Used when an image is received via REST or SocketIO.
568- This function infers the detections from the image,
586+ This function infers the detections from the image,
569587 cares about uploading to the loop and returns the detections as ImageMetadata object.
588+ Returns None if no model is loaded.
570589 """
590+ detector = self ._detector
591+ if detector is None :
592+ return None
571593
572594 async with self .detection_lock :
573- metadata = await run .io_bound (self . detector_logic .evaluate , image )
595+ metadata = await run .io_bound (detector . logic .evaluate , image )
574596
575597 metadata .tags .extend (tags )
576598 metadata .source = source
@@ -598,16 +620,20 @@ async def get_batch_detections(self,
598620 camera_id : Optional [str ] = None ,
599621 source : Optional [str ] = None ,
600622 autoupload : str = 'filtered' ,
601- creation_date : Optional [str ] = None ) -> ImagesMetadata :
623+ creation_date : Optional [str ] = None ) -> Optional [ ImagesMetadata ] :
602624 """
603- Processing function for the detector node when a a batch inference is requested via SocketIO.
625+ Processing function for the detector node when a batch inference is requested via SocketIO.
604626
605- This function infers the detections from all images,
627+ This function infers the detections from all images,
606628 cares about uploading to the loop and returns the detections as a list of ImageMetadata.
629+ Returns None if no model is loaded.
607630 """
631+ detector = self ._detector
632+ if detector is None :
633+ return None
608634
609635 async with self .detection_lock :
610- all_detections = await run .io_bound (self . detector_logic .batch_evaluate , images )
636+ all_detections = await run .io_bound (detector . logic .batch_evaluate , images )
611637
612638 for metadata in all_detections .items :
613639 metadata .tags .extend (tags )
0 commit comments