Skip to content
Merged
40 changes: 37 additions & 3 deletions tests/model_registry/model_catalog/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import random
from typing import Generator, Any
import requests

import yaml
from simple_logger.logger import get_logger

import yaml
import pytest
from kubernetes.dynamic import DynamicClient

Expand All @@ -13,18 +13,26 @@
from ocp_resources.route import Route
from ocp_resources.service_account import ServiceAccount
from tests.model_registry.constants import DEFAULT_MODEL_CATALOG
from tests.model_registry.model_catalog.constants import SAMPLE_MODEL_NAME3, CUSTOM_CATALOG_ID1, DEFAULT_CATALOG_ID
from tests.model_registry.model_catalog.constants import (
SAMPLE_MODEL_NAME3,
CUSTOM_CATALOG_ID1,
DEFAULT_CATALOG_ID,
DEFAULT_CATALOG_FILE,
CATALOG_CONTAINER,
)
from tests.model_registry.model_catalog.utils import (
is_model_catalog_ready,
wait_for_model_catalog_api,
get_model_str,
execute_get_command,
get_model_catalog_pod,
get_default_model_catalog_yaml,
)
from tests.model_registry.utils import get_rest_headers
from utilities.infra import get_openshift_token, login_with_user_password, create_inference_token
from utilities.user_utils import UserTestSession


LOGGER = get_logger(name=__name__)


Expand Down Expand Up @@ -149,3 +157,29 @@ def randomly_picked_model_from_default_catalog(
assert result, f"Expected Default models to be present. Actual: {result}"
LOGGER.info(f"{len(result)} models found")
return random.choice(seq=result)


@pytest.fixture(scope="class")
def default_model_catalog_yaml_content(admin_client: DynamicClient, model_registry_namespace: str) -> dict[Any, Any]:
model_catalog_pod = get_model_catalog_pod(client=admin_client, model_registry_namespace=model_registry_namespace)[0]
return yaml.safe_load(model_catalog_pod.execute(command=["cat", DEFAULT_CATALOG_FILE], container=CATALOG_CONTAINER))
Comment thread
fege marked this conversation as resolved.


@pytest.fixture(scope="class")
def default_catalog_api_response(
model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str]
) -> dict[Any, Any]:
"""Fetch all models from default catalog API (used for data validation tests)"""
return execute_get_command(
url=f"{model_catalog_rest_url[0]}models?source={DEFAULT_CATALOG_ID}&pageSize=100",
headers=model_registry_rest_headers,
)


@pytest.fixture(scope="class")
def catalog_openapi_schema() -> dict[Any, Any]:
"""Fetch and cache the catalog OpenAPI schema (fetched once per class)"""
OPENAPI_SCHEMA_URL = "https://raw.githubusercontent.com/kubeflow/model-registry/main/api/openapi/catalog.yaml"
response = requests.get(OPENAPI_SCHEMA_URL, timeout=10)
response.raise_for_status()
return yaml.safe_load(response.text)
131 changes: 115 additions & 16 deletions tests/model_registry/model_catalog/test_default_model_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import yaml
import random
from kubernetes.dynamic import DynamicClient
from dictdiffer import diff
from ocp_resources.deployment import Deployment
Expand All @@ -10,15 +11,16 @@
from ocp_resources.config_map import ConfigMap
from ocp_resources.route import Route
from ocp_resources.service import Service
from tests.model_registry.model_catalog.constants import DEFAULT_CATALOG_ID, DEFAULT_CATALOG_FILE, CATALOG_CONTAINER
from tests.model_registry.model_catalog.constants import DEFAULT_CATALOG_ID
from tests.model_registry.model_catalog.utils import (
validate_model_catalog_enabled,
execute_get_command,
validate_model_catalog_resource,
validate_default_catalog,
get_validate_default_model_catalog_source,
extract_schema_fields,
)
from tests.model_registry.utils import get_rest_headers, get_model_catalog_pod
from tests.model_registry.utils import get_rest_headers
from utilities.user_utils import UserTestSession

