Skip to content

Commit ff86bc8

Browse files
Add tests for model_artifact update validations (opendatahub-io#284)
* Add tests for model_artifact update validations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8dd9a42 commit ff86bc8

File tree

6 files changed

+283
-0
lines changed

6 files changed

+283
-0
lines changed

tests/model_registry/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class ModelRegistryResourceNotCreated(Exception):
2+
pass
3+
4+
5+
class ModelRegistryResourceNotFoundError(Exception):
6+
pass
7+
8+
9+
class ModelRegistryResourceNotUpdated(Exception):
10+
pass

tests/model_registry/rest_api/__init__.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
from tests.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI
6+
from tests.model_registry.rest_api.utils import register_model_rest_api, execute_model_registry_patch_command
7+
from utilities.constants import Protocols
8+
9+
10+
@pytest.fixture(scope="class")
11+
def model_registry_rest_url(model_registry_instance_rest_endpoint: str) -> str:
12+
# address and port need to be split in the client instantiation
13+
return f"{Protocols.HTTPS}://{model_registry_instance_rest_endpoint}"
14+
15+
16+
@pytest.fixture(scope="class")
17+
def model_registry_rest_headers(current_client_token: str) -> dict[str, str]:
18+
return {
19+
"Authorization": f"Bearer {current_client_token}",
20+
"accept": "application/json",
21+
"Content-Type": "application/json",
22+
}
23+
24+
25+
@pytest.fixture(scope="class")
26+
def registered_model_rest_api(
27+
request: pytest.FixtureRequest, model_registry_rest_url: str, model_registry_rest_headers: dict[str, str]
28+
) -> dict[str, Any]:
29+
return register_model_rest_api(
30+
model_registry_rest_url=model_registry_rest_url,
31+
model_registry_rest_headers=model_registry_rest_headers,
32+
data_dict=request.param,
33+
)
34+
35+
36+
@pytest.fixture(scope="class")
37+
def updated_model_artifact(
38+
request: pytest.FixtureRequest,
39+
model_registry_rest_url: str,
40+
model_registry_rest_headers: dict[str, str],
41+
registered_model_rest_api: dict[str, Any],
42+
) -> dict[str, Any]:
43+
model_artifact_id = registered_model_rest_api["model_artifact"]["id"]
44+
assert model_artifact_id, f"Model artifact id not found: {registered_model_rest_api['model_artifact']}"
45+
return execute_model_registry_patch_command(
46+
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}model_artifacts/{model_artifact_id}",
47+
headers=model_registry_rest_headers,
48+
data_json=request.param,
49+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Any
2+
3+
from utilities.constants import ModelFormat
4+
5+
MODEL_REGISTER: dict[str, str] = {
6+
"name": "model-rest-api",
7+
"description": "Model created via rest call",
8+
"owner": "opendatahub-tests",
9+
}
10+
MODEL_VERSION: dict[str, Any] = {
11+
"name": "v0.0.1",
12+
"state": "LIVE",
13+
"author": "opendatahub-tests",
14+
"description": "Model version created via rest call",
15+
}
16+
17+
MODEL_ARTIFACT: dict[str, Any] = {
18+
"name": "model-artifact-rest-api",
19+
"description": "Model artifact created via rest call",
20+
"uri": "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx",
21+
"state": "UNKNOWN",
22+
"modelFormatName": ModelFormat.ONNX,
23+
"modelFormatVersion": "v1",
24+
"artifactType": "model-artifact",
25+
}
26+
MODEL_REGISTER_DATA = {
27+
"register_model_data": MODEL_REGISTER,
28+
"model_version_data": MODEL_VERSION,
29+
"model_artifact_data": MODEL_ARTIFACT,
30+
}
31+
MODEL_REGISTRY_BASE_URI = "/api/model_registry/v1alpha3/"
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from typing import Self, Any
2+
import pytest
3+
from tests.model_registry.rest_api.constants import MODEL_REGISTER, MODEL_ARTIFACT, MODEL_VERSION, MODEL_REGISTER_DATA
4+
from utilities.constants import DscComponents
5+
from tests.model_registry.constants import MR_NAMESPACE
6+
from simple_logger.logger import get_logger
7+
8+
LOGGER = get_logger(name=__name__)
9+
10+
11+
@pytest.mark.parametrize(
12+
"updated_dsc_component_state_scope_class, registered_model_rest_api",
13+
[
14+
pytest.param(
15+
{
16+
"component_patch": {
17+
DscComponents.MODELREGISTRY: {
18+
"managementState": DscComponents.ManagementState.MANAGED,
19+
"registriesNamespace": MR_NAMESPACE,
20+
},
21+
},
22+
},
23+
MODEL_REGISTER_DATA,
24+
),
25+
],
26+
indirect=True,
27+
)
28+
@pytest.mark.usefixtures("updated_dsc_component_state_scope_class", "registered_model_rest_api")
29+
class TestModelRegistryCreationRest:
30+
"""
31+
Tests the creation of a model registry. If the component is set to 'Removed' it will be switched to 'Managed'
32+
for the duration of this test module.
33+
"""
34+
35+
@pytest.mark.parametrize(
36+
"expected_params, data_key",
37+
[
38+
pytest.param(
39+
MODEL_REGISTER,
40+
"register_model",
41+
id="test_validate_registered_model",
42+
),
43+
pytest.param(
44+
MODEL_VERSION,
45+
"model_version",
46+
id="test_validate_model_version",
47+
),
48+
pytest.param(
49+
MODEL_ARTIFACT,
50+
"model_artifact",
51+
id="test_validate_model_artifact",
52+
),
53+
],
54+
)
55+
def test_validate_model_registry_resource(
56+
self: Self,
57+
registered_model_rest_api: dict[str, Any],
58+
expected_params: dict[str, str],
59+
data_key: str,
60+
):
61+
errors: list[str:Any]
62+
created_resource_data = registered_model_rest_api[data_key]
63+
if errors := [
64+
{key: [expected_params[key], created_resource_data.get(key)]}
65+
for key in expected_params.keys()
66+
if not created_resource_data.get(key) or created_resource_data[key] != expected_params[key]
67+
]:
68+
pytest.fail(f"Model did not get created with expected values: {errors}")
69+
LOGGER.info(f"Successfully validated: {created_resource_data['name']}")
70+
71+
@pytest.mark.parametrize(
72+
"updated_model_artifact, expected_param",
73+
[
74+
pytest.param(
75+
{"description": "updated description"},
76+
{"description": "updated description"},
77+
id="test_validate_updated_artifact_description",
78+
),
79+
pytest.param(
80+
{"modelFormatName": "tensorflow"},
81+
{"modelFormatName": "tensorflow"},
82+
id="test_validate_updated_artifact_model_format_name",
83+
),
84+
pytest.param(
85+
{"modelFormatVersion": "v2"},
86+
{"modelFormatVersion": "v2"},
87+
id="test_validate_updated_artifact_model_format_version",
88+
),
89+
],
90+
indirect=["updated_model_artifact"],
91+
)
92+
def test_create_update_model_artifact(
93+
self,
94+
updated_model_artifact: dict[str, Any],
95+
expected_param: dict[str, Any],
96+
):
97+
errors: list[dict[str, list[Any]]]
98+
if errors := [
99+
{key: [expected_param[key], updated_model_artifact.get(key)]}
100+
for key in expected_param.keys()
101+
if not updated_model_artifact.get(key) or updated_model_artifact[key] != expected_param[key]
102+
]:
103+
pytest.fail(f"Model did not get updated with expected values: {errors}")
104+
LOGGER.info(f"Successfully validated: {updated_model_artifact['name']}")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Any
2+
import requests
3+
import json
4+
from simple_logger.logger import get_logger
5+
from tests.model_registry.exceptions import (
6+
ModelRegistryResourceNotCreated,
7+
ModelRegistryResourceNotFoundError,
8+
ModelRegistryResourceNotUpdated,
9+
)
10+
from tests.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI
11+
12+
LOGGER = get_logger(name=__name__)
13+
14+
15+
def execute_model_registry_patch_command(
16+
url: str, headers: dict[str, str], data_json: dict[str, Any]
17+
) -> dict[Any, Any]:
18+
resp = requests.patch(url=url, json=data_json, headers=headers, verify=False, timeout=60)
19+
LOGGER.info(f"url: {url}, status code: {resp.status_code}, rep: {resp.text}")
20+
21+
if resp.status_code != 200:
22+
raise ModelRegistryResourceNotUpdated(
23+
f"Failed to update ModelRegistry resource: {url}, {resp.status_code}: {resp.text}"
24+
)
25+
try:
26+
return json.loads(resp.text)
27+
except json.JSONDecodeError:
28+
LOGGER.error(f"Unable to parse {resp.text}")
29+
raise
30+
31+
32+
def execute_model_registry_post_command(url: str, headers: dict[str, str], data_json: dict[str, Any]) -> dict[Any, Any]:
33+
resp = requests.post(url=url, json=data_json, headers=headers, verify=False, timeout=60)
34+
LOGGER.info(f"url: {url}, status code: {resp.status_code}, rep: {resp.text}")
35+
36+
if resp.status_code not in [200, 201]:
37+
raise ModelRegistryResourceNotCreated(
38+
f"Failed to create ModelRegistry resource: {url}, {resp.status_code}: {resp.text}"
39+
)
40+
try:
41+
return json.loads(resp.text)
42+
except json.JSONDecodeError:
43+
LOGGER.error(f"Unable to parse {resp.text}")
44+
raise
45+
46+
47+
def execute_model_registry_get_command(url: str, headers: dict[str, str]) -> dict[Any, Any]: # skip-unused-code
48+
resp = requests.get(url=url, headers=headers, verify=False)
49+
LOGGER.info(f"url: {url}, status code: {resp.status_code}, rep: {resp.text}")
50+
if resp.status_code not in [200, 201]:
51+
raise ModelRegistryResourceNotFoundError(
52+
f"Failed to get ModelRegistry resource: {url}, {resp.status_code}: {resp.text}"
53+
)
54+
55+
try:
56+
return json.loads(resp.text)
57+
except json.JSONDecodeError:
58+
LOGGER.error(f"Unable to parse {resp.text}")
59+
raise
60+
61+
62+
def register_model_rest_api(
63+
model_registry_rest_url: str, model_registry_rest_headers: dict[str, str], data_dict: dict[str, Any]
64+
) -> dict[str, Any]:
65+
# register a model
66+
register_model = execute_model_registry_post_command(
67+
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}registered_models",
68+
headers=model_registry_rest_headers,
69+
data_json=data_dict["register_model_data"],
70+
)
71+
# create associated model version:
72+
model_data = data_dict["model_version_data"]
73+
model_data["registeredModelId"] = register_model["id"]
74+
model_version = execute_model_registry_post_command(
75+
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}model_versions",
76+
headers=model_registry_rest_headers,
77+
data_json=model_data,
78+
)
79+
# create associated model artifact
80+
model_artifact = execute_model_registry_post_command(
81+
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}model_versions/{model_version['id']}/artifacts",
82+
headers=model_registry_rest_headers,
83+
data_json=data_dict["model_artifact_data"],
84+
)
85+
LOGGER.info(
86+
f"Successfully registered model: {register_model}, with version: {model_version} and "
87+
f"associated artifact: {model_artifact}"
88+
)
89+
return {"register_model": register_model, "model_version": model_version, "model_artifact": model_artifact}

0 commit comments

Comments
 (0)