Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
572 changes: 364 additions & 208 deletions tests/model_registry/conftest.py

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions tests/model_registry/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
from utilities.constants import Annotations


class ModelRegistryEndpoints:
REGISTERED_MODELS: str = "/api/model_registry/v1alpha3/registered_models"


MODEL_NAME: str = "my-model"
MODEL_DESCRIPTION: str = "lorem ipsum"
DB_RESOURCES_NAME: str = "model-registry-db"
MR_INSTANCE_NAME: str = "model-registry"
MR_OPERATOR_NAME: str = "model-registry-operator"
MR_NAMESPACE: str = "rhoai-model-registries"
DEFAULT_LABEL_DICT_DB: dict[str, str] = {
Annotations.KubernetesIo.NAME: DB_RESOURCES_NAME,
Annotations.KubernetesIo.INSTANCE: DB_RESOURCES_NAME,
Annotations.KubernetesIo.PART_OF: DB_RESOURCES_NAME,
}
70 changes: 17 additions & 53 deletions tests/model_registry/test_model_registry_creation.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
import pytest
from typing import Self
from model_registry import ModelRegistry as ModelRegistryClient

from simple_logger.logger import get_logger

from ocp_resources.data_science_cluster import DataScienceCluster
from utilities.constants import Protocols, DscComponents, ModelFormat
from model_registry import ModelRegistry
from tests.model_registry.constants import MODEL_NAME, MR_NAMESPACE
from tests.model_registry.utils import register_model, get_and_validate_registered_model
from utilities.constants import DscComponents


LOGGER = get_logger(name=__name__)
MODEL_NAME: str = "my-model"


