Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/model_registry/model_catalog/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CATALOG_CONTAINER,
REDHAT_AI_CATALOG_ID,
)
from tests.model_registry.model_catalog.utils import get_models_from_catalog_api
from tests.model_registry.constants import CUSTOM_CATALOG_ID1
from tests.model_registry.utils import (
get_rest_headers,
Expand Down Expand Up @@ -198,3 +199,29 @@ def catalog_openapi_schema() -> dict[Any, Any]:
response = requests.get(OPENAPI_SCHEMA_URL, timeout=10)
response.raise_for_status()
return yaml.safe_load(response.text)


@pytest.fixture
def models_from_filter_query(
request,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
):
"""
Fixture that runs get_models_from_catalog_api with the given filter_query,
asserts that models are returned, and returns list of model names.
"""
filter_query = request.param

models = get_models_from_catalog_api(
model_catalog_rest_url=model_catalog_rest_url,
model_registry_rest_headers=model_registry_rest_headers,
additional_params=f"&filterQuery={filter_query}",
)["items"]

assert models, f"No models returned from filter query: {filter_query}"

model_names = [model["name"] for model in models]
LOGGER.info(f"Filter query '{filter_query}' returned {len(model_names)} models: {', '.join(model_names)}")

return model_names
104 changes: 101 additions & 3 deletions tests/model_registry/model_catalog/test_model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
validate_search_results_against_database,
validate_filter_query_results_against_database,
validate_performance_data_files_on_pod,
validate_model_artifacts_match_criteria_and,
validate_model_artifacts_match_criteria_or,
)
from tests.model_registry.utils import get_model_catalog_pod
from kubernetes.dynamic import DynamicClient
from kubernetes.dynamic.exceptions import ResourceNotFoundError

LOGGER = get_logger(name=__name__)
pytestmark = [
pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace", "test_idp_user")
]
pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")]


