-
Notifications
You must be signed in to change notification settings - Fork 21
ITEP-32416 Add FP16 inference with feature flag #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
itallix
wants to merge
3
commits into
main
Choose a base branch
from
vitalii/use-fp16-inference
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -24,11 +24,14 @@ | |||||||
from iai_core.repos.mappers.mongodb_mappers.id_mapper import IDToMongo | ||||||||
from iai_core.repos.mappers.mongodb_mappers.model_mapper import ModelPurgeInfoToMongo, ModelToMongo | ||||||||
from iai_core.repos.storage.binary_repos import ModelBinaryRepo | ||||||||
from iai_core.utils.feature_flags import FeatureFlagProvider | ||||||||
|
||||||||
from geti_types import ID, Session | ||||||||
|
||||||||
logger = logging.getLogger(__name__) | ||||||||
|
||||||||
FEATURE_FLAG_FP16_INFERENCE = "FEATURE_FLAG_FP16_INFERENCE" | ||||||||
|
||||||||
|
||||||||
class ModelStatusFilter(Enum): | ||||||||
"""enum used to filter models by a list of status' in the model repo""" | ||||||||
|
@@ -392,7 +395,7 @@ def __get_latest_model_for_inference_query( | |||||||
"previous_trained_revision_id": IDToMongo.forward(base_model_id), | ||||||||
"optimization_type": ModelOptimizationType.MO.name, | ||||||||
"has_xai_head": True, | ||||||||
"precision": [ModelPrecision.FP32.name], | ||||||||
"precision": {"$in": [ModelPrecision.FP16.name, ModelPrecision.FP32.name]}, | ||||||||
"model_status": {"$in": model_status_filter.value}, | ||||||||
} | ||||||||
|
||||||||
|
@@ -405,8 +408,8 @@ def get_latest_model_for_inference( | |||||||
Get the MO FP32 with XAI head version of the latest base framework model. | ||||||||
This model is used for inference. | ||||||||
|
||||||||
:base_model_id: Optional ID for which to get the latest inference model | ||||||||
:model_status_filter: Optional ModelStatusFilter to apply in query | ||||||||
:param base_model_id: Optional ID for which to get the latest inference model | ||||||||
:param model_status_filter: Optional ModelStatusFilter to apply in query | ||||||||
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found | ||||||||
""" | ||||||||
# Get the ID of the latest base framework model | ||||||||
|
@@ -420,15 +423,34 @@ def get_latest_model_for_inference( | |||||||
base_model_id=base_model_id, model_status_filter=model_status_filter | ||||||||
) | ||||||||
|
||||||||
# Use ascending order sorting to retrieve the oldest matching document | ||||||||
return self.get_one(extra_filter=query, earliest=True) | ||||||||
models = self.get_all(extra_filter=query) | ||||||||
# Determine which precision to prioritize | ||||||||
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE) | ||||||||
primary_precision = ModelPrecision.FP16 if use_fp16 else ModelPrecision.FP32 | ||||||||
fallback_precision = ModelPrecision.FP32 if use_fp16 else ModelPrecision.FP16 | ||||||||
|
||||||||
# Try to find model with primary precision | ||||||||
primary_model = next((model for model in models if primary_precision in model.precision), None) | ||||||||
if primary_model: | ||||||||
return primary_model | ||||||||
|
||||||||
# Try to find model with fallback precision | ||||||||
fallback_model = next((model for model in models if fallback_precision in model.precision), None) | ||||||||
if fallback_model: | ||||||||
logger.warning( | ||||||||
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision | ||||||||
) | ||||||||
return fallback_model | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision) | ||||||||
return NullModel() | ||||||||
|
||||||||
def get_latest_model_id_for_inference( | ||||||||
self, | ||||||||
model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED, | ||||||||
) -> ID: | ||||||||
""" | ||||||||
Get the ID of the MO FP32 with XAI head version of the latest base framework model. | ||||||||
Get the ID of the MO FP16 or FP32 with XAI head version of the latest base framework model. | ||||||||
This model is used for inference. | ||||||||
|
||||||||
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found | ||||||||
|
@@ -445,12 +467,33 @@ def get_latest_model_id_for_inference( | |||||||
base_model_id=base_model_id, model_status_filter=model_status_filter | ||||||||
), | ||||||||
}, | ||||||||
{"$project": {"_id": 1}}, | ||||||||
{"$project": {"_id": 1, "precision": 1}}, | ||||||||
] | ||||||||
matched_docs = list(self.aggregate_read(aggr_pipeline)) | ||||||||
if not matched_docs: | ||||||||
return ID() | ||||||||
return IDToMongo.backward(matched_docs[0]["_id"]) | ||||||||
|
||||||||
# Determine which precision to prioritize | ||||||||
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE) | ||||||||
primary_precision = ModelPrecision.FP16.name if use_fp16 else ModelPrecision.FP32.name | ||||||||
fallback_precision = ModelPrecision.FP32.name if use_fp16 else ModelPrecision.FP16.name | ||||||||
|
||||||||
# Try to find model with primary precision | ||||||||
primary_model = next((doc for doc in matched_docs if primary_precision in doc["precision"]), None) | ||||||||
if primary_model: | ||||||||
return IDToMongo.backward(primary_model["_id"]) | ||||||||
|
||||||||
# Try to find model with fallback precision | ||||||||
fallback_model = next((doc for doc in matched_docs if fallback_precision in doc["precision"]), None) | ||||||||
if fallback_model: | ||||||||
logger.warning( | ||||||||
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision | ||||||||
) | ||||||||
return IDToMongo.backward(fallback_model["_id"]) | ||||||||
|
||||||||
# If we get here, we have matched_docs but none with the expected precisions | ||||||||
logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision) | ||||||||
return ID() | ||||||||
|
||||||||
def update_model_status(self, model: Model, model_status: ModelStatus) -> None: | ||||||||
""" | ||||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_all returns all matching models in unspecified order, so
next(...)
may pick an older FP16 model instead of the latest; recommend sorting by version or_id
descending (or using a limited aggregate with sort) before selecting primary or fallback precision.Copilot uses AI. Check for mistakes.