@pytest.mark.parametrize(
"updated_dsc_component_state",
"updated_dsc_component_state_scope_class",
[
pytest.param(
{
"component_name": DscComponents.MODELREGISTRY,
"desired_state": DscComponents.ManagementState.MANAGED,
"component_patch": {
DscComponents.MODELREGISTRY: {
"managementState": DscComponents.ManagementState.MANAGED,
"registriesNamespace": MR_NAMESPACE,
},
}
},
)
],
indirect=True,
)
@pytest.mark.usefixtures("updated_dsc_component_state_scope_class")
class TestModelRegistryCreation:
"""
Tests the creation of a model registry. If the component is set to 'Removed' it will be switched to 'Managed'
Expand All @@ -32,51 +38,9 @@ class TestModelRegistryCreation:
@pytest.mark.smoke
def test_registering_model(
self: Self,
model_registry_instance_rest_endpoint: str,
current_client_token: str,
updated_dsc_component_state: DataScienceCluster,
model_registry_client: ModelRegistryClient,
):
# address and port need to be split in the client instantiation
server, port = model_registry_instance_rest_endpoint.split(":")
registry = ModelRegistry(
server_address=f"{Protocols.HTTPS}://{server}",
port=port,
author="opendatahub-test",
user_token=current_client_token,
is_secure=False,
)
model = registry.register_model(
name=MODEL_NAME,
uri="https://storage-place.my-company.com",
version="2.0.0",
description="lorem ipsum",
model_format_name=ModelFormat.ONNX,
model_format_version="1",
storage_key="my-data-connection",
storage_path="path/to/model",
metadata={
"int_key": 1,
"bool_key": False,
"float_key": 3.14,
"str_key": "str_value",
},
model = register_model(model_registry_client=model_registry_client)
get_and_validate_registered_model(
model_registry_client=model_registry_client, model_name=MODEL_NAME, registered_model=model
)
registered_model = registry.get_registered_model(MODEL_NAME)
errors = []
if not registered_model.id == model.id:
errors.append(f"Unexpected id, received {registered_model.id}")
if not registered_model.name == model.name:
errors.append(f"Unexpected name, received {registered_model.name}")
if not registered_model.description == model.description:
errors.append(f"Unexpected description, received {registered_model.description}")
if not registered_model.owner == model.owner:
errors.append(f"Unexpected owner, received {registered_model.owner}")
if not registered_model.state == model.state:
errors.append(f"Unexpected state, received {registered_model.state}")

assert not errors, "errors found in model registry response validation:\n{}".format("\n".join(errors))

# TODO: Edit a registered model
# TODO: Add additional versions for a model
# TODO: List all available models
# TODO: List all versions of a model
18 changes: 0 additions & 18 deletions tests/model_registry/test_rest_api.py

This file was deleted.

Empty file.
31 changes: 31 additions & 0 deletions tests/model_registry/upgrade/test_model_registry_upgrade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from typing import Self
from tests.model_registry.constants import MODEL_NAME
from model_registry import ModelRegistry as ModelRegistryClient
from simple_logger.logger import get_logger
from tests.model_registry.utils import get_and_validate_registered_model, register_model

LOGGER = get_logger(name=__name__)


@pytest.mark.usefixtures("pre_upgrade_dsc_patch")
class TestPreUpgradeModelRegistry:
@pytest.mark.pre_upgrade
def test_registering_model_pre_upgrade(
self: Self,
model_registry_client: ModelRegistryClient,
):
model = register_model(model_registry_client=model_registry_client)
get_and_validate_registered_model(
model_registry_client=model_registry_client, model_name=MODEL_NAME, registered_model=model
)


@pytest.mark.usefixtures("post_upgrade_dsc_patch")
class TestPostUpgradeModelRegistry:
@pytest.mark.post_upgrade
def test_retrieving_model_post_upgrade(
self: Self,
model_registry_client: ModelRegistryClient,
):
model_registry_client.get_registered_model(name=MODEL_NAME)
140 changes: 132 additions & 8 deletions tests/model_registry/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from typing import Any
from simple_logger.logger import get_logger

from kubernetes.dynamic import DynamicClient
from ocp_resources.namespace import Namespace
from timeout_sampler import TimeoutSampler, TimeoutExpiredError
from model_registry import ModelRegistry as ModelRegistryClient
from model_registry.types import RegisteredModel
from ocp_resources.pod import Pod
from ocp_resources.service import Service
from ocp_resources.model_registry import ModelRegistry
from kubernetes.dynamic.exceptions import ResourceNotFoundError
from kubernetes.dynamic.exceptions import ResourceNotFoundError, NotFoundError

from tests.model_registry.constants import MODEL_NAME, MODEL_DESCRIPTION
from utilities.exceptions import ProtocolNotSupportedError, TooManyServicesError
from utilities.constants import Protocols

from utilities.constants import Protocols, ModelFormat

ADDRESS_ANNOTATION_PREFIX: str = "routing.opendatahub.io/external-address-"

LOGGER = get_logger(name=__name__)


def get_mr_service_by_label(client: DynamicClient, ns: Namespace, mr_instance: ModelRegistry) -> Service:
def get_mr_service_by_label(client: DynamicClient, ns: str, mr_instance: ModelRegistry) -> Service:
"""
Args:
client (DynamicClient): OCP Client to use.
ns (Namespace): Namespace object where to find the Service
ns (str): Namespace name where to find the Service
mr_instance (ModelRegistry): Model Registry instance

