diff --git a/tests/model_registry/model_catalog/conftest.py b/tests/model_registry/model_catalog/conftest.py index 86e373c3e..7ff602ef4 100644 --- a/tests/model_registry/model_catalog/conftest.py +++ b/tests/model_registry/model_catalog/conftest.py @@ -17,6 +17,7 @@ CATALOG_CONTAINER, REDHAT_AI_CATALOG_ID, ) +from tests.model_registry.model_catalog.utils import get_models_from_catalog_api from tests.model_registry.constants import CUSTOM_CATALOG_ID1 from tests.model_registry.utils import ( get_rest_headers, @@ -198,3 +199,29 @@ def catalog_openapi_schema() -> dict[Any, Any]: response = requests.get(OPENAPI_SCHEMA_URL, timeout=10) response.raise_for_status() return yaml.safe_load(response.text) + + +@pytest.fixture +def models_from_filter_query( + request, + model_catalog_rest_url: list[str], + model_registry_rest_headers: dict[str, str], +) -> list[str]: + """ + Fixture that runs get_models_from_catalog_api with the given filter_query, + asserts that models are returned, and returns list of model names. + """ + filter_query = request.param + + models = get_models_from_catalog_api( + model_catalog_rest_url=model_catalog_rest_url, + model_registry_rest_headers=model_registry_rest_headers, + additional_params=f"&filterQuery={filter_query}", + )["items"] + + assert models, f"No models returned from filter query: {filter_query}" + + model_names = [model["name"] for model in models] + LOGGER.info(f"Filter query '{filter_query}' returned {len(model_names)} models: {', '.join(model_names)}") + + return model_names diff --git a/tests/model_registry/model_catalog/test_model_search.py b/tests/model_registry/model_catalog/test_model_search.py index 3d034685e..85a7dc62c 100644 --- a/tests/model_registry/model_catalog/test_model_search.py +++ b/tests/model_registry/model_catalog/test_model_search.py @@ -18,15 +18,15 @@ validate_search_results_against_database, validate_filter_query_results_against_database, validate_performance_data_files_on_pod, + validate_model_artifacts_match_criteria_and, + validate_model_artifacts_match_criteria_or, ) from tests.model_registry.utils import get_model_catalog_pod from kubernetes.dynamic import DynamicClient from kubernetes.dynamic.exceptions import ResourceNotFoundError LOGGER = get_logger(name=__name__) -pytestmark = [ - pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace", "test_idp_user") -] +pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")] class TestSearchModelCatalog: @@ -579,3 +579,101 @@ def test_presence_performance_data_on_pod( # Assert that all models have all required performance data files assert not validation_results, f"Models with missing performance data files: {validation_results}" + + @pytest.mark.parametrize( + "models_from_filter_query, expected_value, logic_type", + [ + pytest.param( + "artifacts.requests_per_second > 15.0", + [{"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 15.0}], + "and", + id="performance_min_filter", + ), + pytest.param( + "artifacts.hardware_count = 8", + [{"key_name": "hardware_count", "key_type": "int_value", "comparison": "exact", "value": 8}], + "and", + id="hardware_exact_filter", + ), + pytest.param( + "(artifacts.hardware_type LIKE 'H200') AND (artifacts.ttft_p95 < 50)", + [ + {"key_name": "hardware_type", "key_type": "string_value", "comparison": "exact", "value": "H200"}, + {"key_name": "ttft_p95", "key_type": "double_value", "comparison": "max", "value": 50}, + ], + "and", + id="test_combined_hardware_performance_filter_mixed_types", + ), + pytest.param( + "(artifacts.ttft_mean < 100) AND (artifacts.requests_per_second > 10)", + [ + {"key_name": "ttft_mean", "key_type": "double_value", "comparison": "max", "value": 100}, + {"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 10}, + ], + "and", + id="test_combined_hardware_performance_filter_numeric_types", + ), + pytest.param( + "(artifacts.tps_mean < 247) OR (artifacts.hardware_type LIKE 'A100-80')", + [ + {"key_name": "tps_mean", "key_type": "double_value", "comparison": "max", "value": 247}, + { + "key_name": "hardware_type", + "key_type": "string_value", + "comparison": "exact", + "value": "A100-80", + }, + ], + "or", + id="performance_or_hardware_filter", + ), + ], + indirect=["models_from_filter_query"], + ) + def test_filter_query_advanced_model_search( + self: Self, + models_from_filter_query: list[str], + expected_value: list[dict[str, Any]], + logic_type: str, + model_catalog_rest_url: list[str], + model_registry_rest_headers: dict[str, str], + ): + """ + RHOAIENG-39615: Advanced filter query test for performance-based filtering with AND/OR logic + """ + errors = [] + + # Additional validation: ensure returned models match the filter criteria + for model_name in models_from_filter_query: + url = f"{model_catalog_rest_url[0]}sources/{VALIDATED_CATALOG_ID}/models/{model_name}/artifacts?pageSize" + LOGGER.info(f"Validating model: {model_name} with {len(expected_value)} {logic_type.upper()} validation(s)") + + # Fetch all artifacts with dynamic page size adjustment + all_model_artifacts = fetch_all_artifacts_with_dynamic_paging( + url_with_pagesize=url, + headers=model_registry_rest_headers, + page_size=200, + )["items"] + + validation_result = None + # Select validation function based on logic type + if logic_type == "and": + validation_result = validate_model_artifacts_match_criteria_and( + all_model_artifacts=all_model_artifacts, expected_validations=expected_value, model_name=model_name + ) + elif logic_type == "or": + validation_result = validate_model_artifacts_match_criteria_or( + all_model_artifacts=all_model_artifacts, expected_validations=expected_value, model_name=model_name + ) + else: + raise ValueError(f"Invalid logic_type: {logic_type}. Must be 'and' or 'or'") + + if validation_result: + LOGGER.info(f"For Model: {model_name}, {logic_type.upper()} validation completed successfully") + else: + errors.append(model_name) + + assert not errors, f"{logic_type.upper()} filter validations failed for {', '.join(errors)}" + LOGGER.info( + f"Advanced {logic_type.upper()} filter validation completed for {len(models_from_filter_query)} models" + ) diff --git a/tests/model_registry/model_catalog/utils.py b/tests/model_registry/model_catalog/utils.py index 1a007534f..3e593eb9f 100644 --- a/tests/model_registry/model_catalog/utils.py +++ b/tests/model_registry/model_catalog/utils.py @@ -1005,3 +1005,119 @@ def validate_items_sorted_correctly(items: list[dict], field: str, order: str) - return all(values[i] >= values[i + 1] for i in range(len(values) - 1)) else: raise ValueError(f"Invalid sort order: {order}") + + +def _validate_single_criterion( + artifact_name: str, custom_properties: dict[str, Any], validation: dict[str, Any] +) -> tuple[bool, str]: + """ + Helper function to validate a single criterion against an artifact. + + Args: + artifact_name: Name of the artifact being validated + custom_properties: Custom properties dictionary from the artifact + validation: Single validation criterion containing key_name, key_type, comparison, value + + Returns: + tuple: (condition_met: bool, message: str) + """ + key_name = validation["key_name"] + key_type = validation["key_type"] + comparison_type = validation["comparison"] + expected_val = validation["value"] + + raw_value = custom_properties.get(key_name, {}).get(key_type, None) + + if raw_value is None: + return False, f"{key_name}: missing" + + # Convert value to appropriate type + try: + if key_type == "int_value": + artifact_value = int(raw_value) + elif key_type == "double_value": + artifact_value = float(raw_value) + elif key_type == "string_value": + artifact_value = str(raw_value) + else: + LOGGER.warning(f"Unknown key_type: {key_type}") + return False, f"{key_name}: unknown type {key_type}" + except (ValueError, TypeError): + return False, f"{key_name}: conversion error" + + # Perform comparison based on type + condition_met = False + if comparison_type == "exact": + condition_met = artifact_value == expected_val + elif comparison_type == "min": + condition_met = artifact_value >= expected_val + elif comparison_type == "max": + condition_met = artifact_value <= expected_val + elif comparison_type == "contains" and key_type == "string_value": + condition_met = expected_val in artifact_value + + message = f"Artifact {artifact_name} {key_name}: {artifact_value} {comparison_type} {expected_val}" + return condition_met, message + + +def _get_artifact_validation_results( + artifact: dict[str, Any], expected_validations: list[dict[str, Any]] +) -> tuple[list[bool], list[str]]: + """ + Checks one artifact against all validations and returns the boolean outcomes and messages. + """ + artifact_name = artifact.get("name", "missing_artifact_name") + custom_properties = artifact["customProperties"] + + # Store the boolean results and informative messages + bool_results = [] + messages = [] + + for validation in expected_validations: + condition_met, message = _validate_single_criterion( + artifact_name=artifact_name, custom_properties=custom_properties, validation=validation + ) + bool_results.append(condition_met) + messages.append(message) + + return bool_results, messages + + +def validate_model_artifacts_match_criteria_and( + all_model_artifacts: list[dict[str, Any]], expected_validations: list[dict[str, Any]], model_name: str +) -> bool: + """ + Validates that at least one artifact in the model satisfies ALL expected validation criteria. + """ + for artifact in all_model_artifacts: + bool_results, messages = _get_artifact_validation_results( + artifact=artifact, expected_validations=expected_validations + ) + # If ALL results are True + if all(bool_results): + validation_results = [f"{message}: passed" for message in messages] + LOGGER.info( + f"Model {model_name} passed all {len(bool_results)} validations with artifact: {validation_results}" + ) + return True + + return False + + +def validate_model_artifacts_match_criteria_or( + all_model_artifacts: list[dict[str, Any]], expected_validations: list[dict[str, Any]], model_name: str +) -> bool: + """ + Validates that at least one artifact in the model satisfies AT LEAST ONE of the expected validation criteria. + """ + for artifact in all_model_artifacts: + bool_results, messages = _get_artifact_validation_results( + artifact=artifact, expected_validations=expected_validations + ) + if any(bool_results): + # Find the first passing message for logging + LOGGER.info(f"Model {model_name} passed OR validation with artifact: {messages[bool_results.index(True)]}") + return True + + LOGGER.error(f"Model {model_name} failed all OR validations") + return False diff --git a/tests/model_registry/utils.py b/tests/model_registry/utils.py index e15ecd6f5..004110642 100644 --- a/tests/model_registry/utils.py +++ b/tests/model_registry/utils.py @@ -698,7 +698,8 @@ def execute_get_call( url: str, headers: dict[str, str], verify: bool | str = False, params: dict[str, Any] | None = None ) -> requests.Response: LOGGER.info(f"Executing get call: {url}") - LOGGER.info(f"params: {params}") + if params: + LOGGER.info(f"params: {params}") resp = requests.get(url=url, headers=headers, verify=verify, timeout=60, params=params) LOGGER.info(f"Encoded url from requests library: {resp.url}") if resp.status_code not in [200, 201]: