Skip to content

Commit 6c29e38

Browse files
authored
Add tests with artifacts property (#882)
* Add tests with artifacts property * default name for artifacts missing name * Add suggested code based on Federico's comment
1 parent ae34943 commit 6c29e38

File tree

4 files changed

+246
-4
lines changed

4 files changed

+246
-4
lines changed

tests/model_registry/model_catalog/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CATALOG_CONTAINER,
1818
REDHAT_AI_CATALOG_ID,
1919
)
20+
from tests.model_registry.model_catalog.utils import get_models_from_catalog_api
2021
from tests.model_registry.constants import CUSTOM_CATALOG_ID1
2122
from tests.model_registry.utils import (
2223
get_rest_headers,
@@ -198,3 +199,29 @@ def catalog_openapi_schema() -> dict[Any, Any]:
198199
response = requests.get(OPENAPI_SCHEMA_URL, timeout=10)
199200
response.raise_for_status()
200201
return yaml.safe_load(response.text)
202+
203+
204+
@pytest.fixture
205+
def models_from_filter_query(
206+
request,
207+
model_catalog_rest_url: list[str],
208+
model_registry_rest_headers: dict[str, str],
209+
) -> list[str]:
210+
"""
211+
Fixture that runs get_models_from_catalog_api with the given filter_query,
212+
asserts that models are returned, and returns list of model names.
213+
"""
214+
filter_query = request.param
215+
216+
models = get_models_from_catalog_api(
217+
model_catalog_rest_url=model_catalog_rest_url,
218+
model_registry_rest_headers=model_registry_rest_headers,
219+
additional_params=f"&filterQuery={filter_query}",
220+
)["items"]
221+
222+
assert models, f"No models returned from filter query: {filter_query}"
223+
224+
model_names = [model["name"] for model in models]
225+
LOGGER.info(f"Filter query '{filter_query}' returned {len(model_names)} models: {', '.join(model_names)}")
226+
227+
return model_names

tests/model_registry/model_catalog/test_model_search.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
validate_search_results_against_database,
1919
validate_filter_query_results_against_database,
2020
validate_performance_data_files_on_pod,
21+
validate_model_artifacts_match_criteria_and,
22+
validate_model_artifacts_match_criteria_or,
2123
)
2224
from tests.model_registry.utils import get_model_catalog_pod
2325
from kubernetes.dynamic import DynamicClient
2426
from kubernetes.dynamic.exceptions import ResourceNotFoundError
2527

2628
LOGGER = get_logger(name=__name__)
27-
pytestmark = [
28-
pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace", "test_idp_user")
29-
]
29+
pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")]
3030

3131

