Skip to content

Commit 0f19e6b

Browse files
committed
Make parallel model building optional
1 parent c70d82d commit 0f19e6b

File tree

2 files changed

+82
-40
lines changed

2 files changed

+82
-40
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ You can configure connection to our Learning Loop by specifying the following en
2727
| LOOP_PROJECT | PROJECT | Project ID | Detector (opt.) |
2828
| MIN_UNCERTAIN_THRESHOLD | - | smallest confidence (float) at which auto-upload will happen | Detector (opt.) |
2929
| MAX_UNCERTAIN_THRESHOLD | - | largest confidence (float) at which auto-upload will happen | Detector (opt.) |
30+
| EXCLUSIVE_MODEL_BUILD | - | Reject detections during update to save VRAM (set to 0) | Detector (opt.) |
3031
| INFERENCE_BATCH_SIZE | - | Batch size of trainer when calculating detections | Trainer (opt.) |
3132
| RESTART_AFTER_TRAINING | - | Restart the trainer after training (set to 1) | Trainer (opt.) |
3233
| KEEP_OLD_TRAININGS | - | Do not delete old trainings (set to 1) | Trainer (opt.) |

learning_loop_node/detector/detector_node.py

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
from dataclasses import asdict, dataclass
99
from datetime import datetime
10-
from typing import Dict, List, Literal, Optional
10+
from typing import Dict, List, Literal, Optional, Union
1111

1212
import numpy as np
1313

@@ -46,19 +46,14 @@
4646
from .rest import upload as rest_upload
4747

4848

49-
@dataclass
50-
class _ActiveDetector:
51-
logic: DetectorLogic
52-
model_info: ModelInformation
53-
54-
5549
class DetectorNode(Node):
5650

5751
def __init__(self, name: str, detector_factory: DetectorLogicFactory,
5852
uuid: Optional[str] = None, use_backdoor_controls: bool = False) -> None:
5953
super().__init__(name, uuid=uuid, node_type='detector', needs_login=False, needs_sio=False)
6054
self._detector_factory = detector_factory
61-
self._detector: Optional[_ActiveDetector] = None
55+
self._detector: _DetectorState = _Initializing()
56+
self._exclusive_model_build: bool = os.environ.get('EXCLUSIVE_MODEL_BUILD', '0').lower() in ('1', 'true')
6257
self._remaining_init_attempts: int = 2
6358
self.organization = environment_reader.organization()
6459
self.project = environment_reader.project()
@@ -105,13 +100,13 @@ def get_about_response(self) -> AboutResponse:
105100
return AboutResponse(
106101
operation_mode=self.operation_mode.value,
107102
state=self.status.state,
108-
model_info=self._detector.model_info if self._detector else None,
103+
model_info=self._detector.model_info if isinstance(self._detector, _ActiveDetector) else None,
109104
target_model=self.target_model.version if self.target_model else None,
110105
version_control=self.version_control.value
111106
)
112107

113108
def get_model_version_response(self) -> ModelVersionResponse:
114-
current_version = self._detector.model_info.version if self._detector else 'None'
109+
current_version = self._detector.model_info.version if isinstance(self._detector, _ActiveDetector) else 'None'
115110
target_version = self.target_model.version if self.target_model is not None else 'None'
116111
loop_version = self.loop_deployment_target.version if self.loop_deployment_target is not None else 'None'
117112

@@ -258,14 +253,11 @@ async def detect(sid, data: Dict) -> Dict:
258253
autoupload=data.get('autoupload', 'filtered'),
259254
creation_date=data.get('creation_date', None)
260255
)
261-
if det is None:
262-
return {'error': 'no model loaded'}
263-
detection_dict = jsonable_encoder(asdict(det))
264-
return detection_dict
256+
return jsonable_encoder(asdict(det))
257+
except DetectorUnavailableError as e:
258+
return {'error': str(e)}
265259
except Exception as e:
266260
self.log.exception('could not detect via socketio')
267-
# with open('/tmp/bad_img_from_socket_io.jpg', 'wb') as f:
268-
# f.write(data['image'])
269261
return {'error': str(e)}
270262

