Skip to content

Commit 449ef58

Browse files
committed
Redesign DetectorLogic to allow parallel detections and model updates
DetectorLogic instances are not short(er) lived and only represent a single model version. If a new version is published, a new version is built in the background and replaces the active detector when it's ready.
1 parent 5aa49f9 commit 449ef58

File tree

11 files changed

+151
-127
lines changed

11 files changed

+151
-127
lines changed

learning_loop_node/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
22

33
# from . import log_conf
4-
from .detector.detector_logic import DetectorLogic
4+
from .detector.detector_logic import DetectorLogic, DetectorLogicFactory
55
from .detector.detector_node import DetectorNode
66
from .globals import GLOBALS
77
from .trainer.trainer_node import TrainerNode
88

9-
__all__ = ['TrainerNode', 'DetectorNode', 'DetectorLogic', 'GLOBALS']
9+
__all__ = ['TrainerNode', 'DetectorNode', 'DetectorLogic', 'DetectorLogicFactory', 'GLOBALS']
1010

1111
logging.info('>>>>>>>>>>>>>>>>>> LOOP INITIALIZED <<<<<<<<<<<<<<<<<<<<<<<')
Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,42 @@
1-
import logging
21
from abc import abstractmethod
3-
from typing import List, Optional
2+
from typing import List, Protocol
43

54
import numpy as np
65

76
from ..data_classes import ImageMetadata, ImagesMetadata, ModelInformation
8-
from ..globals import GLOBALS
9-
from .exceptions import NodeNeedsRestartError
107

118

12-
class DetectorLogic():
13-
14-
def __init__(self, model_format: str) -> None:
15-
self.model_format: str = model_format
16-
self.model_info: Optional[ModelInformation] = None
17-
18-
self._remaining_init_attempts: int = 2
9+
class DetectorLogicFactory(Protocol):
10+
"""Protocol for building DetectorLogic instances.
1911
20-
async def soft_reload(self):
21-
self.model_info = None
12+
The factory controls how the detector is constructed — implementations
13+
can build synchronously or offload heavy work to a thread pool.
14+
"""
15+
model_format: str
2216

23-
def load_model_info_and_init_model(self):
24-
"""
25-
Load model information from disk and initialize the model.
26-
27-
The detector node uses a lock to make sure that this is not called
28-
concurrently with evaluate() or batch_evaluate().
29-
"""
30-
logging.info('Loading model from %s', GLOBALS.data_folder)
31-
self.model_info = ModelInformation.load_from_disk(f'{GLOBALS.data_folder}/model')
32-
if self.model_info is None:
33-
logging.error('No model found')
34-
self.model_info = None
35-
return
17+
async def build(self, model_info: ModelInformation) -> 'DetectorLogic': ...
3618

37-
try:
38-
self.init()
39-
logging.info('Successfully loaded model %s', self.model_info)
40-
self._remaining_init_attempts = 2
41-
except Exception as e:
42-
self._remaining_init_attempts -= 1
43-
logging.error('Could not init model %s. Retries left: %s. Error: %s',
44-
self.model_info, self._remaining_init_attempts, e)
45-
self.model_info = None
46-
if self._remaining_init_attempts == 0:
47-
raise NodeNeedsRestartError('Could not init model') from None
48-
raise
4919

50-
@abstractmethod
51-
def init(self):
52-
"""
53-
Initialize the model.
20+
class DetectorLogic():
21+
"""Pure interface for detector implementations.
5422
55-
Called when a (new) model was loaded.
56-
Model information available via `self.model_info`
57-
The detector node uses a lock to make sure that this is not called
58-
concurrently with evaluate() or batch_evaluate().
59-
"""
23+
Subclasses receive the ModelInformation via their constructor (called by the
24+
DetectorLogic factory in DetectorNode) and are free to store or ignore it.
25+
"""
6026

6127
@abstractmethod
6228
def evaluate(self, image: np.ndarray) -> ImageMetadata:
63-
"""
64-
Evaluate the image and return the detections.
29+
"""Evaluate the image and return the detections.
6530
6631
Called by the detector node when an image should be evaluated (REST or SocketIO).
6732
The resulting detections should be stored in the ImageMetadata.
6833
Tags stored in the ImageMetadata will be uploaded to the learning loop.
69-
The function should return empty metadata if the detector is not initialized.
7034
"""
7135

7236
@abstractmethod
7337
def batch_evaluate(self, images: List[np.ndarray]) -> ImagesMetadata:
74-
"""
75-
Evaluate a batch of images and return the detections.
38+
"""Evaluate a batch of images and return the detections.
7639
7740
The resulting detections per image should be stored in the ImagesMetadata.
7841
Tags stored in the ImagesMetadata will be uploaded to the learning loop.
79-
The function should return empty metadata if the detector is not initialized.
8042
"""

learning_loop_node/detector/detector_node.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import asyncio
22
import contextlib
3+
import logging
34
import os
45
import shutil
56
import subprocess
67
import sys
7-
from dataclasses import asdict
8+
from dataclasses import asdict, dataclass
89
from datetime import datetime
910
from typing import Dict, List, Literal, Optional
1011

