Skip to content

ITEP-32416 Add FP16 inference with feature flag #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
from iai_core.repos.mappers.mongodb_mappers.id_mapper import IDToMongo
from iai_core.repos.mappers.mongodb_mappers.model_mapper import ModelPurgeInfoToMongo, ModelToMongo
from iai_core.repos.storage.binary_repos import ModelBinaryRepo
from iai_core.utils.feature_flags import FeatureFlagProvider

from geti_types import ID, Session

logger = logging.getLogger(__name__)

FEATURE_FLAG_FP16_INFERENCE = "FEATURE_FLAG_FP16_INFERENCE"


class ModelStatusFilter(Enum):
"""enum used to filter models by a list of status' in the model repo"""
Expand Down Expand Up @@ -392,7 +395,7 @@ def __get_latest_model_for_inference_query(
"previous_trained_revision_id": IDToMongo.forward(base_model_id),
"optimization_type": ModelOptimizationType.MO.name,
"has_xai_head": True,
"precision": [ModelPrecision.FP32.name],
"precision": {"$in": [ModelPrecision.FP16.name, ModelPrecision.FP32.name]},
"model_status": {"$in": model_status_filter.value},
}

Expand All @@ -402,11 +405,11 @@ def get_latest_model_for_inference(
model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED,
) -> Model:
"""
Get the MO FP32 with XAI head version of the latest base framework model.
Get the MO FP16 or FP32 with XAI head version of the latest base framework model.
This model is used for inference.

:base_model_id: Optional ID for which to get the latest inference model
:model_status_filter: Optional ModelStatusFilter to apply in query
:param base_model_id: Optional ID for which to get the latest inference model
:param model_status_filter: Optional ModelStatusFilter to apply in query
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
"""
# Get the ID of the latest base framework model
Expand All @@ -420,15 +423,34 @@ def get_latest_model_for_inference(
base_model_id=base_model_id, model_status_filter=model_status_filter
)

# Use ascending order sorting to retrieve the oldest matching document
return self.get_one(extra_filter=query, earliest=True)
models = list(self.get_all(extra_filter=query, sort_info=[("_id", 1)]))
# Determine which precision to prioritize
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE)
primary_precision = ModelPrecision.FP16 if use_fp16 else ModelPrecision.FP32
fallback_precision = ModelPrecision.FP32 if use_fp16 else ModelPrecision.FP16

# Try to find model with primary precision
primary_model = next((model for model in models if primary_precision in model.precision), None)
if primary_model:
return primary_model

# Try to find model with fallback precision
fallback_model = next((model for model in models if fallback_precision in model.precision), None)
if fallback_model:
logger.warning(
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision
)
return fallback_model

logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision)
return NullModel()

def get_latest_model_id_for_inference(
self,
model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED,
) -> ID:
"""
Get the ID of the MO FP32 with XAI head version of the latest base framework model.
Get the ID of the MO FP16 or FP32 with XAI head version of the latest base framework model.
This model is used for inference.

:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
Expand All @@ -445,12 +467,34 @@ def get_latest_model_id_for_inference(
base_model_id=base_model_id, model_status_filter=model_status_filter
),
},
{"$project": {"_id": 1}},
{"$project": {"_id": 1, "precision": 1}},
{"$sort": {"_id": 1}},
]
matched_docs = list(self.aggregate_read(aggr_pipeline))
if not matched_docs:
return ID()
return IDToMongo.backward(matched_docs[0]["_id"])

# Determine which precision to prioritize
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE)
primary_precision = ModelPrecision.FP16.name if use_fp16 else ModelPrecision.FP32.name
fallback_precision = ModelPrecision.FP32.name if use_fp16 else ModelPrecision.FP16.name

# Try to find model with primary precision
primary_model = next((doc for doc in matched_docs if primary_precision in doc["precision"]), None)
if primary_model:
return IDToMongo.backward(primary_model["_id"])

# Try to find model with fallback precision
fallback_model = next((doc for doc in matched_docs if fallback_precision in doc["precision"]), None)
if fallback_model:
logger.warning(
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision
)
return IDToMongo.backward(fallback_model["_id"])

# If we get here, we have matched_docs but none with the expected precisions
logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision)
return ID()

def update_model_status(self, model: Model, model_status: ModelStatus) -> None:
"""
Expand Down
Loading
Loading