Skip to content

Commit 502551f

Browse files
committed
fix: address review comments and refactor
Signed-off-by: lugi0 <lgiorgi@redhat.com>
1 parent 94a60ba commit 502551f

File tree

3 files changed

+110
-57
lines changed

3 files changed

+110
-57
lines changed

tests/model_registry/conftest.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -338,39 +338,42 @@ def updated_dsc_component_state_scope_class(
338338
# if we are not tearing down resources or we are in post upgrade, we don't need to do anything
339339
# the pre_upgrade/post_upgrade fixtures will handle the rest
340340
yield dsc_resource
341-
return
342-
343-
original_components = dsc_resource.instance.spec.components
344-
component_patch = request.param["component_patch"]
345-
346-
with ResourceEditor(patches={dsc_resource: {"spec": {"components": component_patch}}}):
347-
for component_name in component_patch:
348-
dsc_resource.wait_for_condition(condition=DscComponents.COMPONENT_MAPPING[component_name], status="True")
349-
if component_patch.get(DscComponents.MODELREGISTRY):
350-
namespace = Namespace(
351-
name=dsc_resource.instance.spec.components.modelregistry.registriesNamespace, ensure_exists=True
341+
else:
342+
original_components = dsc_resource.instance.spec.components
343+
component_patch = request.param["component_patch"]
344+
345+
with ResourceEditor(patches={dsc_resource: {"spec": {"components": component_patch}}}):
346+
for component_name in component_patch:
347+
dsc_resource.wait_for_condition(
348+
condition=DscComponents.COMPONENT_MAPPING[component_name], status="True"
349+
)
350+
if component_patch.get(DscComponents.MODELREGISTRY):
351+
namespace = Namespace(
352+
name=dsc_resource.instance.spec.components.modelregistry.registriesNamespace, ensure_exists=True
353+
)
354+
namespace.wait_for_status(status=Namespace.Status.ACTIVE)
355+
wait_for_pods_running(
356+
admin_client=admin_client,
357+
namespace_name=py_config["applications_namespace"],
358+
number_of_consecutive_checks=6,
352359
)
353-
namespace.wait_for_status(status=Namespace.Status.ACTIVE)
354-
wait_for_pods_running(
355-
admin_client=admin_client,
356-
namespace_name=py_config["applications_namespace"],
357-
number_of_consecutive_checks=6,
358-
)
359-
yield dsc_resource
360-
361-
for component_name, value in component_patch.items():
362-
LOGGER.info(f"Waiting for component {component_name} to be updated.")
363-
if original_components[component_name]["managementState"] == DscComponents.ManagementState.MANAGED:
364-
dsc_resource.wait_for_condition(condition=DscComponents.COMPONENT_MAPPING[component_name], status="True")
365-
if (
366-
component_name == DscComponents.MODELREGISTRY
367-
and value.get("managementState") == DscComponents.ManagementState.MANAGED
368-
):
369-
# Since namespace specified in registriesNamespace is automatically created after setting
370-
# managementStateto Managed. We need to explicitly delete it on clean up.
371-
namespace = Namespace(name=value["registriesNamespace"], ensure_exists=True)
372-
if namespace:
373-
namespace.delete(wait=True)
360+
yield dsc_resource
361+
362+
for component_name, value in component_patch.items():
363+
LOGGER.info(f"Waiting for component {component_name} to be updated.")
364+
if original_components[component_name]["managementState"] == DscComponents.ManagementState.MANAGED:
365+
dsc_resource.wait_for_condition(
366+
condition=DscComponents.COMPONENT_MAPPING[component_name], status="True"
367+
)
368+
if (
369+
component_name == DscComponents.MODELREGISTRY
370+
and value.get("managementState") == DscComponents.ManagementState.MANAGED
371+
):
372+
# Since namespace specified in registriesNamespace is automatically created after setting
373+
# managementStateto Managed. We need to explicitly delete it on clean up.
374+
namespace = Namespace(name=value["registriesNamespace"], ensure_exists=True)
375+
if namespace:
376+
namespace.delete(wait=True)
374377

375378

376379
@pytest.fixture(scope="class")

tests/model_registry/upgrade/test_model_registry_upgrade.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ocp_resources.model_registry_modelregistry_opendatahub_io import ModelRegistry
77
from simple_logger.logger import get_logger
88
from tests.model_registry.rest_api.utils import ModelRegistryV1Alpha1
9+
from tests.model_registry.utils import get_and_validate_registered_model
910

1011
LOGGER = get_logger(name=__name__)
1112

@@ -27,19 +28,22 @@ def test_registering_model_pre_upgrade(
2728
model_registry_client: ModelRegistryClient,
2829
registered_model: RegisteredModel,
2930
):
30-
model = model_registry_client.get_registered_model(name=MODEL_NAME)
31-
expected_attrs = {
32-
"id": registered_model.id,
33-
"name": registered_model.name,
34-
"description": registered_model.description,
35-
"owner": registered_model.owner,
36-
"state": registered_model.state,
37-
}
38-
errors = [
39-
f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}"
40-
for attr, expected in expected_attrs.items()
41-
if getattr(model, attr) != expected
42-
]
31+
errors = get_and_validate_registered_model(
32+
model_registry_client=model_registry_client, model_name=MODEL_NAME, registered_model=registered_model
33+
)
34+
# model = model_registry_client.get_registered_model(name=MODEL_NAME)
35+
# expected_attrs = {
36+
# "id": registered_model.id,
37+
# "name": registered_model.name,
38+
# "description": registered_model.description,
39+
# "owner": registered_model.owner,
40+
# "state": registered_model.state,
41+
# }
42+
# errors = [
43+
# f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}"
44+
# for attr, expected in expected_attrs.items()
45+
# if getattr(model, attr) != expected
46+
# ]
4347
if errors:
4448
pytest.fail("errors found in model registry response validation:\n{}".format("\n".join(errors)))
4549

@@ -55,29 +59,44 @@ def test_retrieving_model_post_upgrade(
5559
model_registry_client: ModelRegistryClient,
5660
model_registry_instance: ModelRegistry,
5761
):
58-
model = model_registry_client.get_registered_model(name=MODEL_NAME)
59-
expected_attrs = {
60-
"name": MODEL_DICT["model_name"],
61-
}
62-
errors = [
63-
f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}"
64-
for attr, expected in expected_attrs.items()
65-
if getattr(model, attr) != expected
66-
]
62+
errors = get_and_validate_registered_model(
63+
model_registry_client=model_registry_client,
64+
model_name=MODEL_NAME,
65+
)
66+
# model = model_registry_client.get_registered_model(name=MODEL_NAME)
67+
# expected_attrs = {
68+
# "name": MODEL_DICT["model_name"],
69+
# }
70+
# errors = [
71+
# f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}"
72+
# for attr, expected in expected_attrs.items()
73+
# if getattr(model, attr) != expected
74+
# ]
6775
if errors:
68-
LOGGER.error(f"received model: {model}")
6976
pytest.fail("errors found in model registry response validation:\n{}".format("\n".join(errors)))
7077