@@ -32,7 +33,7 @@
3233
from ..helpers import background_tasks, environment_reader, run
3334
from ..helpers.misc import numpy_image_from_dict
3435
from ..node import Node
35-
from .detector_logic import DetectorLogic
36+
from .detector_logic import DetectorLogic, DetectorLogicFactory
3637
from .exceptions import NodeNeedsRestartError
3738
from .inbox_filter.relevance_filter import RelevanceFilter
3839
from .outbox import Outbox
@@ -45,11 +46,20 @@
4546
from .rest import upload as rest_upload
4647

4748

49+
@dataclass
50+
class _ActiveDetector:
51+
logic: DetectorLogic
52+
model_info: ModelInformation
53+
54+
4855
class DetectorNode(Node):
4956

50-
def __init__(self, name: str, detector: DetectorLogic, uuid: Optional[str] = None, use_backdoor_controls: bool = False) -> None:
57+
def __init__(self, name: str, detector_factory: DetectorLogicFactory,
58+
uuid: Optional[str] = None, use_backdoor_controls: bool = False) -> None:
5159
super().__init__(name, uuid=uuid, node_type='detector', needs_login=False, needs_sio=False)
52-
self.detector_logic = detector
60+
self._detector_factory = detector_factory
61+
self._detector: Optional[_ActiveDetector] = None
62+
self._remaining_init_attempts: int = 2
5363
self.organization = environment_reader.organization()
5464
self.project = environment_reader.project()
5565
assert self.organization and self.project, 'Detector node needs an organization and an project'
@@ -95,13 +105,13 @@ def get_about_response(self) -> AboutResponse:
95105
return AboutResponse(
96106
operation_mode=self.operation_mode.value,
97107
state=self.status.state,
98-
model_info=self.detector_logic.model_info, # pylint: disable=protected-access
108+
model_info=self._detector.model_info if self._detector else None,
99109
target_model=self.target_model.version if self.target_model else None,
100110
version_control=self.version_control.value
101111
)
102112

103113
def get_model_version_response(self) -> ModelVersionResponse:
104-
current_version = self.detector_logic.model_info.version if self.detector_logic.model_info is not None else 'None' # pylint: disable=protected-access
114+
current_version = self._detector.model_info.version if self._detector else 'None'
105115
target_version = self.target_model.version if self.target_model is not None else 'None'
106116
loop_version = self.loop_deployment_target.version if self.loop_deployment_target is not None else 'None'
107117

@@ -183,15 +193,10 @@ async def soft_reload(self) -> None:
183193
self.socket_connection_broken = True
184194
await self.on_startup()
185195

186-
# simulate startup
187-
await self.detector_logic.soft_reload()
188-
await self._load_model_info_and_init_model()
189-
self.operation_mode = OperationMode.Idle
190-
191196
async def on_startup(self) -> None:
192197
try:
193198
self.outbox.ensure_continuous_upload()
194-
await self._load_model_info_and_init_model()
199+
await self._build_and_swap_detector()
195200
except Exception:
196201
self.log.exception("error during 'startup'")
197202
self.operation_mode = OperationMode.Idle
@@ -297,8 +302,8 @@ async def batch_detect(sid, data: Dict) -> Dict:
297302

298303
@self.sio.event
299304
async def info(sid) -> Dict:
300-
if self.detector_logic.model_info is not None:
301-
return asdict(self.detector_logic.model_info)
305+
if self._detector is not None:
306+
return asdict(self._detector.model_info)
302307
return {"status": "No model loaded"}
303308

304309
@self.sio.event
@@ -350,8 +355,8 @@ async def upload(sid, data: Dict) -> Dict:
350355
except Exception as e:
351356
self.log.exception('could not parse detections')
352357
return {'error': str(e)}
353-
if self.detector_logic.model_info is not None:
354-
image_metadata = self.add_category_id_to_detections(self.detector_logic.model_info, image_metadata)
358+
if self._detector is not None:
359+
image_metadata = self.add_category_id_to_detections(self._detector.model_info, image_metadata)
355360
else:
356361
image_metadata = ImageMetadata()
357362

@@ -393,11 +398,7 @@ async def on_repeat(self) -> None:
393398
async def _sync_status_with_loop(self) -> None:
394399
"""Sync status of the detector with the Learning Loop."""
395400

396-
if self.detector_logic.model_info is not None:
397-
current_model = self.detector_logic.model_info.version
398-
else:
399-
current_model = None
400-
401+
current_model = self._detector.model_info.version if self._detector else None
401402
target_model_version = self.target_model.version if self.target_model else None
402403

403404
status = DetectorStatus(
@@ -409,7 +410,7 @@ async def _sync_status_with_loop(self) -> None:
409410
operation_mode=self.operation_mode,
410411
current_model=current_model,
411412
target_model=target_model_version,
412-
model_format=self.detector_logic.model_format,
413+
model_format=self._detector_factory.model_format,
413414
)
414415

415416
self.log_status_on_change(status.state, status)
@@ -441,8 +442,7 @@ async def _update_model_if_required(self) -> None:
441442
self.log.debug('not running any updates; target model is None')
442443
return
443444

