diff --git a/tests/model_registry/conftest.py b/tests/model_registry/conftest.py index 96876d4f6..5684ca517 100644 --- a/tests/model_registry/conftest.py +++ b/tests/model_registry/conftest.py @@ -1,6 +1,11 @@ import pytest -import schemathesis from typing import Generator, Any + +from _pytest.config import Config +from _pytest.fixtures import FixtureRequest + +from ocp_resources.data_science_cluster import DataScienceCluster +from ocp_resources.resource import ResourceEditor from ocp_resources.secret import Secret from ocp_resources.namespace import Namespace from ocp_resources.service import Service @@ -9,96 +14,116 @@ from ocp_resources.model_registry import ModelRegistry from simple_logger.logger import get_logger from kubernetes.dynamic import DynamicClient +from model_registry import ModelRegistry as ModelRegistryClient -from tests.model_registry.utils import get_endpoint_from_mr_service, get_mr_service_by_label -from utilities.infra import create_ns -from utilities.constants import Annotations, Protocols - +from tests.model_registry.constants import DB_RESOURCES_NAME, MR_INSTANCE_NAME, MR_OPERATOR_NAME, DEFAULT_LABEL_DICT_DB +from tests.model_registry.utils import ( + get_endpoint_from_mr_service, + get_mr_service_by_label, + ModelRegistryV1Alpha1, + wait_for_pods_running, +) +from utilities.constants import Annotations, Protocols, DscComponents +from pytest_testconfig import config as py_config LOGGER = get_logger(name=__name__) -DB_RESOURCES_NAME: str = "model-registry-db" -MR_INSTANCE_NAME: str = "model-registry" -MR_OPERATOR_NAME: str = "model-registry-operator" -MR_NAMESPACE: str = "rhoai-model-registries" -DEFAULT_LABEL_DICT_DB: dict[str, str] = { - Annotations.KubernetesIo.NAME: DB_RESOURCES_NAME, - Annotations.KubernetesIo.INSTANCE: DB_RESOURCES_NAME, - Annotations.KubernetesIo.PART_OF: DB_RESOURCES_NAME, -} + +@pytest.fixture(scope="session") +def teardown_resources(pytestconfig: pytest.Config) -> bool: + delete_resources = True + + if pytestconfig.option.pre_upgrade: + if delete_resources := pytestconfig.option.delete_pre_upgrade_resources: + LOGGER.warning("Upgrade resources will be deleted") + + return delete_resources @pytest.fixture(scope="class") -def model_registry_namespace(admin_client: DynamicClient) -> Generator[Namespace, Any, Any]: - # This namespace should exist after Model Registry is enabled, but it can also be deleted - # from the cluster and does not get reconciled. Fetch if it exists, create otherwise. - ns = Namespace(name=MR_NAMESPACE, client=admin_client) - if ns.exists: - yield ns - else: - LOGGER.warning(f"{MR_NAMESPACE} namespace was not present, creating it") - with create_ns( - name=MR_NAMESPACE, - admin_client=admin_client, - teardown=False, - ) as ns: - yield ns +def model_registry_namespace(updated_dsc_component_state_scope_class: DataScienceCluster) -> str: + return updated_dsc_component_state_scope_class.instance.spec.components.modelregistry.registriesNamespace @pytest.fixture(scope="class") def model_registry_db_service( - admin_client: DynamicClient, model_registry_namespace: Namespace + pytestconfig: Config, + admin_client: DynamicClient, + model_registry_namespace: str, + teardown_resources: bool, ) -> Generator[Service, Any, Any]: - with Service( - client=admin_client, - name=DB_RESOURCES_NAME, - namespace=model_registry_namespace.name, - ports=[ - { - "name": "mysql", - "nodePort": 0, - "port": 3306, - "protocol": "TCP", - "appProtocol": "tcp", - "targetPort": 3306, - } - ], - selector={ - "name": DB_RESOURCES_NAME, - }, - label=DEFAULT_LABEL_DICT_DB, - annotations={ - "template.openshift.io/expose-uri": r"mysql://{.spec.clusterIP}:{.spec.ports[?(.name==\mysql\)].port}", - }, - ) as mr_db_service: + if pytestconfig.option.post_upgrade: + mr_db_service = Service(name=DB_RESOURCES_NAME, namespace=model_registry_namespace, ensure_exists=True) yield mr_db_service + mr_db_service.delete(wait=True) + else: + with Service( + client=admin_client, + name=DB_RESOURCES_NAME, + namespace=model_registry_namespace, + ports=[ + { + "name": "mysql", + "nodePort": 0, + "port": 3306, + "protocol": "TCP", + "appProtocol": "tcp", + "targetPort": 3306, + } + ], + selector={ + "name": DB_RESOURCES_NAME, + }, + label=DEFAULT_LABEL_DICT_DB, + annotations={ + "template.openshift.io/expose-uri": r"mysql://{.spec.clusterIP}:{.spec.ports[?(.name==\mysql\)].port}", + }, + teardown=teardown_resources, + ) as mr_db_service: + yield mr_db_service @pytest.fixture(scope="class") def model_registry_db_pvc( + pytestconfig: Config, admin_client: DynamicClient, - model_registry_namespace: Namespace, + model_registry_namespace: str, + teardown_resources: bool, ) -> Generator[PersistentVolumeClaim, Any, Any]: - with PersistentVolumeClaim( - accessmodes="ReadWriteOnce", - name=DB_RESOURCES_NAME, - namespace=model_registry_namespace.name, - client=admin_client, - size="5Gi", - label=DEFAULT_LABEL_DICT_DB, - ) as pvc: - yield pvc + if pytestconfig.option.post_upgrade: + mr_db_pvc = PersistentVolumeClaim( + name=DB_RESOURCES_NAME, namespace=model_registry_namespace, ensure_exists=True + ) + yield mr_db_pvc + mr_db_pvc.delete(wait=True) + else: + with PersistentVolumeClaim( + accessmodes="ReadWriteOnce", + name=DB_RESOURCES_NAME, + namespace=model_registry_namespace, + client=admin_client, + size="5Gi", + label=DEFAULT_LABEL_DICT_DB, + teardown=teardown_resources, + ) as pvc: + yield pvc @pytest.fixture(scope="class") def model_registry_db_secret( + pytestconfig: Config, admin_client: DynamicClient, - model_registry_namespace: Namespace, + model_registry_namespace: str, + teardown_resources: bool, ) -> Generator[Secret, Any, Any]: + if pytestconfig.option.post_upgrade: + mr_db_secret = Secret(name=DB_RESOURCES_NAME, namespace=model_registry_namespace, ensure_exists=True) + yield mr_db_secret + mr_db_secret.delete(wait=True) with Secret( client=admin_client, name=DB_RESOURCES_NAME, - namespace=model_registry_namespace.name, + namespace=model_registry_namespace, string_data={ "database-name": "model_registry", "database-password": "TheBlurstOfTimes", # pragma: allowlist secret @@ -110,176 +135,199 @@ def model_registry_db_secret( "template.openshift.io/expose-password": "'{.data[''database-password'']}'", "template.openshift.io/expose-username": "'{.data[''database-user'']}'", }, + teardown=teardown_resources, ) as mr_db_secret: yield mr_db_secret @pytest.fixture(scope="class") def model_registry_db_deployment( + pytestconfig: Config, admin_client: DynamicClient, - model_registry_namespace: Namespace, + model_registry_namespace: str, model_registry_db_secret: Secret, model_registry_db_pvc: PersistentVolumeClaim, model_registry_db_service: Service, + teardown_resources: bool, ) -> Generator[Deployment, Any, Any]: - with Deployment( - name=DB_RESOURCES_NAME, - namespace=model_registry_namespace.name, - annotations={ - "template.alpha.openshift.io/wait-for-ready": "true", - }, - label=DEFAULT_LABEL_DICT_DB, - replicas=1, - revision_history_limit=0, - selector={"matchLabels": {"name": DB_RESOURCES_NAME}}, - strategy={"type": "Recreate"}, - template={ - "metadata": { - "labels": { - "name": DB_RESOURCES_NAME, - "sidecar.istio.io/inject": "false", - } + if pytestconfig.option.post_upgrade: + db_deployment = Deployment(name=DB_RESOURCES_NAME, namespace=model_registry_namespace, ensure_exists=True) + yield db_deployment + db_deployment.delete(wait=True) + else: + with Deployment( + name=DB_RESOURCES_NAME, + namespace=model_registry_namespace, + annotations={ + "template.alpha.openshift.io/wait-for-ready": "true", }, - "spec": { - "containers": [ - { - "env": [ - { - "name": "MYSQL_USER", - "valueFrom": { - "secretKeyRef": { - "key": "database-user", - "name": f"{model_registry_db_secret.name}", - } + label=DEFAULT_LABEL_DICT_DB, + replicas=1, + revision_history_limit=0, + selector={"matchLabels": {"name": DB_RESOURCES_NAME}}, + strategy={"type": "Recreate"}, + template={ + "metadata": { + "labels": { + "name": DB_RESOURCES_NAME, + "sidecar.istio.io/inject": "false", + } + }, + "spec": { + "containers": [ + { + "env": [ + { + "name": "MYSQL_USER", + "valueFrom": { + "secretKeyRef": { + "key": "database-user", + "name": f"{model_registry_db_secret.name}", + } + }, }, - }, - { - "name": "MYSQL_PASSWORD", - "valueFrom": { - "secretKeyRef": { - "key": "database-password", - "name": f"{model_registry_db_secret.name}", - } + { + "name": "MYSQL_PASSWORD", + "valueFrom": { + "secretKeyRef": { + "key": "database-password", + "name": f"{model_registry_db_secret.name}", + } + }, }, - }, - { - "name": "MYSQL_ROOT_PASSWORD", - "valueFrom": { - "secretKeyRef": { - "key": "database-password", - "name": f"{model_registry_db_secret.name}", - } + { + "name": "MYSQL_ROOT_PASSWORD", + "valueFrom": { + "secretKeyRef": { + "key": "database-password", + "name": f"{model_registry_db_secret.name}", + } + }, }, - }, - { - "name": "MYSQL_DATABASE", - "valueFrom": { - "secretKeyRef": { - "key": "database-name", - "name": f"{model_registry_db_secret.name}", - } + { + "name": "MYSQL_DATABASE", + "valueFrom": { + "secretKeyRef": { + "key": "database-name", + "name": f"{model_registry_db_secret.name}", + } + }, }, + ], + "args": [ + "--datadir", + "/var/lib/mysql/datadir", + "--default-authentication-plugin=mysql_native_password", + ], + "image": "public.ecr.aws/docker/library/mysql@sha256:9de9d54fecee6253130e65154b930978b1fcc336bcc86dfd06e89b72a2588ebe", # noqa: E501 + "imagePullPolicy": "IfNotPresent", + "livenessProbe": { + "exec": { + "command": [ + "/bin/bash", + "-c", + "mysqladmin -u${MYSQL_USER} -p${MYSQL_ROOT_PASSWORD} ping", + ] + }, + "initialDelaySeconds": 15, + "periodSeconds": 10, + "timeoutSeconds": 5, }, - ], - "args": [ - "--datadir", - "/var/lib/mysql/datadir", - "--default-authentication-plugin=mysql_native_password", - ], - "image": "public.ecr.aws/docker/library/mysql@sha256:9de9d54fecee6253130e65154b930978b1fcc336bcc86dfd06e89b72a2588ebe", # noqa: E501 - "imagePullPolicy": "IfNotPresent", - "livenessProbe": { - "exec": { - "command": [ - "/bin/bash", - "-c", - "mysqladmin -u${MYSQL_USER} -p${MYSQL_ROOT_PASSWORD} ping", - ] - }, - "initialDelaySeconds": 15, - "periodSeconds": 10, - "timeoutSeconds": 5, - }, - "name": "mysql", - "ports": [{"containerPort": 3306, "protocol": "TCP"}], - "readinessProbe": { - "exec": { - "command": [ - "/bin/bash", - "-c", - 'mysql -D ${MYSQL_DATABASE} -u${MYSQL_USER} -p${MYSQL_ROOT_PASSWORD} -e "SELECT 1"', - ] + "name": "mysql", + "ports": [{"containerPort": 3306, "protocol": "TCP"}], + "readinessProbe": { + "exec": { + "command": [ + "/bin/bash", + "-c", + 'mysql -D ${MYSQL_DATABASE} -u${MYSQL_USER} -p${MYSQL_ROOT_PASSWORD} -e "SELECT 1"', # noqa: E501 + ] + }, + "initialDelaySeconds": 10, + "timeoutSeconds": 5, }, - "initialDelaySeconds": 10, - "timeoutSeconds": 5, - }, - "securityContext": {"capabilities": {}, "privileged": False}, - "terminationMessagePath": "/dev/termination-log", - "volumeMounts": [ - { - "mountPath": "/var/lib/mysql", - "name": f"{DB_RESOURCES_NAME}-data", - } - ], - } - ], - "dnsPolicy": "ClusterFirst", - "restartPolicy": "Always", - "volumes": [ - { - "name": f"{DB_RESOURCES_NAME}-data", - "persistentVolumeClaim": {"claimName": DB_RESOURCES_NAME}, - } - ], + "securityContext": {"capabilities": {}, "privileged": False}, + "terminationMessagePath": "/dev/termination-log", + "volumeMounts": [ + { + "mountPath": "/var/lib/mysql", + "name": f"{DB_RESOURCES_NAME}-data", + } + ], + } + ], + "dnsPolicy": "ClusterFirst", + "restartPolicy": "Always", + "volumes": [ + { + "name": f"{DB_RESOURCES_NAME}-data", + "persistentVolumeClaim": {"claimName": DB_RESOURCES_NAME}, + } + ], + }, }, - }, - wait_for_resource=True, - ) as mr_db_deployment: - mr_db_deployment.wait_for_replicas(deployed=True) - yield mr_db_deployment + wait_for_resource=True, + teardown=teardown_resources, + ) as mr_db_deployment: + mr_db_deployment.wait_for_replicas(deployed=True) + yield mr_db_deployment @pytest.fixture(scope="class") def model_registry_instance( + pytestconfig: Config, admin_client: DynamicClient, - model_registry_namespace: Namespace, + model_registry_namespace: str, model_registry_db_deployment: Deployment, model_registry_db_secret: Secret, model_registry_db_service: Service, + teardown_resources: bool, ) -> Generator[ModelRegistry, Any, Any]: - with ModelRegistry( - name=MR_INSTANCE_NAME, - namespace=model_registry_namespace.name, - label={ - Annotations.KubernetesIo.NAME: MR_INSTANCE_NAME, - Annotations.KubernetesIo.INSTANCE: MR_INSTANCE_NAME, - Annotations.KubernetesIo.PART_OF: MR_OPERATOR_NAME, - Annotations.KubernetesIo.CREATED_BY: MR_OPERATOR_NAME, - }, - grpc={}, - rest={}, - istio={ - "authProvider": "redhat-ods-applications-auth-provider", - "gateway": {"grpc": {"tls": {}}, "rest": {"tls": {}}}, - }, - mysql={ - "host": f"{model_registry_db_deployment.name}.{model_registry_db_deployment.namespace}.svc.cluster.local", - "database": model_registry_db_secret.string_data["database-name"], - "passwordSecret": {"key": "database-password", "name": DB_RESOURCES_NAME}, - "port": 3306, - "skipDBCreation": False, - "username": model_registry_db_secret.string_data["database-user"], - }, - wait_for_resource=True, - ) as mr: - mr.wait_for_condition(condition="Available", status="True") - yield mr + host = f"{model_registry_db_deployment.name}.{model_registry_db_deployment.namespace}.svc.cluster.local" + if pytestconfig.option.post_upgrade: + mr_instance = ModelRegistryV1Alpha1( + name=MR_INSTANCE_NAME, namespace=model_registry_namespace, ensure_exists=True + ) + yield mr_instance + mr_instance.delete(wait=True) + else: + with ModelRegistryV1Alpha1( + name=MR_INSTANCE_NAME, + namespace=model_registry_namespace, + label={ + Annotations.KubernetesIo.NAME: MR_INSTANCE_NAME, + Annotations.KubernetesIo.INSTANCE: MR_INSTANCE_NAME, + Annotations.KubernetesIo.PART_OF: MR_OPERATOR_NAME, + Annotations.KubernetesIo.CREATED_BY: MR_OPERATOR_NAME, + }, + grpc={}, + rest={}, + istio={ + "authProvider": "redhat-ods-applications-auth-provider", + "gateway": {"grpc": {"tls": {}}, "rest": {"tls": {}}}, + }, + mysql={ + "host": host, + "database": model_registry_db_secret.string_data["database-name"], + "passwordSecret": {"key": "database-password", "name": DB_RESOURCES_NAME}, + "port": 3306, + "skipDBCreation": False, + "username": model_registry_db_secret.string_data["database-user"], + }, + wait_for_resource=True, + teardown=teardown_resources, + ) as mr: + mr.wait_for_condition(condition="Available", status="True") + wait_for_pods_running( + admin_client=admin_client, namespace_name=model_registry_namespace, number_of_consecutive_checks=6 + ) + yield mr @pytest.fixture(scope="class") def model_registry_instance_service( admin_client: DynamicClient, - model_registry_namespace: Namespace, + model_registry_namespace: str, model_registry_instance: ModelRegistry, ) -> Service: return get_mr_service_by_label( @@ -289,17 +337,125 @@ def model_registry_instance_service( @pytest.fixture(scope="class") def model_registry_instance_rest_endpoint( - admin_client: DynamicClient, model_registry_instance_service: Service, ) -> str: - return get_endpoint_from_mr_service( - client=admin_client, svc=model_registry_instance_service, protocol=Protocols.REST - ) + return get_endpoint_from_mr_service(svc=model_registry_instance_service, protocol=Protocols.REST) + + +@pytest.fixture(scope="class") +def updated_dsc_component_state_scope_class( + pytestconfig: Config, + admin_client: DynamicClient, + request: FixtureRequest, + dsc_resource: DataScienceCluster, + teardown_resources: bool, +) -> Generator[DataScienceCluster, Any, Any]: + if not teardown_resources or pytestconfig.option.post_upgrade: + # if we are not tearing down resources or we are in post upgrade, we don't need to do anything + # the pre_upgrade/post_upgrade fixtures will handle the rest + yield dsc_resource + else: + original_components = dsc_resource.instance.spec.components + component_patch = request.param["component_patch"] + + with ResourceEditor(patches={dsc_resource: {"spec": {"components": component_patch}}}): + for component_name in component_patch: + dsc_resource.wait_for_condition( + condition=DscComponents.COMPONENT_MAPPING[component_name], status="True" + ) + if component_patch.get(DscComponents.MODELREGISTRY): + namespace = Namespace( + name=dsc_resource.instance.spec.components.modelregistry.registriesNamespace, ensure_exists=True + ) + namespace.wait_for_status(status=Namespace.Status.ACTIVE) + wait_for_pods_running( + admin_client=admin_client, + namespace_name=py_config["applications_namespace"], + number_of_consecutive_checks=6, + ) + yield dsc_resource + + for component_name, value in component_patch.items(): + LOGGER.info(f"Waiting for component {component_name} to be updated.") + if original_components[component_name]["managementState"] == DscComponents.ManagementState.MANAGED: + dsc_resource.wait_for_condition( + condition=DscComponents.COMPONENT_MAPPING[component_name], status="True" + ) + if ( + component_name == DscComponents.MODELREGISTRY + and value.get("managementState") == DscComponents.ManagementState.MANAGED + ): + # Since namespace specified in registriesNamespace is automatically created after setting + # managementStateto Managed. We need to explicitly delete it on clean up. + namespace = Namespace(name=value["registriesNamespace"], ensure_exists=True) + if namespace: + namespace.delete(wait=True) + + +@pytest.fixture(scope="class") +def pre_upgrade_dsc_patch( + dsc_resource: DataScienceCluster, + admin_client: DynamicClient, +) -> DataScienceCluster: + original_components = dsc_resource.instance.spec.components + component_patch = {DscComponents.MODELREGISTRY: {"managementState": DscComponents.ManagementState.MANAGED}} + if ( + original_components.get(DscComponents.MODELREGISTRY).get("managementState") + == DscComponents.ManagementState.MANAGED + ): + pytest.fail("Model Registry is already set to Managed before upgrade - was this intentional?") + else: + editor = ResourceEditor(patches={dsc_resource: {"spec": {"components": component_patch}}}) + editor.update() + dsc_resource.wait_for_condition(condition=DscComponents.COMPONENT_MAPPING["modelregistry"], status="True") + namespace = Namespace( + name=dsc_resource.instance.spec.components.modelregistry.registriesNamespace, ensure_exists=True + ) + namespace.wait_for_status(status=Namespace.Status.ACTIVE) + wait_for_pods_running( + admin_client=admin_client, + namespace_name=py_config["applications_namespace"], + number_of_consecutive_checks=6, + ) + return dsc_resource + + +@pytest.fixture(scope="class") +def post_upgrade_dsc_patch( + dsc_resource: DataScienceCluster, +) -> Generator[DataScienceCluster, Any, Any]: + # yield right away so that the rest of the fixture is executed at teardown time + yield dsc_resource + + # the state we found after the upgrade + original_components = dsc_resource.instance.spec.components + # We don't have an easy way to figure out the state of the components before the upgrade at runtime + # For now we know that MR has to go back to Removed after post upgrade tests are run + component_patch = {DscComponents.MODELREGISTRY: {"managementState": DscComponents.ManagementState.REMOVED}} + if ( + original_components.get(DscComponents.MODELREGISTRY).get("managementState") + == DscComponents.ManagementState.REMOVED + ): + pytest.fail("Model Registry is already set to Removed after upgrade - was this intentional?") + else: + editor = ResourceEditor(patches={dsc_resource: {"spec": {"components": component_patch}}}) + editor.update() + ns = original_components.get(DscComponents.MODELREGISTRY).get("registriesNamespace") + namespace = Namespace(name=ns, ensure_exists=True) + if namespace: + namespace.delete(wait=True) @pytest.fixture(scope="class") -def generated_schema(model_registry_instance_rest_endpoint: str) -> Any: - return schemathesis.from_uri( - uri="https://raw.githubusercontent.com/kubeflow/model-registry/main/api/openapi/model-registry.yaml", - base_url=f"https://{model_registry_instance_rest_endpoint}/", +def model_registry_client( + current_client_token: str, + model_registry_instance_rest_endpoint: str, +) -> ModelRegistryClient: + server, port = model_registry_instance_rest_endpoint.split(":") + return ModelRegistryClient( + server_address=f"{Protocols.HTTPS}://{server}", + port=int(port), + author="opendatahub-test", + user_token=current_client_token, + is_secure=False, ) diff --git a/tests/model_registry/constants.py b/tests/model_registry/constants.py index 59aa8ccbb..33fda6dd9 100644 --- a/tests/model_registry/constants.py +++ b/tests/model_registry/constants.py @@ -1,2 +1,18 @@ +from utilities.constants import Annotations + + class ModelRegistryEndpoints: REGISTERED_MODELS: str = "/api/model_registry/v1alpha3/registered_models" + + +MODEL_NAME: str = "my-model" +MODEL_DESCRIPTION: str = "lorem ipsum" +DB_RESOURCES_NAME: str = "model-registry-db" +MR_INSTANCE_NAME: str = "model-registry" +MR_OPERATOR_NAME: str = "model-registry-operator" +MR_NAMESPACE: str = "rhoai-model-registries" +DEFAULT_LABEL_DICT_DB: dict[str, str] = { + Annotations.KubernetesIo.NAME: DB_RESOURCES_NAME, + Annotations.KubernetesIo.INSTANCE: DB_RESOURCES_NAME, + Annotations.KubernetesIo.PART_OF: DB_RESOURCES_NAME, +} diff --git a/tests/model_registry/test_model_registry_creation.py b/tests/model_registry/test_model_registry_creation.py index f5114addf..ce283f5ce 100644 --- a/tests/model_registry/test_model_registry_creation.py +++ b/tests/model_registry/test_model_registry_creation.py @@ -1,28 +1,34 @@ import pytest from typing import Self +from model_registry import ModelRegistry as ModelRegistryClient + from simple_logger.logger import get_logger -from ocp_resources.data_science_cluster import DataScienceCluster -from utilities.constants import Protocols, DscComponents, ModelFormat -from model_registry import ModelRegistry +from tests.model_registry.constants import MODEL_NAME, MR_NAMESPACE +from tests.model_registry.utils import register_model, get_and_validate_registered_model +from utilities.constants import DscComponents LOGGER = get_logger(name=__name__) -MODEL_NAME: str = "my-model" @pytest.mark.parametrize( - "updated_dsc_component_state", + "updated_dsc_component_state_scope_class", [ pytest.param( { - "component_name": DscComponents.MODELREGISTRY, - "desired_state": DscComponents.ManagementState.MANAGED, + "component_patch": { + DscComponents.MODELREGISTRY: { + "managementState": DscComponents.ManagementState.MANAGED, + "registriesNamespace": MR_NAMESPACE, + }, + } }, ) ], indirect=True, ) +@pytest.mark.usefixtures("updated_dsc_component_state_scope_class") class TestModelRegistryCreation: """ Tests the creation of a model registry. If the component is set to 'Removed' it will be switched to 'Managed' @@ -32,51 +38,9 @@ class TestModelRegistryCreation: @pytest.mark.smoke def test_registering_model( self: Self, - model_registry_instance_rest_endpoint: str, - current_client_token: str, - updated_dsc_component_state: DataScienceCluster, + model_registry_client: ModelRegistryClient, ): - # address and port need to be split in the client instantiation - server, port = model_registry_instance_rest_endpoint.split(":") - registry = ModelRegistry( - server_address=f"{Protocols.HTTPS}://{server}", - port=port, - author="opendatahub-test", - user_token=current_client_token, - is_secure=False, - ) - model = registry.register_model( - name=MODEL_NAME, - uri="https://storage-place.my-company.com", - version="2.0.0", - description="lorem ipsum", - model_format_name=ModelFormat.ONNX, - model_format_version="1", - storage_key="my-data-connection", - storage_path="path/to/model", - metadata={ - "int_key": 1, - "bool_key": False, - "float_key": 3.14, - "str_key": "str_value", - }, + model = register_model(model_registry_client=model_registry_client) + get_and_validate_registered_model( + model_registry_client=model_registry_client, model_name=MODEL_NAME, registered_model=model ) - registered_model = registry.get_registered_model(MODEL_NAME) - errors = [] - if not registered_model.id == model.id: - errors.append(f"Unexpected id, received {registered_model.id}") - if not registered_model.name == model.name: - errors.append(f"Unexpected name, received {registered_model.name}") - if not registered_model.description == model.description: - errors.append(f"Unexpected description, received {registered_model.description}") - if not registered_model.owner == model.owner: - errors.append(f"Unexpected owner, received {registered_model.owner}") - if not registered_model.state == model.state: - errors.append(f"Unexpected state, received {registered_model.state}") - - assert not errors, "errors found in model registry response validation:\n{}".format("\n".join(errors)) - - # TODO: Edit a registered model - # TODO: Add additional versions for a model - # TODO: List all available models - # TODO: List all versions of a model diff --git a/tests/model_registry/test_rest_api.py b/tests/model_registry/test_rest_api.py deleted file mode 100644 index a85f58aca..000000000 --- a/tests/model_registry/test_rest_api.py +++ /dev/null @@ -1,18 +0,0 @@ -import schemathesis -import pytest -from simple_logger.logger import get_logger - -LOGGER = get_logger(name=__name__) - -schema = schemathesis.from_pytest_fixture("generated_schema") - - -# TODO: This is a Stateless test due to how the openAPI spec is currently defined in upstream. -# Once it is updated to support Stateful testing of the API we can enable this to run every time, -# but for now having it run manually to check the existing failures is more than enough. -@pytest.mark.skip(reason="Only run manually for now") -@schema.parametrize() -def test_mr_api(case, current_client_token): - case.headers["Authorization"] = f"Bearer {current_client_token}" - case.headers["Content-Type"] = "application/json" - case.call_and_validate(verify=False) diff --git a/tests/model_registry/upgrade/__init__.py b/tests/model_registry/upgrade/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/model_registry/upgrade/test_model_registry_upgrade.py b/tests/model_registry/upgrade/test_model_registry_upgrade.py new file mode 100644 index 000000000..0b150aca0 --- /dev/null +++ b/tests/model_registry/upgrade/test_model_registry_upgrade.py @@ -0,0 +1,31 @@ +import pytest +from typing import Self +from tests.model_registry.constants import MODEL_NAME +from model_registry import ModelRegistry as ModelRegistryClient +from simple_logger.logger import get_logger +from tests.model_registry.utils import get_and_validate_registered_model, register_model + +LOGGER = get_logger(name=__name__) + + +@pytest.mark.usefixtures("pre_upgrade_dsc_patch") +class TestPreUpgradeModelRegistry: + @pytest.mark.pre_upgrade + def test_registering_model_pre_upgrade( + self: Self, + model_registry_client: ModelRegistryClient, + ): + model = register_model(model_registry_client=model_registry_client) + get_and_validate_registered_model( + model_registry_client=model_registry_client, model_name=MODEL_NAME, registered_model=model + ) + + +@pytest.mark.usefixtures("post_upgrade_dsc_patch") +class TestPostUpgradeModelRegistry: + @pytest.mark.post_upgrade + def test_retrieving_model_post_upgrade( + self: Self, + model_registry_client: ModelRegistryClient, + ): + model_registry_client.get_registered_model(name=MODEL_NAME) diff --git a/tests/model_registry/utils.py b/tests/model_registry/utils.py index 74e43b433..c7b50b141 100644 --- a/tests/model_registry/utils.py +++ b/tests/model_registry/utils.py @@ -1,21 +1,29 @@ +from typing import Any +from simple_logger.logger import get_logger + from kubernetes.dynamic import DynamicClient -from ocp_resources.namespace import Namespace +from timeout_sampler import TimeoutSampler, TimeoutExpiredError +from model_registry import ModelRegistry as ModelRegistryClient +from model_registry.types import RegisteredModel +from ocp_resources.pod import Pod from ocp_resources.service import Service from ocp_resources.model_registry import ModelRegistry -from kubernetes.dynamic.exceptions import ResourceNotFoundError +from kubernetes.dynamic.exceptions import ResourceNotFoundError, NotFoundError +from tests.model_registry.constants import MODEL_NAME, MODEL_DESCRIPTION from utilities.exceptions import ProtocolNotSupportedError, TooManyServicesError -from utilities.constants import Protocols - +from utilities.constants import Protocols, ModelFormat ADDRESS_ANNOTATION_PREFIX: str = "routing.opendatahub.io/external-address-" +LOGGER = get_logger(name=__name__) + -def get_mr_service_by_label(client: DynamicClient, ns: Namespace, mr_instance: ModelRegistry) -> Service: +def get_mr_service_by_label(client: DynamicClient, ns: str, mr_instance: ModelRegistry) -> Service: """ Args: client (DynamicClient): OCP Client to use. - ns (Namespace): Namespace object where to find the Service + ns (str): Namespace name where to find the Service mr_instance (ModelRegistry): Model Registry instance Returns: @@ -28,7 +36,7 @@ def get_mr_service_by_label(client: DynamicClient, ns: Namespace, mr_instance: M svcs for svcs in Service.get( dyn_client=client, - namespace=ns.name, + namespace=ns, label_selector=f"app={mr_instance.name},component=model-registry", ) ]: @@ -38,8 +46,124 @@ def get_mr_service_by_label(client: DynamicClient, ns: Namespace, mr_instance: M raise ResourceNotFoundError(f"{mr_instance.name} has no Service") -def get_endpoint_from_mr_service(client: DynamicClient, svc: Service, protocol: str) -> str: +def get_endpoint_from_mr_service(svc: Service, protocol: str) -> str: if protocol in (Protocols.REST, Protocols.GRPC): return svc.instance.metadata.annotations[f"{ADDRESS_ANNOTATION_PREFIX}{protocol}"] else: raise ProtocolNotSupportedError(protocol) + + +def get_pod_container_error_status(pod: Pod) -> str | None: + """ + Check container error status for a given pod and if any containers is in waiting state, return that information + """ + pod_instance_status = pod.instance.status + for container_status in pod_instance_status.get("containerStatuses", []): + if waiting_container := container_status.get("state", {}).get("waiting"): + return waiting_container["reason"] if waiting_container.get("reason") else waiting_container + return "" + + +def get_not_running_pods(pods: list[Pod]) -> list[dict[str, Any]]: + # Gets all the non-running pods from a given namespace. + # Note: We need to keep track of pods marked for deletion as not running. This would ensure any + # pod that was spun up in place of pod marked for deletion, are not ignored + pods_not_running = [] + try: + for pod in pods: + pod_instance = pod.instance + if container_status_error := get_pod_container_error_status(pod=pod): + pods_not_running.append({pod.name: container_status_error}) + + if pod_instance.metadata.get("deletionTimestamp") or pod_instance.status.phase not in ( + pod.Status.RUNNING, + pod.Status.SUCCEEDED, + ): + pods_not_running.append({pod.name: pod.status}) + except (ResourceNotFoundError, NotFoundError) as exc: + LOGGER.warning("Ignoring pod that disappeared during cluster sanity check: %s", exc) + return pods_not_running + + +def wait_for_pods_running( + admin_client: DynamicClient, + namespace_name: str, + number_of_consecutive_checks: int = 1, +) -> bool | None: + """ + Waits for all pods in a given namespace to reach Running/Completed state. To avoid catching all pods in running + state too soon, use number_of_consecutive_checks with appropriate values. + """ + samples = TimeoutSampler( + wait_timeout=180, + sleep=5, + func=get_not_running_pods, + pods=list(Pod.get(dyn_client=admin_client, namespace=namespace_name)), + exceptions_dict={NotFoundError: [], ResourceNotFoundError: []}, + ) + sample = None + try: + current_check = 0 + for sample in samples: + if not sample: + current_check += 1 + if current_check >= number_of_consecutive_checks: + return True + else: + current_check = 0 + except TimeoutExpiredError: + if sample: + LOGGER.error( + f"timeout waiting for all pods in namespace {namespace_name} to reach " + f"running state, following pods are in not running state: {sample}" + ) + raise + return None + + +def register_model(model_registry_client: ModelRegistryClient) -> RegisteredModel: + return model_registry_client.register_model( + name=MODEL_NAME, + uri="https://storage-place.my-company.com", + version="2.0.0", + description=MODEL_DESCRIPTION, + model_format_name=ModelFormat.ONNX, + model_format_version="1", + storage_key="my-data-connection", + storage_path="path/to/model", + metadata={ + "int_key": 1, + "bool_key": False, + "float_key": 3.14, + "str_key": "str_value", + }, + ) + + +def get_and_validate_registered_model( + model_registry_client: ModelRegistryClient, + model_name: str, + registered_model: RegisteredModel, +) -> None: + """ + Get and validate a registered model. + """ + model = model_registry_client.get_registered_model(name=model_name) + expected_attrs = { + "id": registered_model.id, + "name": registered_model.name, + "description": registered_model.description, + "owner": registered_model.owner, + "state": registered_model.state, + } + LOGGER.info(f"Expected: {expected_attrs}") + errors = [ + f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}" + for attr, expected in expected_attrs.items() + if getattr(model, attr) != expected + ] + assert not errors, f"Model Registry validation failed with error: {errors}" + + +class ModelRegistryV1Alpha1(ModelRegistry): + api_version = f"{ModelRegistry.ApiGroup.MODELREGISTRY_OPENDATAHUB_IO}/{ModelRegistry.ApiVersion.V1ALPHA1}"