Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion tests/model_registry/model_catalog/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
REDHAT_AI_CATALOG_ID,
)
from tests.model_registry.model_catalog.utils import get_models_from_catalog_api
from tests.model_registry.constants import CUSTOM_CATALOG_ID1, DEFAULT_CUSTOM_MODEL_CATALOG
from tests.model_registry.constants import (
CUSTOM_CATALOG_ID1,
DEFAULT_MODEL_CATALOG_CM,
DEFAULT_CUSTOM_MODEL_CATALOG,
)
from tests.model_registry.utils import (
get_rest_headers,
is_model_catalog_ready,
Expand All @@ -34,6 +38,46 @@
LOGGER = get_logger(name=__name__)


@pytest.fixture(scope="session")
def enabled_model_catalog_config_map(
admin_client: DynamicClient,
model_registry_namespace: str,
) -> ConfigMap:
"""
Enable all catalogs in the default model catalog configmap
"""
# Get operator-managed default sources ConfigMap
default_sources_cm = ConfigMap(
name=DEFAULT_MODEL_CATALOG_CM, client=admin_client, namespace=model_registry_namespace, ensure_exists=True
)

# Get the sources.yaml content from default sources
default_sources_yaml = default_sources_cm.instance.data.get("sources.yaml", "")

# Parse the YAML and extract only catalogs, enabling each one
parsed_yaml = yaml.safe_load(default_sources_yaml)
if not parsed_yaml or "catalogs" not in parsed_yaml:
raise RuntimeError("No catalogs found in default sources ConfigMap")

for catalog in parsed_yaml["catalogs"]:
catalog["enabled"] = True
enabled_yaml_dict = {"catalogs": parsed_yaml["catalogs"]}
enabled_sources_yaml = yaml.dump(enabled_yaml_dict, default_flow_style=False, sort_keys=False)

LOGGER.info("Adding enabled catalogs to model-catalog-sources ConfigMap")

# Get user-managed sources ConfigMap
user_sources_cm = ConfigMap(
name=DEFAULT_CUSTOM_MODEL_CATALOG, client=admin_client, namespace=model_registry_namespace, ensure_exists=True
)

patches = {"data": {"sources.yaml": enabled_sources_yaml}}

with ResourceEditor(patches={user_sources_cm: patches}):
is_model_catalog_ready(client=admin_client, model_registry_namespace=model_registry_namespace)
yield user_sources_cm


@pytest.fixture(scope="class")
def model_catalog_config_map(
request: pytest.FixtureRequest, admin_client: DynamicClient, model_registry_namespace: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class TestModelCatalogDefault:
def test_model_catalog_default_catalog_sources(
self,
pytestconfig: pytest.Config,
enabled_model_catalog_config_map: ConfigMap,
test_idp_user: UserTestSession,
model_catalog_rest_url: list[str],
user_token_for_api_calls: str,
Expand Down Expand Up @@ -199,6 +200,7 @@ def test_model_catalog_default_catalog_sources(

def test_model_default_catalog_get_models_by_source(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
):
Expand All @@ -211,6 +213,7 @@ def test_model_default_catalog_get_models_by_source(

def test_model_default_catalog_get_model_by_name(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
user_token_for_api_calls: str,
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
Expand All @@ -228,6 +231,7 @@ def test_model_default_catalog_get_model_by_name(

def test_model_default_catalog_get_model_artifact(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
user_token_for_api_calls: str,
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
Expand All @@ -253,6 +257,7 @@ class TestModelCatalogDefaultData:

def test_model_default_catalog_number_of_models(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
default_catalog_api_response: dict[Any, Any],
default_model_catalog_yaml_content: dict[Any, Any],
):
Expand All @@ -269,6 +274,7 @@ def test_model_default_catalog_number_of_models(

def test_model_default_catalog_correspondence_of_model_name(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
default_catalog_api_response: dict[Any, Any],
default_model_catalog_yaml_content: dict[Any, Any],
catalog_openapi_schema: dict[Any, Any],
Expand Down Expand Up @@ -322,6 +328,7 @@ def test_model_default_catalog_correspondence_of_model_name(

def test_model_default_catalog_random_artifact(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
default_model_catalog_yaml_content: dict[Any, Any],
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from typing import Self
from simple_logger.logger import get_logger

from ocp_resources.config_map import ConfigMap
from tests.model_registry.model_catalog.utils import (
validate_filter_options_structure,
execute_database_query,
Expand Down Expand Up @@ -45,6 +45,7 @@ class TestFilterOptionsEndpoint:
)
def test_filter_options_endpoint_validation(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
user_token_for_api_calls: str,
test_idp_user: UserTestSession,
Expand Down Expand Up @@ -97,6 +98,7 @@ def test_filter_options_endpoint_validation(
)
def test_comprehensive_coverage_against_database(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
user_token_for_api_calls: str,
model_registry_namespace: str,
Expand Down
24 changes: 21 additions & 3 deletions tests/model_registry/model_catalog/test_model_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from dictdiffer import diff

from ocp_resources.config_map import ConfigMap
from simple_logger.logger import get_logger
from typing import Self, Any
from tests.model_registry.model_catalog.constants import (
Expand Down Expand Up @@ -32,7 +32,10 @@
class TestSearchModelCatalog:
@pytest.mark.smoke
def test_search_model_catalog_source_label(
self: Self, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str]
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
):
"""
RHOAIENG-33656: Validate search model catalog by source label
Expand Down Expand Up @@ -61,7 +64,10 @@ def test_search_model_catalog_source_label(
assert redhat_ai_filter_moldels_size + redhat_ai_validated_filter_models_size == both_filtered_models_size

def test_search_model_catalog_invalid_source_label(
self: Self, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str]
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
):
"""
RHOAIENG-33656:
Expand Down Expand Up @@ -102,6 +108,7 @@ def test_search_model_catalog_invalid_source_label(
)
def test_search_model_catalog_match(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
Expand Down Expand Up @@ -145,6 +152,7 @@ class TestSearchModelArtifact:
)
def test_validate_model_artifacts_by_artifact_type(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
Expand Down Expand Up @@ -210,6 +218,7 @@ def test_validate_model_artifacts_by_artifact_type(
)
def test_error_handled_for_invalid_artifact_type(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
Expand Down Expand Up @@ -248,6 +257,7 @@ def test_error_handled_for_invalid_artifact_type(
)
def test_multiple_artifact_type_filtering(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
randomly_picked_model_from_catalog_api_by_source: tuple[dict[Any, Any], str, str],
Expand Down Expand Up @@ -299,6 +309,7 @@ class TestSearchModelCatalogQParameter:
)
def test_q_parameter_basic_search(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
search_term: str,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
Expand Down Expand Up @@ -338,6 +349,7 @@ def test_q_parameter_basic_search(
)
def test_q_parameter_case_insensitive(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
search_term: str,
case_variant: str,
model_catalog_rest_url: list[str],
Expand Down Expand Up @@ -388,6 +400,7 @@ def test_q_parameter_case_insensitive(

def test_q_parameter_no_results(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
model_registry_namespace: str,
Expand Down Expand Up @@ -417,6 +430,7 @@ def test_q_parameter_no_results(
def test_q_parameter_empty_query(
self: Self,
search_term,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
):
Expand All @@ -434,6 +448,7 @@ def test_q_parameter_empty_query(

def test_q_parameter_with_source_label_filter(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
):
Expand Down Expand Up @@ -479,6 +494,7 @@ def test_q_parameter_with_source_label_filter(
class TestSearchModelsByFilterQuery:
def test_search_models_by_filter_query(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
model_registry_namespace: str,
Expand Down Expand Up @@ -524,6 +540,7 @@ def test_search_models_by_filter_query(

def test_search_models_by_invalid_filter_query(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
model_registry_namespace: str,
Expand Down Expand Up @@ -563,6 +580,7 @@ def test_search_models_by_invalid_filter_query(
@pytest.mark.downstream_only
def test_presence_performance_data_on_pod(
self: Self,
enabled_model_catalog_config_map: ConfigMap,
admin_client: DynamicClient,
model_registry_namespace: str,
):
Expand Down
26 changes: 22 additions & 4 deletions tests/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
LOGGER = get_logger(name=__name__)


class TransientUnauthorizedError(Exception):
"""Exception for transient 401 Unauthorized errors that should be retried."""

pass


def get_mr_service_by_label(client: DynamicClient, namespace_name: str, mr_instance: ModelRegistry) -> Service:
"""
Args:
Expand Down Expand Up @@ -702,13 +708,25 @@ def execute_get_call(
resp = requests.get(url=url, headers=headers, verify=verify, timeout=60, params=params)
LOGGER.info(f"Encoded url from requests library: {resp.url}")
if resp.status_code not in [200, 201]:
# Raise custom exception for 401 errors that can be retried (OAuth/kube-rbac-proxy initialization)
if resp.status_code == 401:
raise TransientUnauthorizedError(f"Get call failed for resource: {url}, 401: {resp.text}")
# Raise regular exception for other errors (400, 403, 404, etc.) that should fail immediately
raise ResourceNotFoundError(f"Get call failed for resource: {url}, {resp.status_code}: {resp.text}")
return resp


@retry(wait_timeout=60, sleep=5, exceptions_dict={ResourceNotFoundError: []})
@retry(wait_timeout=60, sleep=5, exceptions_dict={ResourceNotFoundError: [], TransientUnauthorizedError: []})
def wait_for_model_catalog_api(url: str, headers: dict[str, str], verify: bool | str = False) -> requests.Response:
return execute_get_call(url=f"{url}sources", headers=headers, verify=verify)
"""
Wait for model catalog API to be ready and fully initialized checks both /sources and /models endpoints
to ensure OAuth/kube-rbac-proxy is fully initialized.
"""
LOGGER.info(f"Waiting for model catalog API at {url}sources")
execute_get_call(url=f"{url}sources", headers=headers, verify=verify)
LOGGER.info(f"Verifying model catalog API readiness at {url}models")

return execute_get_call(url=f"{url}models", headers=headers, verify=verify)


def execute_get_command(
Expand Down Expand Up @@ -741,9 +759,9 @@ def validate_model_catalog_sources(
url=model_catalog_sources_url,
headers=rest_headers,
)["items"]
LOGGER.info(results)
LOGGER.info(f"Model catalog sources: {results}")
# this is for the default catalog:
assert len(results) == len(expected_catalog_values) + 2
assert len(results) == len(expected_catalog_values)
ids_from_query = [result_entry["id"] for result_entry in results]
ids_expected = [expected_entry["id"] for expected_entry in expected_catalog_values]
assert set(ids_expected).issubset(set(ids_from_query)), f"Expected: {expected_catalog_values}. Actual: {results}"
Expand Down