24
24
from iai_core .repos .mappers .mongodb_mappers .id_mapper import IDToMongo
25
25
from iai_core .repos .mappers .mongodb_mappers .model_mapper import ModelPurgeInfoToMongo , ModelToMongo
26
26
from iai_core .repos .storage .binary_repos import ModelBinaryRepo
27
+ from iai_core .utils .feature_flags import FeatureFlagProvider
27
28
28
29
from geti_types import ID , Session
29
30
30
31
logger = logging .getLogger (__name__ )
31
32
33
+ FEATURE_FLAG_FP16_INFERENCE = "FEATURE_FLAG_FP16_INFERENCE"
34
+
32
35
33
36
class ModelStatusFilter (Enum ):
34
37
"""enum used to filter models by a list of status' in the model repo"""
@@ -392,7 +395,7 @@ def __get_latest_model_for_inference_query(
392
395
"previous_trained_revision_id" : IDToMongo .forward (base_model_id ),
393
396
"optimization_type" : ModelOptimizationType .MO .name ,
394
397
"has_xai_head" : True ,
395
- "precision" : [ModelPrecision .FP32 .name ],
398
+ "precision" : { "$in" : [ModelPrecision .FP16 . name , ModelPrecision . FP32 .name ]} ,
396
399
"model_status" : {"$in" : model_status_filter .value },
397
400
}
398
401
@@ -402,11 +405,11 @@ def get_latest_model_for_inference(
402
405
model_status_filter : ModelStatusFilter = ModelStatusFilter .IMPROVED ,
403
406
) -> Model :
404
407
"""
405
- Get the MO FP32 with XAI head version of the latest base framework model.
408
+ Get the MO FP16 or FP32 with XAI head version of the latest base framework model.
406
409
This model is used for inference.
407
410
408
- :base_model_id: Optional ID for which to get the latest inference model
409
- :model_status_filter: Optional ModelStatusFilter to apply in query
411
+ :param base_model_id: Optional ID for which to get the latest inference model
412
+ :param model_status_filter: Optional ModelStatusFilter to apply in query
410
413
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
411
414
"""
412
415
# Get the ID of the latest base framework model
@@ -420,15 +423,34 @@ def get_latest_model_for_inference(
420
423
base_model_id = base_model_id , model_status_filter = model_status_filter
421
424
)
422
425
423
- # Use ascending order sorting to retrieve the oldest matching document
424
- return self .get_one (extra_filter = query , earliest = True )
426
+ models = list (self .get_all (extra_filter = query , sort_info = [("_id" , 1 )]))
427
+ # Determine which precision to prioritize
428
+ use_fp16 = FeatureFlagProvider .is_enabled (FEATURE_FLAG_FP16_INFERENCE )
429
+ primary_precision = ModelPrecision .FP16 if use_fp16 else ModelPrecision .FP32
430
+ fallback_precision = ModelPrecision .FP32 if use_fp16 else ModelPrecision .FP16
431
+
432
+ # Try to find model with primary precision
433
+ primary_model = next ((model for model in models if primary_precision in model .precision ), None )
434
+ if primary_model :
435
+ return primary_model
436
+
437
+ # Try to find model with fallback precision
438
+ fallback_model = next ((model for model in models if fallback_precision in model .precision ), None )
439
+ if fallback_model :
440
+ logger .warning (
441
+ "%s model requested but not found. Falling back to %s." , primary_precision , fallback_precision
442
+ )
443
+ return fallback_model
444
+
445
+ logger .warning ("Neither %s nor %s models were found." , primary_precision , fallback_precision )
446
+ return NullModel ()
425
447
426
448
def get_latest_model_id_for_inference (
427
449
self ,
428
450
model_status_filter : ModelStatusFilter = ModelStatusFilter .IMPROVED ,
429
451
) -> ID :
430
452
"""
431
- Get the ID of the MO FP32 with XAI head version of the latest base framework model.
453
+ Get the ID of the MO FP16 or FP32 with XAI head version of the latest base framework model.
432
454
This model is used for inference.
433
455
434
456
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
@@ -445,12 +467,34 @@ def get_latest_model_id_for_inference(
445
467
base_model_id = base_model_id , model_status_filter = model_status_filter
446
468
),
447
469
},
448
- {"$project" : {"_id" : 1 }},
470
+ {"$project" : {"_id" : 1 , "precision" : 1 }},
471
+ {"$sort" : {"_id" : 1 }},
449
472
]
450
473
matched_docs = list (self .aggregate_read (aggr_pipeline ))
451
474
if not matched_docs :
452
475
return ID ()
453
- return IDToMongo .backward (matched_docs [0 ]["_id" ])
476
+
477
+ # Determine which precision to prioritize
478
+ use_fp16 = FeatureFlagProvider .is_enabled (FEATURE_FLAG_FP16_INFERENCE )
479
+ primary_precision = ModelPrecision .FP16 .name if use_fp16 else ModelPrecision .FP32 .name
480
+ fallback_precision = ModelPrecision .FP32 .name if use_fp16 else ModelPrecision .FP16 .name
481
+
482
+ # Try to find model with primary precision
483
+ primary_model = next ((doc for doc in matched_docs if primary_precision in doc ["precision" ]), None )
484
+ if primary_model :
485
+ return IDToMongo .backward (primary_model ["_id" ])
486
+
487
+ # Try to find model with fallback precision
488
+ fallback_model = next ((doc for doc in matched_docs if fallback_precision in doc ["precision" ]), None )
489
+ if fallback_model :
490
+ logger .warning (
491
+ "%s model requested but not found. Falling back to %s." , primary_precision , fallback_precision
492
+ )
493
+ return IDToMongo .backward (fallback_model ["_id" ])
494
+
495
+ # If we get here, we have matched_docs but none with the expected precisions
496
+ logger .warning ("Neither %s nor %s models were found." , primary_precision , fallback_precision )
497
+ return ID ()
454
498
455
499
def update_model_status (self , model : Model , model_status : ModelStatus ) -> None :
456
500
"""
0 commit comments