forked from opendatahub-io/opendatahub-tests
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
97 lines (85 loc) · 3.69 KB
/
utils.py
File metadata and controls
97 lines (85 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Any
import requests
import json
from simple_logger.logger import get_logger
from tests.model_registry.exceptions import (
ModelRegistryResourceNotCreated,
ModelRegistryResourceNotFoundError,
ModelRegistryResourceNotUpdated,
)
from tests.model_registry.rest_api.constants import MODEL_REGISTRY_BASE_URI
LOGGER = get_logger(name=__name__)
def execute_model_registry_patch_command(
url: str, headers: dict[str, str], data_json: dict[str, Any]
) -> dict[Any, Any]:
resp = requests.patch(url=url, json=data_json, headers=headers, verify=False, timeout=60)
LOGGER.info(f"url: {url}, status code: {resp.status_code}, rep: {resp.text}")
if resp.status_code != 200:
raise ModelRegistryResourceNotUpdated(
f"Failed to update ModelRegistry resource: {url}, {resp.status_code}: {resp.text}"
)
try:
return json.loads(resp.text)
except json.JSONDecodeError:
LOGGER.error(f"Unable to parse {resp.text}")
raise
def execute_model_registry_post_command(
url: str, headers: dict[str, str], data_json: dict[str, Any], verify: bool | str = False
) -> dict[Any, Any]:
resp = requests.post(url=url, json=data_json, headers=headers, verify=verify, timeout=60)
LOGGER.info(f"url: {url}, status code: {resp.status_code}, rep: {resp.text}")
if resp.status_code not in [200, 201]:
raise ModelRegistryResourceNotCreated(
f"Failed to create ModelRegistry resource: {url}, {resp.status_code}: {resp.text}"
)
try:
return json.loads(resp.text)
except json.JSONDecodeError:
LOGGER.error(f"Unable to parse {resp.text}")
raise
def execute_model_registry_get_command(url: str, headers: dict[str, str]) -> dict[Any, Any]: # skip-unused-code
resp = requests.get(url=url, headers=headers, verify=False)
LOGGER.info(f"url: {url}, status code: {resp.status_code}, rep: {resp.text}")
if resp.status_code not in [200, 201]:
raise ModelRegistryResourceNotFoundError(
f"Failed to get ModelRegistry resource: {url}, {resp.status_code}: {resp.text}"
)
try:
return json.loads(resp.text)
except json.JSONDecodeError:
LOGGER.error(f"Unable to parse {resp.text}")
raise
def register_model_rest_api(
model_registry_rest_url: str,
model_registry_rest_headers: dict[str, str],
data_dict: dict[str, Any],
verify: bool | str = False,
) -> dict[str, Any]:
# register a model
register_model = execute_model_registry_post_command(
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}registered_models",
headers=model_registry_rest_headers,
data_json=data_dict["register_model_data"],
verify=verify,
)
# create associated model version:
model_data = data_dict["model_version_data"]
model_data["registeredModelId"] = register_model["id"]
model_version = execute_model_registry_post_command(
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}model_versions",
headers=model_registry_rest_headers,
data_json=model_data,
verify=verify,
)
# create associated model artifact
model_artifact = execute_model_registry_post_command(
url=f"{model_registry_rest_url}{MODEL_REGISTRY_BASE_URI}model_versions/{model_version['id']}/artifacts",
headers=model_registry_rest_headers,
data_json=data_dict["model_artifact_data"],
verify=verify,
)
LOGGER.info(
f"Successfully registered model: {register_model}, with version: {model_version} and "
f"associated artifact: {model_artifact}"
)
return {"register_model": register_model, "model_version": model_version, "model_artifact": model_artifact}