78+
def test_model_registry_instance_api_version_post_upgrade(
79+
self: Self,
80+
model_registry_instance: ModelRegistry,
81+
):
7182
# the following is valid for 2.22+
7283
api_version = model_registry_instance.instance.apiVersion
7384
expected_version = f"{ModelRegistry.ApiGroup.MODELREGISTRY_OPENDATAHUB_IO}/{ModelRegistry.ApiVersion.V1BETA1}"
7485
assert api_version == expected_version
7586

87+
def test_model_registry_instance_spec_post_upgrade(
88+
self: Self,
89+
model_registry_instance: ModelRegistry,
90+
):
7691
model_registry_instance_spec = model_registry_instance.instance.spec
7792
assert not model_registry_instance_spec.istio
7893
assert model_registry_instance_spec.oauthProxy.serviceRoute == "enabled"
7994

80-
# After v1alpha1 is removed (2.24?) this has to be removed
95+
def test_model_registry_instance_status_conversion_post_upgrade(
96+
self: Self,
97+
model_registry_instance: ModelRegistry,
98+
):
99+
# TODO: After v1alpha1 is removed (2.24?) this has to be removed
81100
mr_instance = ModelRegistryV1Alpha1(
82101
name=model_registry_instance.name, namespace=model_registry_instance.namespace, ensure_exists=True
83102
).instance

tests/model_registry/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import uuid
2-
from typing import Any
2+
from typing import Any, List
33

44
from kubernetes.dynamic import DynamicClient
55
from ocp_resources.pod import Pod
@@ -12,6 +12,8 @@
1212
from tests.model_registry.constants import MR_DB_IMAGE_DIGEST
1313
from utilities.exceptions import ProtocolNotSupportedError, TooManyServicesError
1414
from utilities.constants import Protocols, Annotations
15+
from model_registry import ModelRegistry as ModelRegistryClient
16+
from model_registry.types import RegisteredModel
1517

1618
ADDRESS_ANNOTATION_PREFIX: str = "routing.opendatahub.io/external-address-"
1719

@@ -330,3 +332,32 @@ def apply_mysql_args_and_volume_mounts(
330332
my_sql_container["args"] = mysql_args
331333
my_sql_container["volumeMounts"] = volumes_mounts
332334
return my_sql_container
335+
336+
337+
def get_and_validate_registered_model(
338+
model_registry_client: ModelRegistryClient,
339+
model_name: str,
340+
registered_model: RegisteredModel = None,
341+
) -> List[str]:
342+
"""
343+
Get and validate a registered model.
344+
"""
345+
model = model_registry_client.get_registered_model(name=model_name)
346+
if registered_model is not None:
347+
expected_attrs = {
348+
"id": registered_model.id,
349+
"name": registered_model.name,
350+
"description": registered_model.description,
351+
"owner": registered_model.owner,
352+
"state": registered_model.state,
353+
}
354+
else:
355+
expected_attrs = {
356+
"name": model_name,
357+
}
358+
errors = [
359+
f"Unexpected {attr} expected: {expected}, received {getattr(model, attr)}"
360+
for attr, expected in expected_attrs.items()
361+
if getattr(model, attr) != expected
362+
]
363+
return errors

0 commit comments

Comments
 (0)