271263
@self.sio.event
@@ -292,21 +284,22 @@ async def batch_detect(sid, data: Dict) -> Dict:
292284
autoupload=data.get('autoupload', 'filtered'),
293285
creation_date=data.get('creation_date', None)
294286
)
295-
if det is None:
296-
return {'error': 'no model loaded'}
297-
detection_dict = jsonable_encoder(asdict(det))
298-
return detection_dict
287+
return jsonable_encoder(asdict(det))
288+
except DetectorUnavailableError as e:
289+
return {'error': str(e)}
299290
except Exception as e:
300291
self.log.exception('could not detect via socketio')
301-
# with open('/tmp/bad_img_from_socket_io.jpg', 'wb') as f:
302-
# f.write(data['image'])
303292
return {'error': str(e)}
304293

305294
@self.sio.event
306295
async def info(sid) -> Dict:
307-
if self._detector is not None:
308-
return asdict(self._detector.model_info)
309-
return {"status": "No model loaded"}
296+
match self._detector:
297+
case _ActiveDetector() as d:
298+
return asdict(d.model_info)
299+
case _Updating() as u:
300+
return {"status": f"Updating model to version {u.version}"}
301+
case _Initializing():
302+
return {"status": "No model loaded"}
310303

311304
@self.sio.event
312305
async def about(sid) -> Dict:
@@ -357,8 +350,11 @@ async def upload(sid, data: Dict) -> Dict:
357350
except Exception as e:
358351
self.log.exception('could not parse detections')
359352
return {'error': str(e)}
360-
if self._detector is not None:
361-
image_metadata = self.add_category_id_to_detections(self._detector.model_info, image_metadata)
353+
try:
354+
detector = unwrap_detector(self._detector)
355+
image_metadata = self.add_category_id_to_detections(detector.model_info, image_metadata)
356+
except DetectorUnavailableError as e:
357+
self.log.warning('Cannot add category IDs: %s', e)
362358
else:
363359
image_metadata = ImageMetadata()
364360

@@ -400,7 +396,7 @@ async def on_repeat(self) -> None:
400396
async def _sync_status_with_loop(self) -> None:
401397
"""Sync status of the detector with the Learning Loop."""
402398

403-
current_model = self._detector.model_info.version if self._detector else None
399+
current_model = self._detector.model_info.version if isinstance(self._detector, _ActiveDetector) else None
404400
target_model_version = self.target_model.version if self.target_model else None
405401