Returns:
Expand All @@ -28,7 +36,7 @@ def get_mr_service_by_label(client: DynamicClient, ns: Namespace, mr_instance: M
svcs
for svcs in Service.get(
dyn_client=client,
namespace=ns.name,
namespace=ns,
label_selector=f"app={mr_instance.name},component=model-registry",
)
]:
Expand All @@ -38,8 +46,124 @@ def get_mr_service_by_label(client: DynamicClient, ns: Namespace, mr_instance: M
raise ResourceNotFoundError(f"{mr_instance.name} has no Service")


def get_endpoint_from_mr_service(client: DynamicClient, svc: Service, protocol: str) -> str:
def get_endpoint_from_mr_service(svc: Service, protocol: str) -> str:
if protocol in (Protocols.REST, Protocols.GRPC):
return svc.instance.metadata.annotations[f"{ADDRESS_ANNOTATION_PREFIX}{protocol}"]
else:
raise ProtocolNotSupportedError(protocol)


def get_pod_container_error_status(pod: Pod) -> str | None:
"""
Check container error status for a given pod and if any containers is in waiting state, return that information
"""
pod_instance_status = pod.instance.status
for container_status in pod_instance_status.get("containerStatuses", []):
if waiting_container := container_status.get("state", {}).get("waiting"):
return waiting_container["reason"] if waiting_container.get("reason") else waiting_container
return ""


def get_not_running_pods(pods: list[Pod]) -> list[dict[str, Any]]:
# Gets all the non-running pods from a given namespace.
# Note: We need to keep track of pods marked for deletion as not running. This would ensure any
# pod that was spun up in place of pod marked for deletion, are not ignored
pods_not_running = []
try:
for pod in pods:
pod_instance = pod.instance
if container_status_error := get_pod_container_error_status(pod=pod):
pods_not_running.append({pod.name: container_status_error})

if pod_instance.metadata.get("deletionTimestamp") or pod_instance.status.phase not in (
pod.Status.RUNNING,
pod.Status.SUCCEEDED,
):
pods_not_running.append({pod.name: pod.status})
except (ResourceNotFoundError, NotFoundError) as exc:
LOGGER.warning("Ignoring pod that disappeared during cluster sanity check: %s", exc)
return pods_not_running


def wait_for_pods_running(
admin_client: DynamicClient,
namespace_name: str,
number_of_consecutive_checks: int = 1,
) -> bool | None:
"""
Waits for all pods in a given namespace to reach Running/Completed state. To avoid catching all pods in running
state too soon, use number_of_consecutive_checks with appropriate values.
"""
samples = TimeoutSampler(
wait_timeout=180,
sleep=5,
func=get_not_running_pods,
pods=list(Pod.get(dyn_client=admin_client, namespace=namespace_name)),
exceptions_dict={NotFoundError: [], ResourceNotFoundError: []},
)
sample = None
try:
current_check = 0
for sample in samples:
if not sample:
current_check += 1
if current_check >= number_of_consecutive_checks:
return True
else:
current_check = 0
except TimeoutExpiredError:
if sample:
LOGGER.error(
f"timeout waiting for all pods in namespace {namespace_name} to reach "
f"running state, following pods are in not running state: {sample}"
)
raise
return None


def register_model(model_registry_client: ModelRegistryClient) -> RegisteredModel:
return model_registry_client.register_model(
name=MODEL_NAME,
uri="https://storage-place.my-company.com",
version="2.0.0",
description=MODEL_DESCRIPTION,
model_format_name=ModelFormat.ONNX,
model_format_version="1",
storage_key="my-data-connection",
storage_path="path/to/model",
metadata={
"int_key": 1,
"bool_key": False,
"float_key": 3.14,
"str_key": "str_value",
},
)


def get_and_validate_registered_model(
model_registry_client: ModelRegistryClient,
model_name: str,
registered_model: RegisteredModel,
) -> None:
"""
Get and validate a registered model.
"""
model = model_registry_client.get_registered_model(name=model_name)
expected_attrs = {
"id": registered_model.id,
"name": registered_model.name,
"description": registered_model.description,
"owner": registered_model.owner,
"state": registered_model.state,
}
LOGGER.info(f"Expected: {expected_attrs}")
errors = [
f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}"
for attr, expected in expected_attrs.items()
if getattr(model, attr) != expected
]
assert not errors, f"Model Registry validation failed with error: {errors}"


class ModelRegistryV1Alpha1(ModelRegistry):
api_version = f"{ModelRegistry.ApiGroup.MODELREGISTRY_OPENDATAHUB_IO}/{ModelRegistry.ApiVersion.V1ALPHA1}"
Loading