forked from kubeflow/hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_api.py
More file actions
117 lines (99 loc) · 4.34 KB
/
test_api.py
File metadata and controls
117 lines (99 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import secrets
from typing import Callable
import pytest
import requests
import schemathesis
from hypothesis import HealthCheck, settings
from .conftest import REGISTRY_URL
from .constants import ARTIFACT_STATES, ARTIFACT_TYPE_PARAMS, DEFAULT_API_TIMEOUT
schema = schemathesis.pytest.from_fixture("generated_schema")
@pytest.fixture
def auth_headers(setup_env_user_token):
"""Provides authorization headers for API requests."""
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {setup_env_user_token}"
}
schema = (
schema
.exclude(
path="/api/model_registry/v1alpha3/artifacts/{id}",
method="PATCH"
)
.exclude(
path="/api/model_registry/v1alpha3/model_versions/{modelversionId}/artifacts",
method="POST"
)
)
@schema.parametrize()
@settings(
max_examples=100,
deadline=None,
suppress_health_check=[
HealthCheck.filter_too_much,
HealthCheck.too_slow,
HealthCheck.data_too_large,
],
)
@pytest.mark.fuzz
def test_mr_api_stateless(auth_headers: dict, case: schemathesis.Case):
"""Test the Model Registry API endpoints.
This test uses schemathesis to generate and validate API requests
"""
case.call_and_validate(headers=auth_headers)
@pytest.mark.fuzz
@pytest.mark.parametrize(("artifact_type", "uri_prefix"), ARTIFACT_TYPE_PARAMS)
@pytest.mark.parametrize("state", ARTIFACT_STATES)
def test_post_model_version_artifacts(auth_headers: dict, artifact_type: str, uri_prefix: str, state: str, cleanup_artifacts: Callable):
"""
Direct test for POST /api/model_registry/v1alpha3/model_versions/{modelversionId}/artifacts.
"""
model_version_id = str(secrets.randbelow(2000000000 - 100000 + 1) + 100000)
endpoint = f"{REGISTRY_URL}/api/model_registry/v1alpha3/model_versions/{model_version_id}/artifacts"
payload = {
"artifactType": artifact_type,
"name": "my-test-model-artifact-post",
"uri": f"{uri_prefix}my-test-model.pkl",
"state": state,
"description": "A test model artifact created via direct POST test.",
"externalId": str(secrets.randbelow(2000000000 - 100000 + 1) + 100000)
}
response = requests.post(endpoint, headers=auth_headers, json=payload, timeout=DEFAULT_API_TIMEOUT)
artifact_id = response.json()["id"]
cleanup_artifacts(artifact_id)
assert response.status_code in {200, 201}, f"Expected 200 or 201, got {response.status_code}: {response.text}"
response_json = response.json()
assert response_json.get("id"), "Response body should contain 'id'"
assert response_json.get("name") == payload["name"], "Response name should match payload name"
assert response_json.get("artifactType") == payload["artifactType"], "Response artifactType should match payload"
@pytest.mark.fuzz
@pytest.mark.parametrize(("artifact_type", "uri_prefix"), ARTIFACT_TYPE_PARAMS)
def test_patch_artifact(auth_headers: dict, artifact_resource: Callable, artifact_type: str, uri_prefix: str):
"""
Direct test for PATCH /api/model_registry/v1alpha3/artifacts/{id}.
"""
initial_state = "PENDING"
target_state = "LIVE"
create_payload = {
"artifactType": artifact_type,
"name": "test-create-for-patch",
"uri": "s3://my-test-bucket/models/initial-model.pkl",
"state": initial_state,
}
if artifact_type == "model-artifact":
create_payload["modelFormatName"] = "tensorflow"
create_payload["modelFormatVersion"] = "1.0"
with artifact_resource(auth_headers, create_payload) as artifact_id:
patch_endpoint = f"{REGISTRY_URL}/api/model_registry/v1alpha3/artifacts/{artifact_id}"
patch_payload = {
"artifactType": artifact_type,
"description": f"Updated description for {artifact_type} ({target_state})",
"state": target_state,
}
patch_response = requests.patch(patch_endpoint, headers=auth_headers, json=patch_payload, timeout=DEFAULT_API_TIMEOUT)
assert patch_response.status_code == 200
patch_response_json = patch_response.json()
assert patch_response_json.get("id") == artifact_id
assert patch_response_json.get("description") == patch_payload["description"]
assert patch_response_json.get("state") == patch_payload["state"]
assert patch_response_json.get("artifactType") == artifact_type