406402
status = DetectorStatus(
@@ -444,7 +440,14 @@ async def _update_model_if_required(self) -> None:
444440
self.log.debug('not running any updates; target model is None')
445441
return
446442

447-
current_version = self._detector.model_info.version if self._detector else None
443+
match self._detector:
444+
case _ActiveDetector() as d:
445+
current_version = d.model_info.version
446+
case _Updating():
447+
self.log.debug('not checking for updates; model update already in progress')
448+
return
449+
case _Initializing():
450+
current_version = None
448451

449452
if current_version != self.target_model.version:
450453
self.log.info('Updating model from %s to %s',
@@ -526,12 +529,20 @@ async def _update_model(self, target_model: ModelInformation) -> None:
526529
async def _build_and_swap_detector(self, model_dir: str) -> None:
527530
"""Load ModelInformation from model_dir, build a new DetectorLogic via the factory,
528531
then atomically swap self._detector when ready.
529-
The old detector continues to serve requests until the swap."""
532+
The old detector continues to serve requests until the swap.
533+
534+
If EXCLUSIVE_MODEL_BUILD is set and a detector is active, the old detector is torn down
535+
first (freeing e.g. GPU VRAM) and detections are rejected until the new one is ready."""
530536
logging.info('Loading model from %s', model_dir)
531537
model_info = ModelInformation.load_from_disk(model_dir)
532538
if model_info is None:
533539
logging.warning('No model.json found in %s', model_dir)
534540
return
541+
542+
if self._exclusive_model_build and isinstance(self._detector, _ActiveDetector):
543+
async with self.detection_lock: # wait for in-flight detections to finish
544+
self._detector = _Updating(version=model_info.version)
545+
535546
try:
536547
new_detector = await self._detector_factory.build(model_info)
537548
logging.info('Successfully built detector for model %s', model_info)
@@ -540,6 +551,8 @@ async def _build_and_swap_detector(self, model_dir: str) -> None:
540551
self._remaining_init_attempts -= 1
541552
logging.error('Could not build detector for model %s. Retries left: %s. Error: %s',
542553
model_info, self._remaining_init_attempts, e)
554+
if not isinstance(self._detector, _ActiveDetector):
555+
self._detector = _Initializing() # no model to fall back to
543556
if self._remaining_init_attempts == 0:
544557
raise NodeNeedsRestartError('Could not build detector') from None
545558
raise
@@ -587,7 +600,7 @@ async def get_detections(self,
587600
camera_id: Optional[str] = None,
588601
source: Optional[str] = None,
589602
autoupload: Literal['filtered', 'all', 'disabled'],
590-
creation_date: Optional[str] = None) -> Optional[ImageMetadata]:
603+
creation_date: Optional[str] = None) -> ImageMetadata:
591604
"""
592605
Main processing function for the detector node.
593606
@@ -596,11 +609,8 @@ async def get_detections(self,
596609
cares about uploading to the loop and returns the detections as ImageMetadata object.
597610
Returns None if no model is loaded.
598611
"""
599-
detector = self._detector
600-
if detector is None:
601-
return None
602-
603612
async with self.detection_lock:
613+
detector = unwrap_detector(self._detector)
604614
metadata = await run.io_bound(detector.logic.evaluate, image)
605615

606616
metadata.tags.extend(tags)
@@ -629,19 +639,16 @@ async def get_batch_detections(self,
629639
camera_id: Optional[str] = None,
630640
source: Optional[str] = None,
631641
autoupload: str = 'filtered',
632-
creation_date: Optional[str] = None) -> Optional[ImagesMetadata]:
642+
creation_date: Optional[str] = None) -> ImagesMetadata:
633643
"""
634644
Processing function for the detector node when a batch inference is requested via SocketIO.
635645
636646
This function infers the detections from all images,
637647
cares about uploading to the loop and returns the detections as a list of ImageMetadata.
638648
Returns None if no model is loaded.
639649
"""
640-
detector = self._detector
641-
if detector is None:
642-
return None
643-
644650
async with self.detection_lock:
651+
detector = unwrap_detector(self._detector)
645652
all_detections = await run.io_bound(detector.logic.batch_evaluate, images)
646653

647654
for metadata in all_detections.items:
@@ -715,6 +722,40 @@ def register_sio_events(self, sio_client: AsyncClient):
715722
pass
716723

717724

725+
class _Initializing:
726+
"""No model ready yet — first build pending or in progress."""
727+
728+
729+
@dataclass
730+
class _Updating:
731+
"""Exclusive model replacement in progress — detections rejected."""
732+
version: str
733+
734+
735+
@dataclass
736+
class _ActiveDetector:
737+
logic: DetectorLogic
738+
model_info: ModelInformation
739+
740+
741+
_DetectorState = Union[_Initializing, _Updating, _ActiveDetector]
742+
743+
744+
class DetectorUnavailableError(Exception):
745+
pass
746+
747+
748+
def unwrap_detector(state: '_DetectorState') -> _ActiveDetector:
749+
"""Return the active detector or raise DetectorUnavailableError."""
750+
match state:
751+
case _ActiveDetector() as d:
752+
return d
753+
case _Updating() as u:
754+
raise DetectorUnavailableError(f'updating model to version {u.version}')
755+
case _Initializing():
756+
raise DetectorUnavailableError('detector not yet initialized')
757+
758+
718759
@contextlib.contextmanager
719760
def step_into(new_dir):
720761
previous_dir = os.getcwd()

0 commit comments

Comments
 (0)