diff --git a/tests/model_registry/negative_tests/test_db_migration.py b/tests/model_registry/negative_tests/test_db_migration.py index d43721e51..d6abae820 100644 --- a/tests/model_registry/negative_tests/test_db_migration.py +++ b/tests/model_registry/negative_tests/test_db_migration.py @@ -7,8 +7,8 @@ from utilities.constants import DscComponents from tests.model_registry.constants import MR_INSTANCE_NAME from kubernetes.dynamic.client import DynamicClient -from utilities.general import wait_for_pods_by_labels, wait_for_container_status - +from utilities.general import wait_for_container_status +from tests.model_registry.utils import wait_for_new_running_mr_pod LOGGER = get_logger(name=__name__) @@ -36,6 +36,7 @@ def test_db_migration_negative( admin_client: DynamicClient, model_registry_db_instance_pod: Pod, set_mr_db_dirty: int, + model_registry_pod: Pod, delete_mr_deployment: None, ): """ @@ -43,23 +44,23 @@ def test_db_migration_negative( The test will: 1. Set the dirty flag to 1 for the latest migration version 2. Delete the model registry deployment - 3. Check the logs for the expected error + 3. Wait for the old pods to be terminated + 4. Check the logs for the expected error """ - mr_pods = wait_for_pods_by_labels( + LOGGER.info(f"Model registry pod: {model_registry_pod.name}") + mr_pod = wait_for_new_running_mr_pod( admin_client=admin_client, + orig_pod_name=model_registry_pod.name, namespace=py_config["model_registry_namespace"], - label_selector=f"app={MR_INSTANCE_NAME}", - expected_num_pods=1, + instance_name=MR_INSTANCE_NAME, ) - mr_pod = mr_pods[0] - LOGGER.info("Waiting for model registry pod to crash") + LOGGER.info(f"Pod that should contains the container in CrashLoopBackOff state: {mr_pod.name}") assert wait_for_container_status(mr_pod, "rest-container", Pod.Status.CRASH_LOOPBACK_OFF) LOGGER.info("Checking the logs for the expected error") - log_output = mr_pod.log(container="rest-container") expected_error = ( f"Error: {{{{ALERT}}}} error connecting to datastore: Dirty database version {set_mr_db_dirty}. " "Fix and force version." ) - assert expected_error in log_output, "Expected error message not found in logs!" + assert expected_error in log_output, f"Expected error message not found in logs!\n{log_output}" diff --git a/tests/model_registry/utils.py b/tests/model_registry/utils.py index 1fe5eb0be..ba7f833d9 100644 --- a/tests/model_registry/utils.py +++ b/tests/model_registry/utils.py @@ -8,12 +8,12 @@ from ocp_resources.model_registry_modelregistry_opendatahub_io import ModelRegistry from kubernetes.dynamic.exceptions import ResourceNotFoundError from simple_logger.logger import get_logger -from timeout_sampler import TimeoutExpiredError, TimeoutSampler +from timeout_sampler import TimeoutExpiredError, TimeoutSampler, retry from kubernetes.dynamic.exceptions import NotFoundError from tests.model_registry.constants import MR_DB_IMAGE_DIGEST from tests.model_registry.exceptions import ModelRegistryResourceNotFoundError from utilities.exceptions import ProtocolNotSupportedError, TooManyServicesError -from utilities.constants import Protocols, Annotations +from utilities.constants import Protocols, Annotations, Timeout from model_registry import ModelRegistry as ModelRegistryClient from model_registry.types import RegisteredModel @@ -235,6 +235,42 @@ def wait_for_pods_running( return None +@retry(exceptions_dict={TimeoutError: []}, wait_timeout=Timeout.TIMEOUT_2MIN, sleep=5) +def wait_for_new_running_mr_pod( + admin_client: DynamicClient, + orig_pod_name: str, + namespace: str, + instance_name: str, +) -> Pod: + """ + Wait for the model registry pod to be replaced. + + Args: + admin_client (DynamicClient): The admin client. + orig_pod_name (str): The name of the original pod. + namespace (str): The namespace of the pod. + instance_name (str): The name of the instance. + Returns: + Pod object. + + Raises: + TimeoutError: If the pods are not replaced. + + """ + LOGGER.info("Waiting for pod to be replaced") + pods = list( + Pod.get( + dyn_client=admin_client, + namespace=namespace, + label_selector=f"app={instance_name}", + ) + ) + if pods and len(pods) == 1: + if pods[0].name != orig_pod_name and pods[0].status == Pod.Status.RUNNING: + return pods[0] + raise TimeoutError(f"Timeout waiting for pod {orig_pod_name} to be replaced") + + def generate_namespace_name(file_path: str) -> str: return (file_path.removesuffix(".py").replace("/", "-").replace("_", "-"))[-63:].split("-", 1)[-1]