Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions tests/model_registry/model_catalog/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,42 +180,52 @@ def randomly_picked_model_from_catalog_api_by_source(
model_registry_rest_headers: dict[str, str],
request: pytest.FixtureRequest,
) -> tuple[dict[Any, Any], str, str]:
"""Pick a random model from a specific catalog (function-scoped for test isolation)
"""
Pick a random model from a specific catalog if a model name is not provided. If model name is provided, verify
that it exists and is associated with a given catalog and return the same.

Supports parameterized headers via 'header_type':
- 'user_token': Uses user_token_for_api_calls (default for user-specific tests)
- 'registry': Uses model_registry_rest_headers (for catalog/registry tests)
- 'model_name': Name of the model

Accepts 'catalog_id' or 'source' (alias) to specify the catalog.
Accepts 'catalog_id' or 'source' (alias) to specify the catalog. Accepts 'model_name' to specify the model to
look for.
"""
param = getattr(request, "param", {})
# Support both 'catalog_id' and 'source' for backward compatibility
catalog_id = param.get("catalog_id") or param.get("source", REDHAT_AI_CATALOG_ID)
header_type = param.get("header_type", "user_token")

model_name = param.get("model_name")
random_model = None
# Select headers based on header_type
if header_type == "registry":
headers = model_registry_rest_headers
else:
headers = get_rest_headers(token=user_token_for_api_calls)

LOGGER.info(f"Picking random model from catalog: {catalog_id} with header_type: {header_type}")

models_response = execute_get_command(
url=f"{model_catalog_rest_url[0]}models?source={catalog_id}&pageSize=100",
headers=headers,
)
models = models_response.get("items", [])
assert models, f"No models found for catalog: {catalog_id}"
LOGGER.info(f"{len(models)} models found in catalog {catalog_id}")

random_model = random.choice(seq=models)

model_name = random_model.get("name")
assert model_name, "Model name not found in random model"
assert random_model.get("source_id") == catalog_id, f"Catalog ID (source_id) mismatch for model {model_name}"
LOGGER.info(f"Testing model '{model_name}' from catalog '{catalog_id}'")

if not model_name:
LOGGER.info(f"Picking random model from catalog: {catalog_id} with header_type: {header_type}")
models_response = execute_get_command(
url=f"{model_catalog_rest_url[0]}models?source={catalog_id}&pageSize=100",
headers=headers,
)
models = models_response.get("items", [])
assert models, f"No models found for catalog: {catalog_id}"
LOGGER.info(f"{len(models)} models found in catalog {catalog_id}")
random_model = random.choice(seq=models)
model_name = random_model.get("name")
assert model_name, "Model name not found in random model"
assert random_model.get("source_id") == catalog_id, f"Catalog ID (source_id) mismatch for model {model_name}"
LOGGER.info(f"Testing model '{model_name}' from catalog '{catalog_id}'")
else:
LOGGER.info(f"Looking for pre-selected model: {model_name} from catalog: {catalog_id}")
# check if the model exists:
random_model = execute_get_command(
url=f"{model_catalog_rest_url[0]}sources/{catalog_id}/models/{model_name}",
headers=headers,
)
assert random_model["source_id"] == catalog_id, f"Catalog ID (source_id) mismatch for model {model_name}"
LOGGER.info(f"Using model '{model_name}' from catalog '{catalog_id}'")
return random_model, model_name, catalog_id


Expand Down
206 changes: 206 additions & 0 deletions tests/model_registry/model_catalog/test_model_artifact_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import pytest
from typing import Self, Any
import random
from ocp_resources.config_map import ConfigMap
from tests.model_registry.model_catalog.utils import (
fetch_all_artifacts_with_dynamic_paging,
validate_model_artifacts_match_criteria_and,
validate_model_artifacts_match_criteria_or,
)
from tests.model_registry.model_catalog.constants import (
VALIDATED_CATALOG_ID,
)
from kubernetes.dynamic.exceptions import ResourceNotFoundError
from simple_logger.logger import get_logger

LOGGER = get_logger(name=__name__)
pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")]
MODEL_NAMEs_ARTIFACT_SEARCH: list[str] = [
"RedHatAI/Llama-3.1-8B-Instruct",
"RedHatAI/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic",
"RedHatAI/Mistral-Small-3.1-24B-Instruct-2503-quantized.w4a16",
"RedHatAI/Mistral-Small-3.1-24B-Instruct-2503-quantized.w8a8",
"RedHatAI/Mixtral-8x7B-Instruct-v0.1",
]


class TestSearchArtifactsByFilterQuery:
@pytest.mark.parametrize(
"randomly_picked_model_from_catalog_api_by_source, invalid_filter_query",
[
pytest.param(
{"catalog_id": VALIDATED_CATALOG_ID, "header_type": "registry"},
"fake IN ('test', 'fake'))",
id="test_invalid_artifact_filter_query_malformed",
),
pytest.param(
{"catalog_id": VALIDATED_CATALOG_ID, "header_type": "registry"},
"ttft_p90.double_value < abc",
id="test_invalid_artifact_filter_query_data_type_mismatch",
),
pytest.param(
{"catalog_id": VALIDATED_CATALOG_ID, "header_type": "registry"},
"hardware_type.string_value = 5.0",
id="test_invalid_artifact_filter_query_data_type_mismatch_equality",
),
],
indirect=["randomly_picked_model_from_catalog_api_by_source"],
)
def test_search_artifacts_by_invalid_filter_query(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict, str, str],
invalid_filter_query: str,
):
"""
Tests the API's response to invalid filter queries syntax when searching artifacts.
It verifies that an invalid filter query syntax raises the correct error.
"""
_, model_name, catalog_id = randomly_picked_model_from_catalog_api_by_source

LOGGER.info(f"Testing invalid artifact filter query: '{invalid_filter_query}' for model: {model_name}")
with pytest.raises(ResourceNotFoundError, match="invalid filter query"):
fetch_all_artifacts_with_dynamic_paging(
url_with_pagesize=(
f"{model_catalog_rest_url[0]}sources/{catalog_id}/models/{model_name}/artifacts?"
f"filterQuery={invalid_filter_query}&pageSize"
),
headers=model_registry_rest_headers,
page_size=1,
)

LOGGER.info(
f"Successfully validated that invalid artifact filter query '{invalid_filter_query}' raises an error"
)

@pytest.mark.parametrize(
"randomly_picked_model_from_catalog_api_by_source, filter_query, expected_value, logic_type",
[
pytest.param(
{
"catalog_id": VALIDATED_CATALOG_ID,
"header_type": "registry",
"model_name": random.choice(MODEL_NAMEs_ARTIFACT_SEARCH),
},
"hardware_type.string_value = 'ABC-1234'",
None,
None,
id="test_valid_artifact_filter_query_no_results",
),
pytest.param(
{
"catalog_id": VALIDATED_CATALOG_ID,
"header_type": "registry",
"model_name": random.choice(MODEL_NAMEs_ARTIFACT_SEARCH),
},
"requests_per_second.double_value > 15.0",
[{"key_name": "requests_per_second", "key_type": "double_value", "comparison": "min", "value": 15.0}],
"and",
id="test_performance_min_filter",
),
pytest.param(
{
"catalog_id": VALIDATED_CATALOG_ID,
"header_type": "registry",
"model_name": random.choice(MODEL_NAMEs_ARTIFACT_SEARCH),
},
"hardware_count.int_value = 8",
[{"key_name": "hardware_count", "key_type": "int_value", "comparison": "exact", "value": 8}],
"and",
id="test_hardware_exact_filter",
),
pytest.param(
{
"catalog_id": VALIDATED_CATALOG_ID,
"header_type": "registry",
"model_name": random.choice(MODEL_NAMEs_ARTIFACT_SEARCH),
},
"(hardware_type.string_value = 'H100') AND (ttft_p99.double_value < 200)",
[
{"key_name": "hardware_type", "key_type": "string_value", "comparison": "exact", "value": "H100"},
{"key_name": "ttft_p99", "key_type": "double_value", "comparison": "max", "value": 199},
],
"and",
id="test_combined_hardware_performance_filter_and_operation",
),
pytest.param(
{
"catalog_id": VALIDATED_CATALOG_ID,
"header_type": "registry",
"model_name": random.choice(MODEL_NAMEs_ARTIFACT_SEARCH),
},
"(tps_mean.double_value <260) OR (hardware_type.string_value = 'A100-80')",
[
{"key_name": "tps_mean", "key_type": "double_value", "comparison": "max", "value": 260},
{
"key_name": "hardware_type",
"key_type": "string_value",
"comparison": "exact",
"value": "A100-80",
},
],
"or",
id="performance_or_hardware_filter_or_operation",
),
],
indirect=["randomly_picked_model_from_catalog_api_by_source"],
)
def test_filter_query_advanced_artifact_search(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict, str, str],
filter_query: str,
expected_value: list[dict[str, Any]] | None,
logic_type: str | None,
):
"""
Advanced filter query test for artifact-based filtering with AND/OR logic
"""
_, model_name, catalog_id = randomly_picked_model_from_catalog_api_by_source

LOGGER.info(f"Testing artifact filter query: '{filter_query}' for model: {model_name}")

result = fetch_all_artifacts_with_dynamic_paging(
url_with_pagesize=(
f"{model_catalog_rest_url[0]}sources/{catalog_id}/models/{model_name}/artifacts?"
f"filterQuery={filter_query}&pageSize"
),
headers=model_registry_rest_headers,
page_size=100,
)

if expected_value is None:
# Simple validation of length and size for basic filter queries
assert result["items"] == [], f"Filter query '{filter_query}' should return valid results"
assert result["size"] == 0, f"Size should be 0 for filter query '{filter_query}'"
LOGGER.info(
f"Successfully validated that filter query '{filter_query}' returns {len(result['items'])} artifacts"
)
else:
# Advanced validation using criteria matching
all_artifacts = result["items"]

validation_result = None
# Select validation function based on logic type
if logic_type == "and":
validation_result = validate_model_artifacts_match_criteria_and(
all_model_artifacts=all_artifacts, expected_validations=expected_value, model_name=model_name
)
elif logic_type == "or":
validation_result = validate_model_artifacts_match_criteria_or(
all_model_artifacts=all_artifacts, expected_validations=expected_value, model_name=model_name
)
else:
raise ValueError(f"Invalid logic_type: {logic_type}. Must be 'and' or 'or'")

if validation_result:
LOGGER.info(
f"For Model: {model_name}, {logic_type} validation completed successfully"
f" for {len(all_artifacts)} artifacts"
)
else:
pytest.fail(f"{logic_type} filter validation failed for model {model_name}")