Skip to content

Commit b0cc478

Browse files
dbasunagadolfo-ab
authored andcommitted
Add tests to archive and unarchive models, versions (opendatahub-io#329)
* on rebase clean commented-by- labels * Add tests to archive and unarchive models, versions updates! 1b232c3 updates! 8195a8f
1 parent a9535b4 commit b0cc478

5 files changed

Lines changed: 213 additions & 33 deletions

File tree

tests/model_registry/rest_api/conftest.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tests.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI
66
from tests.model_registry.rest_api.utils import register_model_rest_api, execute_model_registry_patch_command
77
from utilities.constants import Protocols
8+
from utilities.exceptions import MissingParameter
89

910

1011
@pytest.fixture(scope="class")
@@ -33,17 +34,32 @@ def registered_model_rest_api(
3334
)
3435

3536

36-
@pytest.fixture(scope="class")
37-
def updated_model_artifact(
37+
@pytest.fixture()
38+
def updated_model_registry_resource(
3839
request: pytest.FixtureRequest,
3940
model_registry_rest_url: str,
4041
model_registry_rest_headers: dict[str, str],
4142
registered_model_rest_api: dict[str, Any],
4243
) -> dict[str, Any]:
43-
model_artifact_id = registered_model_rest_api["model_artifact"]["id"]
44-
assert model_artifact_id, f"Model artifact id not found: {registered_model_rest_api['model_artifact']}"
44+
"""
45+
Generic fixture to update any model registry resource via PATCH request.
46+
47+
Expects request.param to contain:
48+
- resource_name: Key to identify the resource in registered_model_rest_api
49+
- api_name: API endpoint name for the resource type
50+
- data: JSON data to send in the PATCH request
51+
52+
Returns:
53+
Dictionary containing the updated resource data
54+
"""
55+
resource_name = request.param.get("resource_name")
56+
api_name = request.param.get("api_name")
57+
if not (api_name and resource_name):
58+
raise MissingParameter("resource_name and api_name are required parameters for this fixture.")
59+
resource_id = registered_model_rest_api[resource_name]["id"]
60+
assert resource_id, f"Resource id not found: {registered_model_rest_api[resource_name]}"
4561
return execute_model_registry_patch_command(
46-
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}model_artifacts/{model_artifact_id}",
62+
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}{api_name}/{resource_id}",
4763
headers=model_registry_rest_headers,
48-
data_json=request.param,
64+
data_json=request.param["data"],
4965
)

tests/model_registry/rest_api/constants.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,24 @@
22

33
from utilities.constants import ModelFormat
44