LOGGER = get_logger(name=__name__)
Expand Down Expand Up @@ -152,28 +154,125 @@ def test_model_default_catalog_get_model_artifact(
assert result, f"No artifacts found for {model_name}"
assert result[0]["uri"]


@pytest.mark.skip_must_gather
class TestModelCatalogDefaultData:
"""Test class for validating default catalog data (not user-specific)"""

def test_model_default_catalog_number_of_models(
self: Self,
admin_client: DynamicClient,
model_registry_namespace: str,
model_catalog_rest_url: list[str],
user_token_for_api_calls: str,
default_catalog_api_response: dict[Any, Any],
default_model_catalog_yaml_content: dict[Any, Any],
):
"""
RHOAIENG-33667: Validate number of models in default catalog
"""

model_catalog_pod = get_model_catalog_pod(
client=admin_client, model_registry_namespace=model_registry_namespace
)[0]
count = len(default_model_catalog_yaml_content.get("models", []))

catalog_content = model_catalog_pod.execute(command=["cat", DEFAULT_CATALOG_FILE], container=CATALOG_CONTAINER)
catalog_data = yaml.safe_load(catalog_content)
count = len(catalog_data.get("models", []))
assert count == default_catalog_api_response["size"], (
f"Expected count: {count}, Actual size: {default_catalog_api_response['size']}"
)
LOGGER.info("Model count matches")

result = execute_get_command(
url=f"{model_catalog_rest_url[0]}models?source={DEFAULT_CATALOG_ID}&pageSize=1",
headers=get_rest_headers(token=user_token_for_api_calls),
def test_model_default_catalog_correspondence_of_model_name(
self: Self,
default_catalog_api_response: dict[Any, Any],
default_model_catalog_yaml_content: dict[Any, Any],
catalog_openapi_schema: dict[Any, Any],
):
"""
RHOAIENG-35260: Validate the correspondence of model parameters in default catalog yaml and model catalog api
"""

all_model_fields, required_model_fields = extract_schema_fields(
openapi_schema=catalog_openapi_schema, schema_name="CatalogModel"
)
LOGGER.info(f"All model fields from OpenAPI schema: {all_model_fields}")
LOGGER.info(f"Required model fields from OpenAPI schema: {required_model_fields}")

api_models = {model["name"]: model for model in default_catalog_api_response.get("items", [])}
Comment thread
fege marked this conversation as resolved.
assert api_models

models_with_differences = {}

for model in default_model_catalog_yaml_content.get("models", []):
LOGGER.info(f"Validating model: {model['name']}")

api_model = api_models.get(model["name"])
assert api_model, f"Model {model['name']} not found in API response"

# Check required fields are present in both YAML and API
yaml_missing_required = required_model_fields - set(model.keys())
api_missing_required = required_model_fields - set(api_model.keys())

assert not yaml_missing_required, (
f"Model {model['name']} missing REQUIRED fields in YAML: {yaml_missing_required}"
)
assert not api_missing_required, (
f"Model {model['name']} missing REQUIRED fields in API: {api_missing_required}"
)

# Filter to only schema-defined fields for value comparison
model_filtered = {k: v for k, v in model.items() if k in all_model_fields}
api_model_filtered = {k: v for k, v in api_model.items() if k in all_model_fields}

differences = list(diff(model_filtered, api_model_filtered))
if differences:
models_with_differences[model["name"]] = differences
LOGGER.warning(f"Found value differences for {model['name']}: {differences}")

# FAILS for null-valued properties in YAML model until https://issues.redhat.com/browse/RHOAIENG-35322 is fixed
assert not models_with_differences, (
f"Found differences in {len(models_with_differences)} model(s): {models_with_differences}"
)
Comment thread
fege marked this conversation as resolved.
LOGGER.info("Model correspondence matches")

def test_model_default_catalog_random_artifact(
Comment thread
fege marked this conversation as resolved.
self: Self,
default_model_catalog_yaml_content: dict[Any, Any],
model_catalog_rest_url: list[str],
model_registry_rest_headers: dict[str, str],
catalog_openapi_schema: dict[Any, Any],
):
"""
RHOAIENG-35260: Validate the random artifact in default catalog yaml matches API response
"""

all_artifact_fields, required_artifact_fields = extract_schema_fields(
openapi_schema=catalog_openapi_schema, schema_name="CatalogModelArtifact"
)
LOGGER.info(f"All artifact fields from OpenAPI schema: {all_artifact_fields}")
LOGGER.info(f"Required artifact fields from OpenAPI schema: {required_artifact_fields}")

random_model = random.choice(seq=default_model_catalog_yaml_content.get("models", []))
LOGGER.info(f"Random model: {random_model['name']}")

api_model_artifacts = execute_get_command(
url=f"{model_catalog_rest_url[0]}sources/{DEFAULT_CATALOG_ID}/models/{random_model['name']}/artifacts",
headers=model_registry_rest_headers,
)["items"]

yaml_artifacts = random_model.get("artifacts", [])
assert api_model_artifacts, f"No artifacts found in API for {random_model['name']}"
assert yaml_artifacts, f"No artifacts found in YAML for {random_model['name']}"

# Validate all required fields are present in both YAML and API artifact
# FAILS artifactType is not in YAML nor in API until https://issues.redhat.com/browse/RHOAIENG-35569 is fixed
for field in required_artifact_fields:
for artifact in yaml_artifacts:
assert field in artifact, f"YAML artifact for {random_model['name']} missing REQUIRED field: {field}"
for artifact in api_model_artifacts:
assert field in artifact, f"API artifact for {random_model['name']} missing REQUIRED field: {field}"
Comment thread
fege marked this conversation as resolved.

# Filter artifacts to only include schema-defined fields for comparison
yaml_artifacts_filtered = [
{k: v for k, v in artifact.items() if k in all_artifact_fields} for artifact in yaml_artifacts
]
api_artifacts_filtered = [
{k: v for k, v in artifact.items() if k in all_artifact_fields} for artifact in api_model_artifacts
]

assert count == result["size"], f"Expected count: {count}, Actual size: {result['size']}"
differences = list(diff(yaml_artifacts_filtered, api_artifacts_filtered))
assert not differences, f"Artifacts mismatch for {random_model['name']}: {differences}"
LOGGER.info("Artifacts match")
47 changes: 47 additions & 0 deletions tests/model_registry/model_catalog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,50 @@ def get_validate_default_model_catalog_source(token: str, model_catalog_url: str

def get_default_model_catalog_yaml(config_map: ConfigMap) -> str:
return yaml.safe_load(config_map.instance.data["sources.yaml"])["catalogs"]


def extract_schema_fields(openapi_schema: dict[Any, Any], schema_name: str) -> tuple[set[str], set[str]]:
"""
Extract all and required fields from an OpenAPI schema for validation.

Args:
openapi_schema: The parsed OpenAPI schema dictionary
schema_name: Name of the schema to extract (e.g., "CatalogModel", "CatalogModelArtifact")

Returns:
Tuple of (all_fields, required_fields) excluding server-generated fields and timestamps.
"""

def _extract_properties_and_required(schema: dict[Any, Any]) -> tuple[set[str], set[str]]:
"""Recursively extract properties and required fields from a schema."""
props = set(schema.get("properties", {}).keys())
required = set(schema.get("required", []))

# Properties from allOf (inheritance/composition)
if "allOf" in schema:
for item in schema["allOf"]:
sub_schema = item
if "$ref" in item:
# Follow reference and recursively extract
ref_schema_name = item["$ref"].split("/")[-1]
sub_schema = openapi_schema["components"]["schemas"][ref_schema_name]
sub_props, sub_required = _extract_properties_and_required(schema=sub_schema)
props.update(sub_props)
required.update(sub_required)

return props, required

target_schema = openapi_schema["components"]["schemas"][schema_name]
all_properties, required_fields = _extract_properties_and_required(schema=target_schema)

# Exclude fields that shouldn't be compared
excluded_fields = {
"id", # Server-generated
"externalId", # Server-generated
"createTimeSinceEpoch", # Timestamps may differ
"lastUpdateTimeSinceEpoch", # Timestamps may differ
"artifacts", # CatalogModel only
"source_id", # CatalogModel only
}

return all_properties - excluded_fields, required_fields - excluded_fields