3232
class TestSearchModelCatalog:
@@ -579,3 +579,101 @@ def test_presence_performance_data_on_pod(
579579

580580
# Assert that all models have all required performance data files
581581
assert not validation_results, f"Models with missing performance data files: {validation_results}"
582+
583+
@pytest.mark.parametrize(
584+
"models_from_filter_query, expected_value, logic_type",
585+
[
586+
pytest.param(
587+
"artifacts.requests_per_second > 15.0",
588+
[{"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 15.0}],
589+
"and",
590+
id="performance_min_filter",
591+
),
592+
pytest.param(
593+
"artifacts.hardware_count = 8",
594+
[{"key_name": "hardware_count", "key_type": "int_value", "comparison": "exact", "value": 8}],
595+
"and",
596+
id="hardware_exact_filter",
597+
),
598+
pytest.param(
599+
"(artifacts.hardware_type LIKE 'H200') AND (artifacts.ttft_p95 < 50)",
600+
[
601+
{"key_name": "hardware_type", "key_type": "string_value", "comparison": "exact", "value": "H200"},
602+
{"key_name": "ttft_p95", "key_type": "double_value", "comparison": "max", "value": 50},
603+
],
604+
"and",
605+
id="test_combined_hardware_performance_filter_mixed_types",
606+
),
607+
pytest.param(
608+
"(artifacts.ttft_mean < 100) AND (artifacts.requests_per_second > 10)",
609+
[
610+
{"key_name": "ttft_mean", "key_type": "double_value", "comparison": "max", "value": 100},
611+
{"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 10},
612+
],
613+
"and",
614+
id="test_combined_hardware_performance_filter_numeric_types",
615+
),
616+
pytest.param(
617+
"(artifacts.tps_mean < 247) OR (artifacts.hardware_type LIKE 'A100-80')",
618+
[
619+
{"key_name": "tps_mean", "key_type": "double_value", "comparison": "max", "value": 247},
620+
{
621+
"key_name": "hardware_type",
622+
"key_type": "string_value",
623+
"comparison": "exact",
624+
"value": "A100-80",
625+
},
626+
],
627+
"or",
628+
id="performance_or_hardware_filter",
629+
),
630+
],
631+
indirect=["models_from_filter_query"],
632+
)
633+
def test_filter_query_advanced_model_search(
634+
self: Self,
635+
models_from_filter_query: list[str],
636+
expected_value: list[dict[str, Any]],
637+
logic_type: str,
638+
model_catalog_rest_url: list[str],
639+
model_registry_rest_headers: dict[str, str],
640+
):
641+
"""
642+
RHOAIENG-39615: Advanced filter query test for performance-based filtering with AND/OR logic
643+
"""
644+
errors = []
645+
646+
# Additional validation: ensure returned models match the filter criteria
647+
for model_name in models_from_filter_query:
648+
url = f"{model_catalog_rest_url[0]}sources/{VALIDATED_CATALOG_ID}/models/{model_name}/artifacts?pageSize"
649+
LOGGER.info(f"Validating model: {model_name} with {len(expected_value)} {logic_type.upper()} validation(s)")
650+
651+
# Fetch all artifacts with dynamic page size adjustment
652+
all_model_artifacts = fetch_all_artifacts_with_dynamic_paging(
653+
url_with_pagesize=url,
654+
headers=model_registry_rest_headers,
655+
page_size=200,
656+
)["items"]
657+
658+
validation_result = None
659+
# Select validation function based on logic type
660+
if logic_type == "and":
661+
validation_result = validate_model_artifacts_match_criteria_and(
662+
all_model_artifacts=all_model_artifacts, expected_validations=expected_value, model_name=model_name
663+
)
664+
elif logic_type == "or":
665+
validation_result = validate_model_artifacts_match_criteria_or(
666+
all_model_artifacts=all_model_artifacts, expected_validations=expected_value, model_name=model_name
667+
)
668+
else:
669+
raise ValueError(f"Invalid logic_type: {logic_type}. Must be 'and' or 'or'")
670+
671+
if validation_result:
672+
LOGGER.info(f"For Model: {model_name}, {logic_type.upper()} validation completed successfully")
673+
else:
674+
errors.append(model_name)
675+
676+
assert not errors, f"{logic_type.upper()} filter validations failed for {', '.join(errors)}"
677+
LOGGER.info(
678+
f"Advanced {logic_type.upper()} filter validation completed for {len(models_from_filter_query)} models"
679+
)

tests/model_registry/model_catalog/utils.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,3 +1005,119 @@ def validate_items_sorted_correctly(items: list[dict], field: str, order: str) -
10051005
return all(values[i] >= values[i + 1] for i in range(len(values) - 1))
10061006
else:
10071007
raise ValueError(f"Invalid sort order: {order}")
1008+
1009+
1010+
def _validate_single_criterion(
1011+
artifact_name: str, custom_properties: dict[str, Any], validation: dict[str, Any]
1012+
) -> tuple[bool, str]:
1013+
"""
1014+
Helper function to validate a single criterion against an artifact.
1015+
1016+
Args:
1017+
artifact_name: Name of the artifact being validated
1018+
custom_properties: Custom properties dictionary from the artifact
1019+
validation: Single validation criterion containing key_name, key_type, comparison, value
1020+
1021+
Returns:
1022+
tuple: (condition_met: bool, message: str)
1023+
"""
1024+
key_name = validation["key_name"]
1025+
key_type = validation["key_type"]
1026+
comparison_type = validation["comparison"]
1027+
expected_val = validation["value"]
1028+
1029+
raw_value = custom_properties.get(key_name, {}).get(key_type, None)
1030+
1031+
if raw_value is None:
1032+
return False, f"{key_name}: missing"
1033+
1034+
# Convert value to appropriate type
1035+
try:
1036+
if key_type == "int_value":
1037+
artifact_value = int(raw_value)
1038+
elif key_type == "double_value":
1039+
artifact_value = float(raw_value)
1040+
elif key_type == "string_value":
1041+
artifact_value = str(raw_value)
1042+
else:
1043+
LOGGER.warning(f"Unknown key_type: {key_type}")
1044+
return False, f"{key_name}: unknown type {key_type}"
1045+
except (ValueError, TypeError):
1046+
return False, f"{key_name}: conversion error"
1047+
1048+
# Perform comparison based on type
1049+
condition_met = False
1050+
if comparison_type == "exact":
1051+
condition_met = artifact_value == expected_val
1052+
elif comparison_type == "min":
1053+
condition_met = artifact_value >= expected_val
1054+
elif comparison_type == "max":
1055+
condition_met = artifact_value <= expected_val
1056+
elif comparison_type == "contains" and key_type == "string_value":
1057+
condition_met = expected_val in artifact_value
1058+
1059+
message = f"Artifact {artifact_name} {key_name}: {artifact_value} {comparison_type} {expected_val}"
1060+
return condition_met, message
1061+
1062+
1063+
def _get_artifact_validation_results(
1064+
artifact: dict[str, Any], expected_validations: list[dict[str, Any]]
1065+
) -> tuple[list[bool], list[str]]:
1066+
"""
1067+
Checks one artifact against all validations and returns the boolean outcomes and messages.
1068+
"""
1069+
artifact_name = artifact.get("name", "missing_artifact_name")
1070+
custom_properties = artifact["customProperties"]
1071+
1072+
# Store the boolean results and informative messages
1073+
bool_results = []
1074+
messages = []
1075+
1076+
for validation in expected_validations:
1077+
condition_met, message = _validate_single_criterion(
1078+
artifact_name=artifact_name, custom_properties=custom_properties, validation=validation
1079+
)
1080+
bool_results.append(condition_met)
1081+
messages.append(message)
1082+
1083+
return bool_results, messages
1084+
1085+
1086+
def validate_model_artifacts_match_criteria_and(
1087+
all_model_artifacts: list[dict[str, Any]], expected_validations: list[dict[str, Any]], model_name: str
1088+
) -> bool:
1089+
"""
1090+
Validates that at least one artifact in the model satisfies ALL expected validation criteria.
1091+
"""
1092+
for artifact in all_model_artifacts:
1093+
bool_results, messages = _get_artifact_validation_results(
1094+
artifact=artifact, expected_validations=expected_validations
1095+
)
1096+
# If ALL results are True
1097+
if all(bool_results):
1098+
validation_results = [f"{message}: passed" for message in messages]
1099+
LOGGER.info(
1100+
f"Model {model_name} passed all {len(bool_results)} validations with artifact: {validation_results}"
1101+
)
1102+
return True
1103+
1104+
return False
1105+
1106+
1107+
def validate_model_artifacts_match_criteria_or(
1108+
all_model_artifacts: list[dict[str, Any]], expected_validations: list[dict[str, Any]], model_name: str
1109+
) -> bool:
1110+
"""
1111+
Validates that at least one artifact in the model satisfies AT LEAST ONE of the expected validation criteria.
1112+
"""
1113+
for artifact in all_model_artifacts:
1114+
bool_results, messages = _get_artifact_validation_results(
1115+
artifact=artifact, expected_validations=expected_validations
1116+
)
1117+
if any(bool_results):
1118+
# Find the first passing message for logging
1119+
LOGGER.info(f"Model {model_name} passed OR validation with artifact: {messages[bool_results.index(True)]}")
1120+
return True
1121+
1122+
LOGGER.error(f"Model {model_name} failed all OR validations")
1123+
return False

tests/model_registry/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,8 @@ def execute_get_call(
698698
url: str, headers: dict[str, str], verify: bool | str = False, params: dict[str, Any] | None = None
699699
) -> requests.Response:
700700
LOGGER.info(f"Executing get call: {url}")
701-
LOGGER.info(f"params: {params}")
701+
if params:
702+
LOGGER.info(f"params: {params}")
702703
resp = requests.get(url=url, headers=headers, verify=verify, timeout=60, params=params)
703704
LOGGER.info(f"Encoded url from requests library: {resp.url}")
704705
if resp.status_code not in [200, 201]:

0 commit comments

Comments
 (0)