Skip to content

ITEP-32416 Add FP16 inference with feature flag #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
from iai_core.repos.mappers.mongodb_mappers.id_mapper import IDToMongo
from iai_core.repos.mappers.mongodb_mappers.model_mapper import ModelPurgeInfoToMongo, ModelToMongo
from iai_core.repos.storage.binary_repos import ModelBinaryRepo
from iai_core.utils.feature_flags import FeatureFlagProvider

from geti_types import ID, Session

logger = logging.getLogger(__name__)

FEATURE_FLAG_FP16_INFERENCE = "FEATURE_FLAG_FP16_INFERENCE"


class ModelStatusFilter(Enum):
"""enum used to filter models by a list of status' in the model repo"""
Expand Down Expand Up @@ -392,7 +395,7 @@ def __get_latest_model_for_inference_query(
"previous_trained_revision_id": IDToMongo.forward(base_model_id),
"optimization_type": ModelOptimizationType.MO.name,
"has_xai_head": True,
"precision": [ModelPrecision.FP32.name],
"precision": {"$in": [ModelPrecision.FP16.name, ModelPrecision.FP32.name]},
"model_status": {"$in": model_status_filter.value},
}

Expand All @@ -405,8 +408,8 @@ def get_latest_model_for_inference(
Get the MO FP32 with XAI head version of the latest base framework model.
This model is used for inference.

:base_model_id: Optional ID for which to get the latest inference model
:model_status_filter: Optional ModelStatusFilter to apply in query
:param base_model_id: Optional ID for which to get the latest inference model
:param model_status_filter: Optional ModelStatusFilter to apply in query
:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
"""
# Get the ID of the latest base framework model
Expand All @@ -420,15 +423,34 @@ def get_latest_model_for_inference(
base_model_id=base_model_id, model_status_filter=model_status_filter
)

# Use ascending order sorting to retrieve the oldest matching document
return self.get_one(extra_filter=query, earliest=True)
models = self.get_all(extra_filter=query)
Copy link
Preview

Copilot AI May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_all returns all matching models in unspecified order, so next(...) may pick an older FP16 model instead of the latest; recommend sorting by version or _id descending (or using a limited aggregate with sort) before selecting primary or fallback precision.

Suggested change
models = self.get_all(extra_filter=query)
# Ensure models are sorted by version (or _id) in descending order
models = self.get_all(extra_filter=query).sort([("version", DESCENDING), ("_id", DESCENDING)])

Copilot uses AI. Check for mistakes.

# Determine which precision to prioritize
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE)
primary_precision = ModelPrecision.FP16 if use_fp16 else ModelPrecision.FP32
fallback_precision = ModelPrecision.FP32 if use_fp16 else ModelPrecision.FP16

# Try to find model with primary precision
primary_model = next((model for model in models if primary_precision in model.precision), None)
if primary_model:
return primary_model

# Try to find model with fallback precision
fallback_model = next((model for model in models if fallback_precision in model.precision), None)
if fallback_model:
logger.warning(
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision
)
return fallback_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return fallback_model
return fallback_model
logger.warning(f"Fallback {fallback_precision} model also not found.")


logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision)
return NullModel()

def get_latest_model_id_for_inference(
self,
model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED,
) -> ID:
"""
Get the ID of the MO FP32 with XAI head version of the latest base framework model.
Get the ID of the MO FP16 or FP32 with XAI head version of the latest base framework model.
This model is used for inference.

:return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found
Expand All @@ -445,12 +467,33 @@ def get_latest_model_id_for_inference(
base_model_id=base_model_id, model_status_filter=model_status_filter
),
},
{"$project": {"_id": 1}},
{"$project": {"_id": 1, "precision": 1}},
]
matched_docs = list(self.aggregate_read(aggr_pipeline))
if not matched_docs:
return ID()
return IDToMongo.backward(matched_docs[0]["_id"])

# Determine which precision to prioritize
use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE)
primary_precision = ModelPrecision.FP16.name if use_fp16 else ModelPrecision.FP32.name
fallback_precision = ModelPrecision.FP32.name if use_fp16 else ModelPrecision.FP16.name