class TestSearchModelCatalog:
Expand Down Expand Up @@ -579,3 +579,101 @@ def test_presence_performance_data_on_pod(

# Assert that all models have all required performance data files
assert not validation_results, f"Models with missing performance data files: {validation_results}"

@pytest.mark.parametrize(
"models_from_filter_query, expected_value, logic_type",
[
pytest.param(
"artifacts.requests_per_second > 15.0",
[{"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 15.0}],
"and",
id="performance_min_filter",
),
pytest.param(
"artifacts.hardware_count = 8",
[{"key_name": "hardware_count", "key_type": "int_value", "comparison": "exact", "value": 8}],
"and",
id="hardware_exact_filter",
),
pytest.param(
"(artifacts.hardware_type LIKE 'H200') AND (artifacts.ttft_p95 < 50)",
[
{"key_name": "hardware_type", "key_type": "string_value", "comparison": "exact", "value": "H200"},
{"key_name": "ttft_p95", "key_type": "double_value", "comparison": "max", "value": 50},
],
"and",
id="test_combined_hardware_performance_filter_mixed_types",
),
pytest.param(
"(artifacts.ttft_mean < 100) AND (artifacts.requests_per_second > 10)",
[
{"key_name": "ttft_mean", "key_type": "double_value", "comparison": "max", "value": 100},
{"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 10},
],
"and",
id="test_combined_hardware_performance_filter_numeric_types",
),
pytest.param(
"(artifacts.tps_mean < 247) OR (artifacts.hardware_type LIKE 'A100-80')",
[
{"key_name": "tps_mean", "key_type": "double_value", "comparison": "max", "value": 247},
{
"key_name": "hardware_type",
"key_type": "string_value",
"comparison": "exact",
"value": "A100-80",
},
],
"or",
id="performance_or_hardware_filter",
),
],
indirect=["models_from_filter_query"],
)
def test_filter_query_advanced_model_search(
self: Self,
models_from_filter_query: list[str],
expected_value: list[dict[str, Any]],
logic_type: str,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
):
"""
RHOAIENG-39615: Advanced filter query test for performance-based filtering with AND/OR logic
"""
errors = []

# Additional validation: ensure returned models match the filter criteria
for model_name in models_from_filter_query:
url = f"{model_catalog_rest_url[0]}sources/{VALIDATED_CATALOG_ID}/models/{model_name}/artifacts?pageSize"
LOGGER.info(f"Validating model: {model_name} with {len(expected_value)} {logic_type.upper()} validation(s)")

# Fetch all artifacts with dynamic page size adjustment
all_model_artifacts = fetch_all_artifacts_with_dynamic_paging(
url_with_pagesize=url,
headers=model_registry_rest_headers,
page_size=100,
)["items"]

validation_result = None
# Select validation function based on logic type
if logic_type == "and":
validation_result = validate_model_artifacts_match_criteria_and(
all_model_artifacts=all_model_artifacts, expected_validations=expected_value, model_name=model_name
)
elif logic_type == "or":
validation_result = validate_model_artifacts_match_criteria_or(
all_model_artifacts=all_model_artifacts, expected_validations=expected_value, model_name=model_name
)
else:
raise ValueError(f"Invalid logic_type: {logic_type}. Must be 'and' or 'or'")

if validation_result:
LOGGER.info(f"For Model: {model_name}, {logic_type.upper()} validation completed successfully")
else:
errors.append(model_name)

assert not errors, f"{logic_type.upper()} filter validations failed for {', '.join(errors)}"
LOGGER.info(
f"Advanced {logic_type.upper()} filter validation completed for {len(models_from_filter_query)} models"
)
137 changes: 137 additions & 0 deletions tests/model_registry/model_catalog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,3 +1005,140 @@ def validate_items_sorted_correctly(items: list[dict], field: str, order: str) -
return all(values[i] >= values[i + 1] for i in range(len(values) - 1))
else:
raise ValueError(f"Invalid sort order: {order}")


def _validate_single_criterion(
artifact_name: str, custom_properties: dict[str, Any], validation: dict[str, Any]
) -> tuple[bool, str]:
"""
Helper function to validate a single criterion against an artifact.

Args:
artifact_name: Name of the artifact being validated
custom_properties: Custom properties dictionary from the artifact
validation: Single validation criterion containing key_name, key_type, comparison, value

Returns:
tuple: (condition_met: bool, message: str)
"""
key_name = validation["key_name"]
key_type = validation["key_type"]
comparison_type = validation["comparison"]
expected_val = validation["value"]

raw_value = custom_properties.get(key_name, {}).get(key_type, None)

if raw_value is None:
return False, f"{key_name}: missing"

# Convert value to appropriate type
try:
if key_type == "int_value":
artifact_value = int(raw_value)
elif key_type == "double_value":
artifact_value = float(raw_value)
elif key_type == "string_value":
artifact_value = str(raw_value)
else:
LOGGER.warning(f"Unknown key_type: {key_type}")
return False, f"{key_name}: unknown type {key_type}"
except (ValueError, TypeError):
return False, f"{key_name}: conversion error"

# Perform comparison based on type
condition_met = False
if comparison_type == "exact":
condition_met = artifact_value == expected_val
elif comparison_type == "min":
condition_met = artifact_value >= expected_val
elif comparison_type == "max":
condition_met = artifact_value <= expected_val
elif comparison_type == "contains" and key_type == "string_value":
condition_met = expected_val in artifact_value

message = f"Artifact {artifact_name} {key_name}: {artifact_value} {comparison_type} {expected_val}"
return condition_met, message


def validate_model_artifacts_match_criteria_and(
all_model_artifacts: list[dict[str, Any]], expected_validations: list[dict[str, Any]], model_name: str
) -> bool:
"""
Validates that at least one artifact in the model satisfies all expected validation criteria.

Args:
all_model_artifacts: List of artifact dictionaries for a model
expected_validations: List of validation criteria dictionaries, each containing:
- key_name: The property name to validate
- key_type: The type of the property (int_value, double_value, string_value)
- comparison: The comparison type (exact, min, max, contains)
- value: The expected value for comparison
model_name: Name of the model being validated (for logging)

Returns:
bool: True if at least one artifact satisfies all validation criteria, False otherwise
"""
for artifact in all_model_artifacts:
artifact_name = artifact.get("name", "missing_artifact_name")
custom_properties = artifact["customProperties"]
validation_results = []
conditions_passed = 0

# Check if this artifact satisfies ALL validations
for validation in expected_validations:
condition_met, message = _validate_single_criterion(
artifact_name=artifact_name, custom_properties=custom_properties, validation=validation
)

if not condition_met:
validation_results.append(f"{message}: failed")
break # AND logic: break on first failure
else:
validation_results.append(f"{message}: passed")
conditions_passed += 1

# If this artifact satisfies all conditions, the model passes
if conditions_passed == len(expected_validations):
LOGGER.info(
f"Model {model_name} passed all {conditions_passed} validations with artifact: {validation_results}"
)
return True

return False


def validate_model_artifacts_match_criteria_or(
all_model_artifacts: list[dict[str, Any]], expected_validations: list[dict[str, Any]], model_name: str
) -> bool:
"""
Validates that at least one artifact in the model satisfies at least one of the expected validation criteria.

Args:
all_model_artifacts: List of artifact dictionaries for a model
expected_validations: List of validation criteria dictionaries, each containing:
- key_name: The property name to validate
- key_type: The type of the property (int_value, double_value, string_value)
- comparison: The comparison type (exact, min, max, contains)
- value: The expected value for comparison
model_name: Name of the model being validated (for logging)

Returns:
bool: True if at least one artifact satisfies at least one validation criterion, False otherwise
"""
for artifact in all_model_artifacts:
artifact_name = artifact.get("name")
custom_properties = artifact["customProperties"]

# Check if this artifact satisfies ANY validation (OR logic)
for validation in expected_validations:
condition_met, message = _validate_single_criterion(
artifact_name=artifact_name, custom_properties=custom_properties, validation=validation
)

if condition_met:
LOGGER.info(f"Model {model_name} passed OR validation with artifact: {message}")
return True # OR logic: return immediately on first success

# No artifact passed any validation
LOGGER.info(f"Model {model_name} failed all {len(expected_validations)} OR validations")
return False
3 changes: 2 additions & 1 deletion tests/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,8 @@ def execute_get_call(
url: str, headers: dict[str, str], verify: bool | str = False, params: dict[str, Any] | None = None
) -> requests.Response:
LOGGER.info(f"Executing get call: {url}")
LOGGER.info(f"params: {params}")
if params:
LOGGER.info(f"params: {params}")
resp = requests.get(url=url, headers=headers, verify=verify, timeout=60, params=params)
LOGGER.info(f"Encoded url from requests library: {resp.url}")
if resp.status_code not in [200, 201]:
Expand Down