444-
current_version = self.detector_logic.model_info.version \
445-
if self.detector_logic.model_info is not None else None
445+
current_version = self._detector.model_info.version if self._detector else None
446446

447447
if current_version != self.target_model.version:
448448
self.log.info('Updating model from %s to %s',
@@ -499,7 +499,7 @@ async def _update_model(self, target_model: ModelInformation) -> None:
499499
await self.data_exchanger.download_model(target_model_folder,
500500
Context(organization=self.organization,
501501
project=self.project),
502-
target_model.id, self.detector_logic.model_format)
502+
target_model.id, self._detector_factory.model_format)
503503
self.log.info('Downloaded model %s', target_model.version)
504504
except Exception:
505505
self.log.exception('Could not download model %s', target_model.version)
@@ -517,22 +517,40 @@ async def _update_model(self, target_model: ModelInformation) -> None:
517517
self.log.info('Updated symlink for model to %s', os.readlink(model_symlink))
518518

519519
try:
520-
await self._load_model_info_and_init_model()
520+
await self._build_and_swap_detector()
521521
except NodeNeedsRestartError:
522522
self.log.error('Node needs restart')
523523
sys.exit(0)
524524
except Exception:
525-
self.log.exception('Could not load model, will retry download on next check')
525+
self.log.exception('Could not build detector, will retry download on next check')
526526
shutil.rmtree(target_model_folder, ignore_errors=True)
527527
self.target_model = None
528528
return
529529

530530
await self._sync_status_with_loop()
531531
# self.reload(reason='new model installed')
532532

533-
async def _load_model_info_and_init_model(self) -> None:
534-
async with self.detection_lock:
535-
self.detector_logic.load_model_info_and_init_model()
533+
async def _build_and_swap_detector(self) -> None:
534+
"""Load ModelInformation from disk, build a new DetectorLogic via the factory,
535+
then atomically swap self._detector when ready.
536+
The old detector continues to serve requests until the swap."""
537+
logging.info('Loading model from %s', GLOBALS.data_folder)
538+
model_info = ModelInformation.load_from_disk(f'{GLOBALS.data_folder}/model')
539+
if model_info is None:
540+
logging.info('No model found at %s/model', GLOBALS.data_folder)
541+
return
542+
try:
543+
new_detector = await self._detector_factory.build(model_info)
544+
logging.info('Successfully built detector for model %s', model_info)
545+
self._remaining_init_attempts = 2
546+
except Exception as e:
547+
self._remaining_init_attempts -= 1
548+
logging.error('Could not build detector for model %s. Retries left: %s. Error: %s',
549+
model_info, self._remaining_init_attempts, e)
550+
if self._remaining_init_attempts == 0:
551+
raise NodeNeedsRestartError('Could not build detector') from None
552+
raise
553+
self._detector = _ActiveDetector(new_detector, model_info)
536554

537555
# ================================== API Implementations ==================================
538556

@@ -560,17 +578,21 @@ async def get_detections(self,
560578
camera_id: Optional[str] = None,
561579
source: Optional[str] = None,
562580
autoupload: Literal['filtered', 'all', 'disabled'],
563-
creation_date: Optional[str] = None) -> ImageMetadata:
581+
creation_date: Optional[str] = None) -> Optional[ImageMetadata]:
564582
"""
565583
Main processing function for the detector node.
566584
567585
Used when an image is received via REST or SocketIO.
568-
This function infers the detections from the image,
586+
This function infers the detections from the image,
569587
cares about uploading to the loop and returns the detections as ImageMetadata object.
588+
Returns None if no model is loaded.
570589
"""
590+
detector = self._detector
591+
if detector is None:
592+
return None
571593

572594
async with self.detection_lock:
573-
metadata = await run.io_bound(self.detector_logic.evaluate, image)
595+
metadata = await run.io_bound(detector.logic.evaluate, image)
574596

575597
metadata.tags.extend(tags)
576598
metadata.source = source
@@ -598,16 +620,20 @@ async def get_batch_detections(self,
598620
camera_id: Optional[str] = None,
599621
source: Optional[str] = None,
600622
autoupload: str = 'filtered',
601-
creation_date: Optional[str] = None) -> ImagesMetadata:
623+
creation_date: Optional[str] = None) -> Optional[ImagesMetadata]:
602624
"""
603-
Processing function for the detector node when a a batch inference is requested via SocketIO.
625+
Processing function for the detector node when a batch inference is requested via SocketIO.
604626
605-
This function infers the detections from all images,
627+
This function infers the detections from all images,
606628
cares about uploading to the loop and returns the detections as a list of ImageMetadata.
629+
Returns None if no model is loaded.
607630
"""
631+
detector = self._detector
632+
if detector is None:
633+
return None
608634

609635
async with self.detection_lock:
610-
all_detections = await run.io_bound(self.detector_logic.batch_evaluate, images)
636+
all_detections = await run.io_bound(detector.logic.batch_evaluate, images)
611637

612638
for metadata in all_detections.items:
613639
metadata.tags.extend(tags)

0 commit comments

Comments
 (0)