5-
MODEL_REGISTER: dict[str, str] = {
5+
MODEL_REGISTER: dict[str, Any] = {
66
"name": "model-rest-api",
77
"description": "Model created via rest call",
88
"owner": "opendatahub-tests",
9+
"customProperties": {
10+
"test_rm_bool_property": {"bool_value": False, "metadataType": "MetadataBoolValue"},
11+
"test_rm_str_property": {"string_value": "my_value", "metadataType": "MetadataStringValue"},
12+
},
913
}
1014
MODEL_VERSION: dict[str, Any] = {
1115
"name": "v0.0.1",
1216
"state": "LIVE",
1317
"author": "opendatahub-tests",
1418
"description": "Model version created via rest call",
19+
"customProperties": {
20+
"test_mv_bool_property": {"bool_value": True, "metadataType": "MetadataBoolValue"},
21+
"test_mv_str_property": {"string_value": "my_value", "metadataType": "MetadataStringValue"},
22+
},
1523
}
1624

1725
MODEL_ARTIFACT: dict[str, Any] = {
@@ -22,6 +30,10 @@
2230
"modelFormatName": ModelFormat.ONNX,
2331
"modelFormatVersion": "v1",
2432
"artifactType": "model-artifact",
33+
"customProperties": {
34+
"test_ma_bool_property": {"bool_value": True, "metadataType": "MetadataBoolValue"},
35+
"test_ma_str_property": {"string_value": "my_value", "metadataType": "MetadataStringValue"},
36+
},
2537
}
2638
MODEL_REGISTER_DATA = {
2739
"register_model_data": MODEL_REGISTER,

tests/model_registry/rest_api/test_model_registry_rest_api.py

Lines changed: 140 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,25 @@
33
from pytest_testconfig import config as py_config
44

55
from tests.model_registry.rest_api.constants import MODEL_REGISTER, MODEL_ARTIFACT, MODEL_VERSION, MODEL_REGISTER_DATA
6+
from tests.model_registry.rest_api.utils import validate_resource_attributes
67
from utilities.constants import DscComponents
78
from simple_logger.logger import get_logger
89

910
LOGGER = get_logger(name=__name__)
11+
CUSTOM_PROPERTY = {
12+
"customProperties": {
13+
"my_bool_property": {"bool_value": True, "metadataType": "MetadataBoolValue"},
14+
"my_str_property": {"string_value": "my_value", "metadataType": "MetadataStringValue"},
15+
"my_double_property": {"double_value": 500.01, "metadataType": "MetadataDoubleValue"},
16+
}
17+
}
18+
MODEL_VERSION_DESCRIPTION = {"description": "updated model version description"}
19+
STATE_ARCHIVED = {"state": "ARCHIVED"}
20+
STATE_LIVE = {"state": "LIVE"}
21+
REGISTERED_MODEL_DESCRIPTION = {"description": "updated registered model description"}
22+
MODEL_FORMAT_VERSION = {"modelFormatVersion": "v2"}
23+
MODEL_FORMAT_NAME = {"modelFormatName": "tensorflow"}
24+
MODEL_ARTIFACT_DESCRIPTION = {"description": "updated artifact description"}
1025

1126

1227
@pytest.mark.parametrize(
@@ -59,47 +74,146 @@ def test_validate_model_registry_resource(
5974
expected_params: dict[str, str],
6075
data_key: str,
6176
):
62-
errors: list[str:Any]
63-
created_resource_data = registered_model_rest_api[data_key]
64-
if errors := [
65-
{key: [expected_params[key], created_resource_data.get(key)]}
66-
for key in expected_params.keys()
67-
if not created_resource_data.get(key) or created_resource_data[key] != expected_params[key]
68-
]:
69-
pytest.fail(f"Model did not get created with expected values: {errors}")
70-
LOGGER.info(f"Successfully validated: {created_resource_data['name']}")
77+
validate_resource_attributes(
78+
expected_params=expected_params,
79+
actual_resource_data=registered_model_rest_api[data_key],
80+
resource_name=data_key,
81+
)
7182

7283
@pytest.mark.parametrize(
73-
"updated_model_artifact, expected_param",
84+
"updated_model_registry_resource, expected_param",
7485
[
7586
pytest.param(
76-
{"description": "updated description"},
77-
{"description": "updated description"},
87+
{
88+
"resource_name": "model_artifact",
89+
"api_name": "model_artifacts",
90+
"data": MODEL_ARTIFACT_DESCRIPTION,
91+
},
92+
MODEL_ARTIFACT_DESCRIPTION,
7893
id="test_validate_updated_artifact_description",
7994
),
8095
pytest.param(
81-
{"modelFormatName": "tensorflow"},
82-
{"modelFormatName": "tensorflow"},
96+
{
97+
"resource_name": "model_artifact",
98+
"api_name": "model_artifacts",
99+
"data": MODEL_FORMAT_NAME,
100+
},
101+
MODEL_FORMAT_NAME,
83102
id="test_validate_updated_artifact_model_format_name",
84103
),
85104
pytest.param(
86-
{"modelFormatVersion": "v2"},
87-
{"modelFormatVersion": "v2"},
105+
{
106+
"resource_name": "model_artifact",
107+
"api_name": "model_artifacts",
108+
"data": MODEL_FORMAT_VERSION,
109+
},
110+
MODEL_FORMAT_VERSION,
88111
id="test_validate_updated_artifact_model_format_version",
89112
),
90113
],
91-
indirect=["updated_model_artifact"],
114+
indirect=["updated_model_registry_resource"],
92115
)
93116
def test_create_update_model_artifact(
94117
self,
95-
updated_model_artifact: dict[str, Any],
118+
updated_model_registry_resource: dict[str, Any],
119+
expected_param: dict[str, Any],
120+
):
121+
"""
122+
Update model artifacts and ensure the updated values are reflected on the artifact
123+
"""
124+
125+
validate_resource_attributes(
126+
expected_params=expected_param,
127+
actual_resource_data=updated_model_registry_resource,
128+
resource_name="model artifact",
129+
)
130+
131+
@pytest.mark.parametrize(
132+
"updated_model_registry_resource, expected_param",
133+
[
134+
pytest.param(
135+
{
136+
"resource_name": "model_version",
137+
"api_name": "model_versions",
138+
"data": MODEL_VERSION_DESCRIPTION,
139+
},
140+
MODEL_VERSION_DESCRIPTION,
141+
id="test_validate_updated_version_description",
142+
),
143+
pytest.param(
144+
{"resource_name": "model_version", "api_name": "model_versions", "data": STATE_ARCHIVED},
145+
STATE_ARCHIVED,
146+
id="test_validate_updated_version_state_archived",
147+
),
148+
pytest.param(
149+
{"resource_name": "model_version", "api_name": "model_versions", "data": STATE_LIVE},
150+
STATE_LIVE,
151+
id="test_validate_updated_version_state_unarchived",
152+
),
153+
pytest.param(
154+
{"resource_name": "model_version", "api_name": "model_versions", "data": CUSTOM_PROPERTY},
155+
CUSTOM_PROPERTY,
156+
id="test_validate_updated_version_custom_properties",
157+
),
158+
],
159+
indirect=["updated_model_registry_resource"],
160+
)
161+
def test_updated_model_version(
162+
self,
163+
updated_model_registry_resource: dict[str, Any],
164+
expected_param: dict[str, Any],
165+
):
166+
"""
167+
Update, [RHOAIENG-24371] archive, unarchive model versions and ensure the updated values
168+
are reflected on the model version
169+
"""
170+
validate_resource_attributes(
171+
expected_params=expected_param,
172+
actual_resource_data=updated_model_registry_resource,
173+
resource_name="model version",
174+
)
175+
176+
@pytest.mark.parametrize(
177+
"updated_model_registry_resource, expected_param",
178+
[
179+
pytest.param(
180+
{
181+
"resource_name": "register_model",
182+
"api_name": "registered_models",
183+
"data": REGISTERED_MODEL_DESCRIPTION,
184+
},
185+
REGISTERED_MODEL_DESCRIPTION,
186+
id="test_validate_updated_model_description",
187+
),
188+
pytest.param(
189+
{"resource_name": "register_model", "api_name": "registered_models", "data": STATE_ARCHIVED},
190+
STATE_ARCHIVED,
191+
id="test_validate_updated_model_state_archived",
192+
),
193+
pytest.param(
194+
{"resource_name": "register_model", "api_name": "registered_models", "data": STATE_LIVE},
195+
STATE_LIVE,
196+
id="test_validate_updated_model_state_unarchived",
197+
),
198+
pytest.param(
199+
{"resource_name": "register_model", "api_name": "registered_models", "data": CUSTOM_PROPERTY},
200+
CUSTOM_PROPERTY,
201+
id="test_validate_updated_registered_model_custom_properties",
202+
),
203+
],
204+
indirect=["updated_model_registry_resource"],
205+
)
206+
def test_updated_registered_model(
207+
self,
208+
updated_model_registry_resource: dict[str, Any],
96209
expected_param: dict[str, Any],
97210
):
98-
errors: list[dict[str, list[Any]]]
99-
if errors := [
100-
{key: [expected_param[key], updated_model_artifact.get(key)]}
101-
for key in expected_param.keys()
102-
if not updated_model_artifact.get(key) or updated_model_artifact[key] != expected_param[key]
103-
]:
104-
pytest.fail(f"Model did not get updated with expected values: {errors}")
105-
LOGGER.info(f"Successfully validated: {updated_model_artifact['name']}")
211+
"""
212+
Update, [RHOAIENG-24371] archive, unarchive registered models and ensure the updated values
213+
are reflected on the registered model
214+
"""
215+
validate_resource_attributes(
216+
expected_params=expected_param,
217+
actual_resource_data=updated_model_registry_resource,
218+
resource_name="registered model",
219+
)

tests/model_registry/rest_api/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from typing import Any
22
import requests
33
import json
4+
45
from simple_logger.logger import get_logger
56
from tests.model_registry.exceptions import (
67
ModelRegistryResourceNotCreated,
78
ModelRegistryResourceNotFoundError,
89
ModelRegistryResourceNotUpdated,
910
)
1011
from tests.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI
12+
from utilities.exceptions import ResourceValueMismatch
1113

1214
LOGGER = get_logger(name=__name__)
1315

@@ -87,3 +89,27 @@ def register_model_rest_api(
8789
f"associated artifact: {model_artifact}"
8890
)
8991
return {"register_model": register_model, "model_version": model_version, "model_artifact": model_artifact}
92+
93+
94+
def validate_resource_attributes(
95+
expected_params: dict[str, Any], actual_resource_data: dict[str, Any], resource_name: str
96+
) -> None:
97+
"""
98+
Validate that expected parameters match actual resource data.
99+
Args:
100+
expected_params: Dictionary of expected attribute values
101+
actual_resource_data: Dictionary of actual resource data from API
102+
resource_name: Name of the resource being validated for error messages
103+
104+
Raises:
105+
ResourceValueMismatch: When expected and actual values don't match
106+
107+
"""
108+
errors: list[dict[str, list[Any]]]
109+
if errors := [
110+
{key: [f"Expected value: {expected_params[key]}, actual value: {actual_resource_data.get(key)}"]}
111+
for key in expected_params.keys()
112+
if (not actual_resource_data.get(key) or actual_resource_data[key] != expected_params[key])
113+
]:
114+
raise ResourceValueMismatch(f"Resource: {resource_name} has mismatched data: {errors}")
115+
LOGGER.info(f"Successfully validated resource: {resource_name}: {actual_resource_data['name']}")

utilities/exceptions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,15 @@ class UnexpectedFailureError(Exception):
120120

121121
class UnexpectedResourceCountError(Exception):
122122
"""Unexpected number of API resources found"""
123+
124+
125+
class ResourceValueMismatch(Exception):
126+
"""Resource value mismatch"""
127+
128+
pass
129+
130+
131+
class MissingParameter(Exception):
132+
"""Raised required argument is not passed."""
133+
134+
pass

0 commit comments

Comments
 (0)