# Try to find model with primary precision
primary_model = next((doc for doc in matched_docs if primary_precision in doc["precision"]), None)
if primary_model:
return IDToMongo.backward(primary_model["_id"])

# Try to find model with fallback precision
fallback_model = next((doc for doc in matched_docs if fallback_precision in doc["precision"]), None)
if fallback_model:
logger.warning(
"%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision
)
return IDToMongo.backward(fallback_model["_id"])

# If we get here, we have matched_docs but none with the expected precisions
logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision)
return ID()

def update_model_status(self, model: Model, model_status: ModelStatus) -> None:
"""
Expand Down
228 changes: 116 additions & 112 deletions interactive_ai/libs/iai_core_py/tests/repos/test_model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,33 @@
import pytest

from iai_core.entities.datasets import NullDataset
from iai_core.entities.model import Model, ModelFormat, ModelOptimizationType, ModelStatus, NullModel
from iai_core.entities.model import Model, ModelFormat, ModelOptimizationType, ModelPrecision, ModelStatus, NullModel
from iai_core.repos import ModelRepo, ProjectRepo
from iai_core.repos.model_repo import ModelStatusFilter
from iai_core.repos.model_repo import FEATURE_FLAG_FP16_INFERENCE, ModelStatusFilter
from iai_core.repos.storage.binary_repos import ModelBinaryRepo
from tests.test_helpers import empty_model_configuration


def create_model(project, storage, framework, model_format, opt_type, version, previous_model=None, precision=None):
"""Helper to create model instances with consistent settings"""
return Model(
project=project,
model_storage=storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
previous_trained_revision=previous_model,
data_source_dict={"test_data": b"weights_data"},
training_framework=framework,
model_status=ModelStatus.SUCCESS,
model_format=model_format,
has_xai_head=True,
optimization_type=opt_type,
precision=[ModelPrecision.FP32] if not precision else [precision],
version=version,
)


class TestModelRepo:
def test_indexes(self, fxt_model_storage) -> None:
model_repo = ModelRepo(fxt_model_storage.identifier)
Expand Down Expand Up @@ -397,120 +417,115 @@ def test_get_latest(
)
assert latest_model_successful_optimized == fxt_model_optimized

@pytest.mark.parametrize(
"feature_flag_setting", [pytest.param(True, id="fp16-enabled"), pytest.param(False, id="fp16-disabled")]
)
def test_get_latest_model_for_inference(
self, request, fxt_empty_project, fxt_model_storage, fxt_training_framework
self, request, feature_flag_setting, fxt_empty_project, fxt_model_storage, fxt_training_framework, monkeypatch
) -> None:
"""
Models:
Models version 1: M1 (base) -> M2 (MO)
Models version 2: M3 (base) -> M4 (MO) -> M5 (MO) -> M6 (ONNX)
Models version 1: M1 (base) -> M2 (MO, FP32) -> M3 (MO, FP16)
Models version 2: M4 (base) -> M5 (MO, FP32) -> M6 (MO, FP16) -> M7 (ONNX)

