Skip to content

Commit c104709

Browse files
authored
ITEP-32416 Add FP16 inference with feature flag (#233)
1 parent 10bbcd9 commit c104709

File tree

10 files changed

+256
-155
lines changed

10 files changed

+256
-155
lines changed

interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
from iai_core.repos.mappers.mongodb_mappers.id_mapper import IDToMongo
2525
from iai_core.repos.mappers.mongodb_mappers.model_mapper import ModelPurgeInfoToMongo, ModelToMongo
2626
from iai_core.repos.storage.binary_repos import ModelBinaryRepo
27+
from iai_core.utils.feature_flags import FeatureFlagProvider
2728

2829
from geti_types import ID, Session
2930

3031
logger = logging.getLogger(__name__)
3132

33+
FEATURE_FLAG_FP16_INFERENCE = "FEATURE_FLAG_FP16_INFERENCE"
34+
3235

3336
class ModelStatusFilter(Enum):
3437
"""enum used to filter models by a list of status' in the model repo"""
@@ -392,7 +395,7 @@ def __get_latest_model_for_inference_query(
392395
"previous_trained_revision_id": IDToMongo.forward(base_model_id),
393396
"optimization_type": ModelOptimizationType.MO.name,
394397
"has_xai_head": True,
395-
"precision": [ModelPrecision.FP32.name],
398+
"precision": {"$in": [ModelPrecision.FP16.name, ModelPrecision.FP32.name]},
396399
"model_status": {"$in": model_status_filter.value},
397400
}
398401

@@ -402,11 +405,11 @@ def get_latest_model_for_inference(
402405
model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED,
403406
) -> Model:
404407
"""
405-
Get the MO FP32 with XAI head version of the latest base framework model.
408+
Get the MO FP16 or FP32 with XAI head version of the latest base framework model.
406409
This model is used for inference.
407410
408-
:base_model_id: Optional ID for which to get the latest inference model
409-
:model_status_filter: Optional ModelStatusFilter to apply in query
411+
:param base_model_id: Optional ID for which to get the latest inference model
412+
:param model_status_filter: Optional ModelStatusFilter to apply in query
410413
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
411414
"""
412415
# Get the ID of the latest base framework model
@@ -420,15 +423,34 @@ def get_latest_model_for_inference(
420423
base_model_id=base_model_id, model_status_filter=model_status_filter
421424
)
422425

423-
# Use ascending order sorting to retrieve the oldest matching document
424-
return self.get_one(extra_filter=query, earliest=True)
426+
models = list(self.get_all(extra_filter=query, sort_info=[("_id", 1)]))
427+
# Determine which precision to prioritize
428+
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE)
429+
primary_precision = ModelPrecision.FP16 if use_fp16 else ModelPrecision.FP32
430+
fallback_precision = ModelPrecision.FP32 if use_fp16 else ModelPrecision.FP16
431+
432+
# Try to find model with primary precision
433+
primary_model = next((model for model in models if primary_precision in model.precision), None)
434+
if primary_model:
435+
return primary_model
436+
437+
# Try to find model with fallback precision
438+
fallback_model = next((model for model in models if fallback_precision in model.precision), None)
439+
if fallback_model:
440+
logger.warning(
441+
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision
442+
)
443+
return fallback_model
444+
445+
logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision)
446+
return NullModel()
425447

426448
def get_latest_model_id_for_inference(
427449
self,
428450
model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED,
429451
) -> ID:
430452
"""
431-
Get the ID of the MO FP32 with XAI head version of the latest base framework model.
453+
Get the ID of the MO FP16 or FP32 with XAI head version of the latest base framework model.
432454
This model is used for inference.
433455
434456
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
@@ -445,12 +467,34 @@ def get_latest_model_id_for_inference(
445467
base_model_id=base_model_id, model_status_filter=model_status_filter
446468
),
447469
},
448-
{"$project": {"_id": 1}},
470+
{"$project": {"_id": 1, "precision": 1}},
471+
{"$sort": {"_id": 1}},
449472
]
450473
matched_docs = list(self.aggregate_read(aggr_pipeline))
451474
if not matched_docs:
452475
return ID()
453-
return IDToMongo.backward(matched_docs[0]["_id"])
476+
477+
# Determine which precision to prioritize
478+
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE)
479+
primary_precision = ModelPrecision.FP16.name if use_fp16 else ModelPrecision.FP32.name
480+
fallback_precision = ModelPrecision.FP32.name if use_fp16 else ModelPrecision.FP16.name
481+
482+
# Try to find model with primary precision
483+
primary_model = next((doc for doc in matched_docs if primary_precision in doc["precision"]), None)
484+
if primary_model:
485+
return IDToMongo.backward(primary_model["_id"])
486+
487+
# Try to find model with fallback precision
488+
fallback_model = next((doc for doc in matched_docs if fallback_precision in doc["precision"]), None)
489+
if fallback_model:
490+
logger.warning(
491+
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision
492+
)
493+
return IDToMongo.backward(fallback_model["_id"])
494+
495+
# If we get here, we have matched_docs but none with the expected precisions
496+
logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision)
497+
return ID()
454498

455499
def update_model_status(self, model: Model, model_status: ModelStatus) -> None:
456500
"""

0 commit comments

Comments
 (0)