diff --git a/interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py b/interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py index 1ed202042..0fcdf3912 100644 --- a/interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py +++ b/interactive_ai/libs/iai_core_py/iai_core/repos/model_repo.py @@ -24,11 +24,14 @@ from iai_core.repos.mappers.mongodb_mappers.id_mapper import IDToMongo from iai_core.repos.mappers.mongodb_mappers.model_mapper import ModelPurgeInfoToMongo, ModelToMongo from iai_core.repos.storage.binary_repos import ModelBinaryRepo +from iai_core.utils.feature_flags import FeatureFlagProvider from geti_types import ID, Session logger = logging.getLogger(__name__) +FEATURE_FLAG_FP16_INFERENCE = "FEATURE_FLAG_FP16_INFERENCE" + class ModelStatusFilter(Enum): """enum used to filter models by a list of status' in the model repo""" @@ -392,7 +395,7 @@ def __get_latest_model_for_inference_query( "previous_trained_revision_id": IDToMongo.forward(base_model_id), "optimization_type": ModelOptimizationType.MO.name, "has_xai_head": True, - "precision": [ModelPrecision.FP32.name], + "precision": {"$in": [ModelPrecision.FP16.name, ModelPrecision.FP32.name]}, "model_status": {"$in": model_status_filter.value}, } @@ -405,8 +408,8 @@ def get_latest_model_for_inference( Get the MO FP32 with XAI head version of the latest base framework model. This model is used for inference. - :base_model_id: Optional ID for which to get the latest inference model - :model_status_filter: Optional ModelStatusFilter to apply in query + :param base_model_id: Optional ID for which to get the latest inference model + :param model_status_filter: Optional ModelStatusFilter to apply in query :return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found """ # Get the ID of the latest base framework model @@ -420,15 +423,34 @@ def get_latest_model_for_inference( base_model_id=base_model_id, model_status_filter=model_status_filter ) - # Use ascending order sorting to retrieve the oldest matching document - return self.get_one(extra_filter=query, earliest=True) + models = self.get_all(extra_filter=query) + # Determine which precision to prioritize + use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE) + primary_precision = ModelPrecision.FP16 if use_fp16 else ModelPrecision.FP32 + fallback_precision = ModelPrecision.FP32 if use_fp16 else ModelPrecision.FP16 + + # Try to find model with primary precision + primary_model = next((model for model in models if primary_precision in model.precision), None) + if primary_model: + return primary_model + + # Try to find model with fallback precision + fallback_model = next((model for model in models if fallback_precision in model.precision), None) + if fallback_model: + logger.warning( + "%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision + ) + return fallback_model + + logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision) + return NullModel() def get_latest_model_id_for_inference( self, model_status_filter: ModelStatusFilter = ModelStatusFilter.IMPROVED, ) -> ID: """ - Get the ID of the MO FP32 with XAI head version of the latest base framework model. + Get the ID of the MO FP16 or FP32 with XAI head version of the latest base framework model. This model is used for inference. :return: The MO model or :class:`~iai_core.entities.model.NullModel` if not found @@ -445,12 +467,33 @@ def get_latest_model_id_for_inference( base_model_id=base_model_id, model_status_filter=model_status_filter ), }, - {"$project": {"_id": 1}}, + {"$project": {"_id": 1, "precision": 1}}, ] matched_docs = list(self.aggregate_read(aggr_pipeline)) if not matched_docs: return ID() - return IDToMongo.backward(matched_docs[0]["_id"]) + + # Determine which precision to prioritize + use_fp16 = FeatureFlagProvider.is_enabled(FEATURE_FLAG_FP16_INFERENCE) + primary_precision = ModelPrecision.FP16.name if use_fp16 else ModelPrecision.FP32.name + fallback_precision = ModelPrecision.FP32.name if use_fp16 else ModelPrecision.FP16.name + + # Try to find model with primary precision + primary_model = next((doc for doc in matched_docs if primary_precision in doc["precision"]), None) + if primary_model: + return IDToMongo.backward(primary_model["_id"]) + + # Try to find model with fallback precision + fallback_model = next((doc for doc in matched_docs if fallback_precision in doc["precision"]), None) + if fallback_model: + logger.warning( + "%s model requested but not found. Falling back to %s.", primary_precision, fallback_precision + ) + return IDToMongo.backward(fallback_model["_id"]) + + # If we get here, we have matched_docs but none with the expected precisions + logger.warning("Neither %s nor %s models were found.", primary_precision, fallback_precision) + return ID() def update_model_status(self, model: Model, model_status: ModelStatus) -> None: """ diff --git a/interactive_ai/libs/iai_core_py/tests/repos/test_model_repo.py b/interactive_ai/libs/iai_core_py/tests/repos/test_model_repo.py index d91176875..7689c2f08 100644 --- a/interactive_ai/libs/iai_core_py/tests/repos/test_model_repo.py +++ b/interactive_ai/libs/iai_core_py/tests/repos/test_model_repo.py @@ -8,13 +8,33 @@ import pytest from iai_core.entities.datasets import NullDataset -from iai_core.entities.model import Model, ModelFormat, ModelOptimizationType, ModelStatus, NullModel +from iai_core.entities.model import Model, ModelFormat, ModelOptimizationType, ModelPrecision, ModelStatus, NullModel from iai_core.repos import ModelRepo, ProjectRepo -from iai_core.repos.model_repo import ModelStatusFilter +from iai_core.repos.model_repo import FEATURE_FLAG_FP16_INFERENCE, ModelStatusFilter from iai_core.repos.storage.binary_repos import ModelBinaryRepo from tests.test_helpers import empty_model_configuration +def create_model(project, storage, framework, model_format, opt_type, version, previous_model=None, precision=None): + """Helper to create model instances with consistent settings""" + return Model( + project=project, + model_storage=storage, + train_dataset=NullDataset(), + configuration=empty_model_configuration(), + id_=ModelRepo.generate_id(), + previous_trained_revision=previous_model, + data_source_dict={"test_data": b"weights_data"}, + training_framework=framework, + model_status=ModelStatus.SUCCESS, + model_format=model_format, + has_xai_head=True, + optimization_type=opt_type, + precision=[ModelPrecision.FP32] if not precision else [precision], + version=version, + ) + + class TestModelRepo: def test_indexes(self, fxt_model_storage) -> None: model_repo = ModelRepo(fxt_model_storage.identifier) @@ -397,120 +417,115 @@ def test_get_latest( ) assert latest_model_successful_optimized == fxt_model_optimized + @pytest.mark.parametrize( + "feature_flag_setting", [pytest.param(True, id="fp16-enabled"), pytest.param(False, id="fp16-disabled")] + ) def test_get_latest_model_for_inference( - self, request, fxt_empty_project, fxt_model_storage, fxt_training_framework + self, request, feature_flag_setting, fxt_empty_project, fxt_model_storage, fxt_training_framework, monkeypatch ) -> None: """ Models: - Models version 1: M1 (base) -> M2 (MO) - Models version 2: M3 (base) -> M4 (MO) -> M5 (MO) -> M6 (ONNX) + Models version 1: M1 (base) -> M2 (MO, FP32) -> M3 (MO, FP16) + Models version 2: M4 (base) -> M5 (MO, FP32) -> M6 (MO, FP16) -> M7 (ONNX) Expected: The latest model for inference is M4 (the first one generated after the base model). """ + monkeypatch.setenv(FEATURE_FLAG_FP16_INFERENCE, str(feature_flag_setting).lower()) project = fxt_empty_project model_storage = fxt_model_storage model_repo = ModelRepo(model_storage.identifier) request.addfinalizer(lambda: model_repo.delete_all()) - m1 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.BASE_FRAMEWORK, - has_xai_head=True, - optimization_type=ModelOptimizationType.NONE, + + # Create model hierarchy with both FP16 and FP32 models + # Version 1 models + m1_base = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.BASE_FRAMEWORK, + ModelOptimizationType.NONE, version=1, + previous_model=None, ) - m2 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - previous_trained_revision=m1, - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.OPENVINO, - has_xai_head=True, - optimization_type=ModelOptimizationType.MO, + m2_fp32 = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.OPENVINO, + ModelOptimizationType.MO, version=1, + previous_model=m1_base, + precision=ModelPrecision.FP32, ) - m3 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.BASE_FRAMEWORK, - has_xai_head=True, - optimization_type=ModelOptimizationType.NONE, + m3_fp16 = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.OPENVINO, + ModelOptimizationType.MO, + version=1, + previous_model=m1_base, + precision=ModelPrecision.FP16, + ) + + # Version 2 models + m4_base = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.BASE_FRAMEWORK, + ModelOptimizationType.NONE, version=2, + previous_model=None, ) - m4 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - previous_trained_revision=m3, - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.OPENVINO, - has_xai_head=True, - optimization_type=ModelOptimizationType.MO, + m5_fp32 = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.OPENVINO, + ModelOptimizationType.MO, version=2, + previous_model=m4_base, + precision=ModelPrecision.FP32, ) - m5 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - previous_trained_revision=m3, - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.OPENVINO, - has_xai_head=True, - optimization_type=ModelOptimizationType.MO, + m6_fp16 = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.OPENVINO, + ModelOptimizationType.MO, version=2, + previous_model=m4_base, + precision=ModelPrecision.FP16, ) - m6 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - previous_trained_revision=m3, - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.ONNX, - has_xai_head=True, - optimization_type=ModelOptimizationType.ONNX, + m7_onnx = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.ONNX, + ModelOptimizationType.ONNX, version=2, + previous_model=m4_base, + precision=ModelPrecision.FP16, ) - model_repo.save_many([m1, m2, m3, m4, m5, m6]) + + model_repo.save_many([m1_base, m2_fp32, m3_fp16, m4_base, m5_fp32, m6_fp16, m7_onnx]) with ( patch.object(ProjectRepo, "get_by_id", return_value=fxt_empty_project), ): inference_model = model_repo.get_latest_model_for_inference() inference_model_id = model_repo.get_latest_model_id_for_inference() - inference_model_for_m1 = model_repo.get_latest_model_for_inference(base_model_id=m1.id_) + inference_model_for_m1 = model_repo.get_latest_model_for_inference(base_model_id=m1_base.id_) - assert inference_model == m4 - assert inference_model_id == inference_model.id_ - assert inference_model_for_m1 == m2 + if feature_flag_setting: + assert inference_model == m6_fp16 + assert inference_model_id == inference_model.id_ + assert inference_model_for_m1 == m3_fp16 + else: + assert inference_model == m5_fp32 + assert inference_model_id == inference_model.id_ + assert inference_model_for_m1 == m2_fp32 def test_get_latest_with_latest_version( self, request, fxt_empty_project, fxt_model_storage, fxt_training_framework @@ -520,38 +535,27 @@ def test_get_latest_with_latest_version( model_storage = fxt_model_storage model_repo = ModelRepo(model_storage.identifier) request.addfinalizer(lambda: model_repo.delete_all()) - m1 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.BASE_FRAMEWORK, - has_xai_head=True, - optimization_type=ModelOptimizationType.NONE, + m1 = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.BASE_FRAMEWORK, + ModelOptimizationType.NONE, version=1, + previous_model=None, ) - m2 = Model( - project=project, - model_storage=model_storage, - train_dataset=NullDataset(), - configuration=empty_model_configuration(), - id_=ModelRepo.generate_id(), - data_source_dict={"test_data": b"weights_data"}, - training_framework=fxt_training_framework, - model_status=ModelStatus.SUCCESS, - model_format=ModelFormat.BASE_FRAMEWORK, - has_xai_head=True, - optimization_type=ModelOptimizationType.NONE, + m2 = create_model( + project, + model_storage, + fxt_training_framework, + ModelFormat.BASE_FRAMEWORK, + ModelOptimizationType.NONE, version=2, + previous_model=None, ) # m1.id > m2.id but m2.version > m1.version m1.id_ = ModelRepo.generate_id() - model_repo.save(m2) - model_repo.save(m1) + model_repo.save_many([m2, m1]) inference_model = model_repo.get_latest() assert inference_model == m2 diff --git a/interactive_ai/workflows/geti_domain/common/jobs_common/features/feature_flag_provider.py b/interactive_ai/workflows/geti_domain/common/jobs_common/features/feature_flag_provider.py index 2e617d05c..b5ff90c60 100644 --- a/interactive_ai/workflows/geti_domain/common/jobs_common/features/feature_flag_provider.py +++ b/interactive_ai/workflows/geti_domain/common/jobs_common/features/feature_flag_provider.py @@ -23,6 +23,7 @@ class FeatureFlag(Enum): FEATURE_FLAG_ANOMALY_REDUCTION = auto() FEATURE_FLAG_OTX_VERSION_SELECTION = auto() FEATURE_FLAG_KEYPOINT_DETECTION = auto() + FEATURE_FLAG_FP16_INFERENCE = auto() class FeatureFlagProvider: diff --git a/interactive_ai/workflows/geti_domain/common/jobs_common_extras/mlflow/utils/train_output_models.py b/interactive_ai/workflows/geti_domain/common/jobs_common_extras/mlflow/utils/train_output_models.py index 41145c768..dc5dceb70 100644 --- a/interactive_ai/workflows/geti_domain/common/jobs_common_extras/mlflow/utils/train_output_models.py +++ b/interactive_ai/workflows/geti_domain/common/jobs_common_extras/mlflow/utils/train_output_models.py @@ -28,7 +28,7 @@ class TrainOutputModelIds: """Class used to store different training output models' id only.""" base: str - mo_fp32_with_xai: str + mo_with_xai: str mo_fp32_without_xai: typing.Optional[str] # noqa: UP007 mo_fp16_without_xai: typing.Optional[str] # noqa: UP007 onnx: typing.Optional[str] # noqa: UP007 @@ -41,7 +41,7 @@ class TrainOutputModels: """ base: Model # The trained model in the base framework format - mo_fp32_with_xai: Model # The OpenVino model with eXplainable AI (saliency maps) + mo_with_xai: Model # The OpenVino model with eXplainable AI (saliency maps) mo_fp32_without_xai: typing.Optional[Model] = None # noqa: UP007 # The OpenVino model without eXplainable AI mo_fp16_without_xai: typing.Optional[Model] = None # noqa: UP007 # The OpenVino model with lower precision FP16 onnx: typing.Optional[Model] = None # noqa: UP007 # The ONNX model @@ -119,7 +119,7 @@ def _parse(model: Model | None) -> str | None: return TrainOutputModelIds( base=str(self.base.id_), - mo_fp32_with_xai=str(self.mo_fp32_with_xai.id_), + mo_with_xai=str(self.mo_with_xai.id_), mo_fp32_without_xai=_parse(self.mo_fp32_without_xai), mo_fp16_without_xai=_parse(self.mo_fp16_without_xai), onnx=_parse(self.onnx), @@ -146,7 +146,7 @@ def _parse(id_: str | None) -> Model | None: return cls( base=_parse(train_output_model_ids.base), - mo_fp32_with_xai=_parse(train_output_model_ids.mo_fp32_with_xai), + mo_with_xai=_parse(train_output_model_ids.mo_with_xai), mo_fp32_without_xai=_parse(train_output_model_ids.mo_fp32_without_xai), mo_fp16_without_xai=_parse(train_output_model_ids.mo_fp16_without_xai), onnx=_parse(train_output_model_ids.onnx), diff --git a/interactive_ai/workflows/geti_domain/train/job/tasks/evaluate_and_infer/evaluate_and_infer.py b/interactive_ai/workflows/geti_domain/train/job/tasks/evaluate_and_infer/evaluate_and_infer.py index 072a1e90d..443c03476 100644 --- a/interactive_ai/workflows/geti_domain/train/job/tasks/evaluate_and_infer/evaluate_and_infer.py +++ b/interactive_ai/workflows/geti_domain/train/job/tasks/evaluate_and_infer/evaluate_and_infer.py @@ -280,7 +280,7 @@ def evaluate_and_infer( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=report_evaluate_progress, ) @@ -302,7 +302,7 @@ def evaluate_and_infer( register_models( project_id=ID(train_data.train_data.project_id), model_id=ID(train_data.train_output_model_ids.base), - optimized_model_id=ID(train_data.train_output_model_ids.mo_fp32_with_xai), + optimized_model_id=ID(train_data.train_output_model_ids.mo_with_xai), model_storage_id=ID(train_data.train_data.model_storage_id), task_id=ID(train_data.train_data.task_id), ) @@ -315,7 +315,7 @@ def evaluate_and_infer( post_model_acceptance( train_data=train_data.train_data, base_model_id=train_data.train_output_model_ids.base, - inference_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + inference_model_id=train_data.train_output_model_ids.mo_with_xai, ) except Exception as err: logger.error("Error occurred during model activation", exc_info=err) @@ -330,7 +330,7 @@ def evaluate_and_infer( train_data=train_data.train_data, training_dataset_id=train_data.dataset_id, train_inference_subset_id=train_inference_subset_id, - model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=report_task_infer_progress, ) except Exception: diff --git a/interactive_ai/workflows/geti_domain/train/job/tasks/prepare_and_train/train_helpers.py b/interactive_ai/workflows/geti_domain/train/job/tasks/prepare_and_train/train_helpers.py index 675e8c1f4..4b6b55a41 100644 --- a/interactive_ai/workflows/geti_domain/train/job/tasks/prepare_and_train/train_helpers.py +++ b/interactive_ai/workflows/geti_domain/train/job/tasks/prepare_and_train/train_helpers.py @@ -25,6 +25,7 @@ from iai_core.entities.task_node import TaskNode from iai_core.repos import ModelRepo from jobs_common.exceptions import CommandInitializationFailedException, TrainingPodFailedException +from jobs_common.features.feature_flag_provider import FeatureFlag, FeatureFlagProvider from jobs_common.tasks.utils.secrets import JobMetadata from jobs_common_extras.mlflow.adapters.geti_otx_interface import GetiOTXInterfaceAdapter from jobs_common_extras.mlflow.utils.train_output_models import TrainOutputModelIds, TrainOutputModels @@ -203,13 +204,12 @@ def _get_export_parameters( def prepare_train(train_data: TrainWorkflowData, dataset: Dataset) -> TrainOutputModels: """Function should be called in prior to model training Flyte task. - It creates iai-core model entities and prepares MLFlow experiement buckets for + It creates iai-core model entities and prepares MLFlow experiment buckets for the subsequent model training Flyte task. :param train_data: Data class defining data used for training and providing helpers to get frequently used objects :param dataset: dataset to train on - :param hyper_parameters: ConfigurableParameters to use for training """ project, task_node = train_data.get_common_entities() @@ -238,12 +238,13 @@ def prepare_train(train_data: TrainWorkflowData, dataset: Dataset) -> TrainOutpu previous_revision=input_model, previous_trained_revision=input_model, ) + use_fp16 = FeatureFlagProvider.is_enabled(FeatureFlag.FEATURE_FLAG_FP16_INFERENCE) output_models = TrainOutputModels( base=output_base_model, - mo_fp32_with_xai=model_builder.create_model( + mo_with_xai=model_builder.create_model( model_format=ModelFormat.OPENVINO, has_xai_head=True, - precision=[ModelPrecision.FP32], + precision=[ModelPrecision.FP16 if use_fp16 else ModelPrecision.FP32], model_optimization_type=ModelOptimizationType.MO, previous_revision=output_base_model, previous_trained_revision=output_base_model, diff --git a/interactive_ai/workflows/geti_domain/train/tests/fixtures/train_workflow_data.py b/interactive_ai/workflows/geti_domain/train/tests/fixtures/train_workflow_data.py index e8b2e056d..685257d3e 100644 --- a/interactive_ai/workflows/geti_domain/train/tests/fixtures/train_workflow_data.py +++ b/interactive_ai/workflows/geti_domain/train/tests/fixtures/train_workflow_data.py @@ -29,13 +29,13 @@ def _create_model(id_): return model model_base = _create_model(fxt_train_output_model_ids.base) - model_mo_fp32_with_xai = _create_model(fxt_train_output_model_ids.mo_fp32_with_xai) + model_mo_with_xai = _create_model(fxt_train_output_model_ids.mo_with_xai) model_mo_fp32_without_xai = _create_model(fxt_train_output_model_ids.mo_fp32_without_xai) model_mo_fp16_without_xai = _create_model(fxt_train_output_model_ids.mo_fp16_without_xai) model_onnx = _create_model(fxt_train_output_model_ids.onnx) models = [ model_base, - model_mo_fp32_with_xai, + model_mo_with_xai, model_mo_fp32_without_xai, model_mo_fp16_without_xai, model_onnx, diff --git a/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/evaluate_and_infer/test_evaluate_and_infer.py b/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/evaluate_and_infer/test_evaluate_and_infer.py index 001edc0e9..9d03647b9 100644 --- a/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/evaluate_and_infer/test_evaluate_and_infer.py +++ b/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/evaluate_and_infer/test_evaluate_and_infer.py @@ -97,7 +97,7 @@ def test_evaluate_and_infer_model_not_accepted( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) @@ -185,7 +185,7 @@ def test_evaluate_and_infer_evaluate_failed( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) @@ -272,7 +272,7 @@ def test_evaluate_and_infer_post_model_acceptance_failed( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) @@ -282,7 +282,7 @@ def test_evaluate_and_infer_post_model_acceptance_failed( mocked_post_model_acceptance.assert_called_once_with( train_data=train_data.train_data, base_model_id=train_data.train_output_model_ids.base, - inference_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + inference_model_id=train_data.train_output_model_ids.mo_with_xai, ) mocked_report_progress.assert_has_calls( @@ -357,20 +357,20 @@ def test_evaluate_and_infer_task_inference_failed( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) mocked_post_model_acceptance.assert_called_once_with( train_data=train_data.train_data, base_model_id=train_data.train_output_model_ids.base, - inference_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + inference_model_id=train_data.train_output_model_ids.mo_with_xai, ) mocked_register_models.assert_called_once_with( project_id=ID(train_data.train_data.project_id), model_id=ID(train_data.train_output_model_ids.base), - optimized_model_id=ID(train_data.train_output_model_ids.mo_fp32_with_xai), + optimized_model_id=ID(train_data.train_output_model_ids.mo_with_xai), model_storage_id=ID(train_data.train_data.model_storage_id), task_id=ID(train_data.train_data.task_id), ) @@ -378,7 +378,7 @@ def test_evaluate_and_infer_task_inference_failed( train_data=train_data.train_data, training_dataset_id=train_data.dataset_id, train_inference_subset_id=TRAIN_INFERENCE_SUBSET_ID, - model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) mocked_pipeline_infer_on_unannotated.assert_not_called() @@ -453,19 +453,19 @@ def test_evaluate_and_infer_pipeline_inference_failed( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) mocked_post_model_acceptance.assert_called_once_with( train_data=train_data.train_data, base_model_id=train_data.train_output_model_ids.base, - inference_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + inference_model_id=train_data.train_output_model_ids.mo_with_xai, ) mocked_register_models.assert_called_once_with( project_id=ID(train_data.train_data.project_id), model_id=ID(train_data.train_output_model_ids.base), - optimized_model_id=ID(train_data.train_output_model_ids.mo_fp32_with_xai), + optimized_model_id=ID(train_data.train_output_model_ids.mo_with_xai), model_storage_id=ID(train_data.train_data.model_storage_id), task_id=ID(train_data.train_data.task_id), ) @@ -473,7 +473,7 @@ def test_evaluate_and_infer_pipeline_inference_failed( train_data=train_data.train_data, training_dataset_id=train_data.dataset_id, train_inference_subset_id=TRAIN_INFERENCE_SUBSET_ID, - model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) mocked_pipeline_infer_on_unannotated.assert_called_once_with( @@ -555,7 +555,7 @@ def test_evaluate_and_infer_model_accepted( train_data=train_data.train_data, dataset_id=train_data.dataset_id, base_model_id=train_data.train_output_model_ids.base, - mo_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + mo_model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) @@ -568,12 +568,12 @@ def test_evaluate_and_infer_model_accepted( mocked_post_model_acceptance.assert_called_once_with( train_data=train_data.train_data, base_model_id=train_data.train_output_model_ids.base, - inference_model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + inference_model_id=train_data.train_output_model_ids.mo_with_xai, ) mocked_register_models.assert_called_once_with( project_id=ID(train_data.train_data.project_id), model_id=ID(train_data.train_output_model_ids.base), - optimized_model_id=ID(train_data.train_output_model_ids.mo_fp32_with_xai), + optimized_model_id=ID(train_data.train_output_model_ids.mo_with_xai), model_storage_id=ID(train_data.train_data.model_storage_id), task_id=ID(train_data.train_data.task_id), ) @@ -581,7 +581,7 @@ def test_evaluate_and_infer_model_accepted( train_data=train_data.train_data, training_dataset_id=train_data.dataset_id, train_inference_subset_id=TRAIN_INFERENCE_SUBSET_ID, - model_id=train_data.train_output_model_ids.mo_fp32_with_xai, + model_id=train_data.train_output_model_ids.mo_with_xai, progress_callback=ANY, ) report_progress_calls.extend( diff --git a/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/prepare_and_train/test_train_helpers.py b/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/prepare_and_train/test_train_helpers.py index 4d8171454..aebffbb61 100644 --- a/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/prepare_and_train/test_train_helpers.py +++ b/interactive_ai/workflows/geti_domain/train/tests/unit/tasks/prepare_and_train/test_train_helpers.py @@ -6,7 +6,8 @@ import pytest from geti_types import ID -from iai_core.entities.model import ModelStatus +from iai_core.entities.model import ModelPrecision, ModelStatus +from jobs_common.features.feature_flag_provider import FeatureFlag from jobs_common_extras.mlflow.adapters.geti_otx_interface import GetiOTXInterfaceAdapter from job.tasks.prepare_and_train.train_helpers import finalize_train, prepare_train @@ -14,16 +15,22 @@ @pytest.mark.JobsComponent class TestTrainHelpers: + @pytest.mark.parametrize( + "feature_flag_setting", [pytest.param(True, id="fp16-enabled"), pytest.param(False, id="fp16-disabled")] + ) @patch("job.tasks.prepare_and_train.train_helpers.GetiOTXInterfaceAdapter") @patch("job.tasks.prepare_and_train.train_helpers.ModelRepo") def test_prepare_train( self, mock_model_repo, mock_geti_otx_interface_adapter, + feature_flag_setting, mock_train_data, fxt_dataset_with_images, + monkeypatch, ) -> None: # Arrange + monkeypatch.setenv(FeatureFlag.FEATURE_FLAG_FP16_INFERENCE.name, str(feature_flag_setting).lower()) mock_model_repo.generate_id.side_effect = [ID(str(i)) for i in range(5)] # Act @@ -35,7 +42,10 @@ def test_prepare_train( # Assert assert output_model_ids.base == "0" - assert output_model_ids.mo_fp32_with_xai == "1" + assert output_model_ids.mo_with_xai == "1" + assert ( + ModelPrecision.FP16 if feature_flag_setting else ModelPrecision.FP32 + ) in train_output_models.mo_with_xai.precision assert output_model_ids.mo_fp32_without_xai == "2" assert output_model_ids.mo_fp16_without_xai == "3" assert output_model_ids.onnx == "4" diff --git a/interactive_ai/workflows/geti_domain/train/tests/unit/workflows/test_train_workflow.py b/interactive_ai/workflows/geti_domain/train/tests/unit/workflows/test_train_workflow.py index c3caf7064..97387ea81 100644 --- a/interactive_ai/workflows/geti_domain/train/tests/unit/workflows/test_train_workflow.py +++ b/interactive_ai/workflows/geti_domain/train/tests/unit/workflows/test_train_workflow.py @@ -71,7 +71,7 @@ def test_train_workflow( job_id="job_id", train_output_model_ids=TrainOutputModelIds( base=BASE_MODEL_ID, - mo_fp32_with_xai=MO_MODEL_ID, + mo_with_xai=MO_MODEL_ID, mo_fp32_without_xai=None, mo_fp16_without_xai=None, onnx=None,