Expected:
The latest model for inference is M4 (the first one generated after the base model).
"""
monkeypatch.setenv(FEATURE_FLAG_FP16_INFERENCE, str(feature_flag_setting).lower())
project = fxt_empty_project
model_storage = fxt_model_storage
model_repo = ModelRepo(model_storage.identifier)
request.addfinalizer(lambda: model_repo.delete_all())
m1 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.BASE_FRAMEWORK,
has_xai_head=True,
optimization_type=ModelOptimizationType.NONE,

# Create model hierarchy with both FP16 and FP32 models
# Version 1 models
m1_base = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.BASE_FRAMEWORK,
ModelOptimizationType.NONE,
version=1,
previous_model=None,
)
m2 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
previous_trained_revision=m1,
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.OPENVINO,
has_xai_head=True,
optimization_type=ModelOptimizationType.MO,
m2_fp32 = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.OPENVINO,
ModelOptimizationType.MO,
version=1,
previous_model=m1_base,
precision=ModelPrecision.FP32,
)
m3 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.BASE_FRAMEWORK,
has_xai_head=True,
optimization_type=ModelOptimizationType.NONE,
m3_fp16 = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.OPENVINO,
ModelOptimizationType.MO,
version=1,
previous_model=m1_base,
precision=ModelPrecision.FP16,
)

# Version 2 models
m4_base = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.BASE_FRAMEWORK,
ModelOptimizationType.NONE,
version=2,
previous_model=None,
)
m4 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
previous_trained_revision=m3,
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.OPENVINO,
has_xai_head=True,
optimization_type=ModelOptimizationType.MO,
m5_fp32 = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.OPENVINO,
ModelOptimizationType.MO,
version=2,
previous_model=m4_base,
precision=ModelPrecision.FP32,
)
m5 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
previous_trained_revision=m3,
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.OPENVINO,
has_xai_head=True,
optimization_type=ModelOptimizationType.MO,
m6_fp16 = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.OPENVINO,
ModelOptimizationType.MO,
version=2,
previous_model=m4_base,
precision=ModelPrecision.FP16,
)
m6 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
previous_trained_revision=m3,
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.ONNX,
has_xai_head=True,
optimization_type=ModelOptimizationType.ONNX,
m7_onnx = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.ONNX,
ModelOptimizationType.ONNX,
version=2,
previous_model=m4_base,
precision=ModelPrecision.FP16,
)
model_repo.save_many([m1, m2, m3, m4, m5, m6])

model_repo.save_many([m1_base, m2_fp32, m3_fp16, m4_base, m5_fp32, m6_fp16, m7_onnx])
with (
patch.object(ProjectRepo, "get_by_id", return_value=fxt_empty_project),
):
inference_model = model_repo.get_latest_model_for_inference()
inference_model_id = model_repo.get_latest_model_id_for_inference()
inference_model_for_m1 = model_repo.get_latest_model_for_inference(base_model_id=m1.id_)
inference_model_for_m1 = model_repo.get_latest_model_for_inference(base_model_id=m1_base.id_)

assert inference_model == m4
assert inference_model_id == inference_model.id_
assert inference_model_for_m1 == m2
if feature_flag_setting:
assert inference_model == m6_fp16
assert inference_model_id == inference_model.id_
assert inference_model_for_m1 == m3_fp16
else:
assert inference_model == m5_fp32
assert inference_model_id == inference_model.id_
assert inference_model_for_m1 == m2_fp32

def test_get_latest_with_latest_version(
self, request, fxt_empty_project, fxt_model_storage, fxt_training_framework
Expand All @@ -520,38 +535,27 @@ def test_get_latest_with_latest_version(
model_storage = fxt_model_storage
model_repo = ModelRepo(model_storage.identifier)
request.addfinalizer(lambda: model_repo.delete_all())
m1 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.BASE_FRAMEWORK,
has_xai_head=True,
optimization_type=ModelOptimizationType.NONE,
m1 = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.BASE_FRAMEWORK,
ModelOptimizationType.NONE,
version=1,
previous_model=None,
)
m2 = Model(
project=project,
model_storage=model_storage,
train_dataset=NullDataset(),
configuration=empty_model_configuration(),
id_=ModelRepo.generate_id(),
data_source_dict={"test_data": b"weights_data"},
training_framework=fxt_training_framework,
model_status=ModelStatus.SUCCESS,
model_format=ModelFormat.BASE_FRAMEWORK,
has_xai_head=True,
optimization_type=ModelOptimizationType.NONE,
m2 = create_model(
project,
model_storage,
fxt_training_framework,
ModelFormat.BASE_FRAMEWORK,
ModelOptimizationType.NONE,
version=2,
previous_model=None,
)
# m1.id > m2.id but m2.version > m1.version
m1.id_ = ModelRepo.generate_id()
model_repo.save(m2)
model_repo.save(m1)
model_repo.save_many([m2, m1])
inference_model = model_repo.get_latest()
assert inference_model == m2

Expand Down
Loading
Loading