Skip to content

Commit 5b89ab6

Browse files
committed
feat: add canary and concurency tests
1 parent e4fad15 commit 5b89ab6

File tree

7 files changed

+182
-31
lines changed

7 files changed

+182
-31
lines changed

tests/model_serving/model_server/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ def s3_models_inference_service(
148148
if (enable_auth := request.param.get("enable-auth")) is not None:
149149
isvc_kwargs["enable_auth"] = enable_auth
150150

151+
if (scale_metric := request.param.get("scale-metric")) is not None:
152+
isvc_kwargs["scale_metric"] = scale_metric
153+
154+
if (scale_target := request.param.get("scale-target")) is not None:
155+
isvc_kwargs["scale_target"] = scale_target
156+
151157
with create_isvc(**isvc_kwargs) as isvc:
152158
yield isvc
153159

tests/model_serving/model_server/serverless/conftest.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,39 @@
66
from ocp_resources.resource import ResourceEditor
77

88
from tests.model_serving.model_server.serverless.utils import wait_for_canary_rollout
9+
from tests.model_serving.model_server.utils import run_inference_multiple_times
10+
from utilities.constants import ModelFormat, Protocols
11+
from utilities.inference_utils import Inference
12+
from utilities.manifests.caikit_tgis import CAIKIT_TGIS_INFERENCE_CONFIG
913

1014

1115
@pytest.fixture
1216
def inference_service_updated_canary_config(
1317
request: FixtureRequest, s3_models_inference_service: InferenceService
1418
) -> Generator[InferenceService, Any, Any]:
15-
percent = request.param("canary-traffic-percent")
19+
canary_percent = request.param["canary-traffic-percent"]
1620
predictor_config = {
1721
"spec": {
18-
"predictor": {"canaryTrafficPercent": percent},
22+
"predictor": {"canaryTrafficPercent": canary_percent},
1923
}
2024
}
2125

2226
if model_path := request.param.get("model-path"):
23-
predictor_config["spec"]["predictor"]["model"]["storage_path"] = model_path
27+
predictor_config["spec"]["predictor"]["model"] = {"storage": {"path": model_path}}
2428

2529
with ResourceEditor(patches={s3_models_inference_service: predictor_config}):
26-
wait_for_canary_rollout(isvc=s3_models_inference_service, percentage=percent)
30+
wait_for_canary_rollout(isvc=s3_models_inference_service, percentage=canary_percent)
2731
yield s3_models_inference_service
32+
33+
34+
@pytest.fixture
35+
def multiple_tgis_inference_requests(s3_models_inference_service: InferenceService) -> None:
36+
run_inference_multiple_times(
37+
isvc=s3_models_inference_service,
38+
inference_config=CAIKIT_TGIS_INFERENCE_CONFIG,
39+
inference_type=Inference.ALL_TOKENS,
40+
protocol=Protocols.HTTPS,
41+
model_name=ModelFormat.CAIKIT,
42+
iterations=50,
43+
run_in_parallel=True,
44+
)

tests/model_serving/model_server/serverless/test_canary_rollout.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_serverless_before_model_update(
5959
"inference_service_updated_canary_config",
6060
[
6161
pytest.param(
62-
{"canary-traffic-percent": "10", "model-path": ModelStoragePath.FLAN_T5_SMALL_HF},
62+
{"canary-traffic-percent": 30, "model-path": ModelStoragePath.FLAN_T5_SMALL_HF},
6363
)
6464
],
6565
indirect=True,
@@ -72,15 +72,16 @@ def test_serverless_during_canary_rollout(self, inference_service_updated_canary
7272
model_name=ModelAndFormat.FLAN_T5_SMALL_CAIKIT,
7373
inference_type=Inference.ALL_TOKENS,
7474
protocol=Protocols.GRPC,
75-
iterations=5,
76-
percentage=10,
75+
iterations=20,
76+
expected_percentage=30,
77+
tolerance=10,
7778
)
7879

7980
@pytest.mark.parametrize(
8081
"inference_service_updated_canary_config",
8182
[
8283
pytest.param(
83-
{"canary-traffic-percent": "100"},
84+
{"canary-traffic-percent": 100},
8485
)
8586
],
8687
indirect=True,
@@ -94,5 +95,5 @@ def test_serverless_after_canary_rollout(self, inference_service_updated_canary_
9495
inference_type=Inference.ALL_TOKENS,
9596
protocol=Protocols.GRPC,
9697
iterations=5,
97-
percentage=100,
98+
expected_percentage=100,
9899
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
3+
from tests.model_serving.model_server.serverless.utils import (
4+
inference_service_pods_sampler,
5+
)
6+
from utilities.constants import (
7+
KServeDeploymentType,
8+
ModelFormat,
9+
ModelInferenceRuntime,
10+
ModelStoragePath,
11+
RuntimeTemplates,
12+
Timeout,
13+
)
14+
15+
pytestmark = [
16+
pytest.mark.serverless,
17+
pytest.mark.sanity,
18+
pytest.mark.usefixtures("valid_aws_config"),
19+
]
20+
21+
22+
@pytest.mark.parametrize(
23+
"model_namespace, serving_runtime_from_template, s3_models_inference_service",
24+
[
25+
pytest.param(
26+
{"name": "serverless-auto-scale"},
27+
{
28+
"name": f"{ModelInferenceRuntime.CAIKIT_TGIS_RUNTIME}",
29+
"template-name": RuntimeTemplates.CAIKIT_TGIS_SERVING,
30+
"multi-model": False,
31+
"enable-http": True,
32+
},
33+
{
34+
"name": f"{ModelFormat.CAIKIT}-auto-scale",
35+
"deployment-mode": KServeDeploymentType.SERVERLESS,
36+
"model-dir": ModelStoragePath.FLAN_T5_SMALL_CAIKIT,
37+
"scale-metric": "concurrency",
38+
"scale-target": 1,
39+
},
40+
)
41+
],
42+
indirect=True,
43+
)
44+
class TestConcurrencyAutoScale:
45+
@pytest.mark.dependency(name="test_auto_scale_using_concurrency")
46+
def test_auto_scale_using_concurrency(
47+
self,
48+
admin_client,
49+
s3_models_inference_service,
50+
multiple_tgis_inference_requests,
51+
):
52+
"""Verify model is successfully scaled up based on concurrency metrics (KPA)"""
53+
for pods in inference_service_pods_sampler(
54+
client=admin_client,
55+
isvc=s3_models_inference_service,
56+
timeout=Timeout.TIMEOUT_1MIN,
57+
):
58+
if pods:
59+
if len(pods) > 1 and all([pod.status == pod.Status.RUNNING for pod in pods]):
60+
return
61+
62+
@pytest.mark.dependency(requires=["test_auto_scale_using_concurrency"])
63+
def test_pods_scaled_down_when_no_requests(self, admin_client, s3_models_inference_service):
64+
"""Verify auto-scaled pods are deleted when there are no inference requests"""
65+
for pods in inference_service_pods_sampler(
66+
client=admin_client,
67+
isvc=s3_models_inference_service,
68+
timeout=Timeout.TIMEOUT_4MIN,
69+
):
70+
if pods and len(pods) == 1:
71+
return

tests/model_serving/model_server/serverless/utils.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
from __future__ import annotations
2+
13
from typing import Any
24

5+
from kubernetes.dynamic import DynamicClient
36
from ocp_resources.inference_service import InferenceService
47
from simple_logger.logger import get_logger
58
from timeout_sampler import TimeoutExpiredError, TimeoutSampler
69

710
from tests.model_serving.model_server.utils import verify_inference_response
811
from utilities.constants import Timeout
912
from utilities.exceptions import InferenceCanaryTrafficError
13+
from utilities.infra import get_pods_by_isvc_label
1014

1115
LOGGER = get_logger(name=__name__)
1216

@@ -51,7 +55,8 @@ def verify_canary_traffic(
5155
protocol: str,
5256
model_name: str,
5357
iterations: int,
54-
percentage: int,
58+
expected_percentage: int,
59+
tolerance: int = 0,
5560
) -> None:
5661
"""
5762
Verify canary traffic percentage against inference_config.
@@ -63,15 +68,17 @@ def verify_canary_traffic(
6368
protocol (str): Protocol.
6469
model_name (str): Model name.
6570
iterations (int): Number of iterations.
66-
percentage (int): Percentage of canary rollout.
71+
expected_percentage (int): Percentage of canary rollout.
72+
tolerance (int): Tolerance of traffic percentage distribution;
73+
difference between actual and expected percentage.
6774
6875
Raises:
6976
InferenceCanaryTrafficError: If canary rollout is not updated
7077
7178
"""
7279
successful_inferences = 0
7380

74-
for _ in range(iterations):
81+
for iteration in range(iterations):
7582
try:
7683
verify_inference_response(
7784
inference_service=isvc,
@@ -81,16 +88,42 @@ def verify_canary_traffic(
8188
model_name=model_name,
8289
use_default_query=True,
8390
)
91+
LOGGER.info(f"Successful inference. Iteration: {iteration + 1}")
8492

8593
successful_inferences += 1
8694

87-
except Exception:
88-
continue
95+
except Exception as ex:
96+
LOGGER.warning(f"Inference failed. Error: {ex}. Previous model was used.")
8997

98+
LOGGER.info(f"Number of inference requests to the new model: {successful_inferences}")
9099
successful_inferences_percentage = successful_inferences / iterations * 100
91100

92-
if successful_inferences_percentage != percentage:
101+
diff_percentage = abs(expected_percentage - successful_inferences_percentage)
102+
103+
if successful_inferences == 0 or diff_percentage > tolerance:
93104
raise InferenceCanaryTrafficError(
94105
f"Percentage of inference requests {successful_inferences_percentage} "
95-
f"to the new model does not match the expected percentage {percentage}. "
106+
f"to the new model does not match the expected percentage {expected_percentage}. "
96107
)
108+
109+
110+
def inference_service_pods_sampler(client: DynamicClient, isvc: InferenceService, timeout: int) -> TimeoutSampler:
111+
"""
112+
Returns TimeoutSampler for inference service.
113+
114+
Args:
115+
client (DynamicClient): DynamicClient object
116+
isvc (InferenceService): InferenceService object
117+
timeout (int): Timeout in seconds
118+
119+
Returns:
120+
TimeoutSampler: TimeoutSampler object
121+
122+
"""
123+
return TimeoutSampler(
124+
wait_timeout=timeout,
125+
sleep=1,
126+
func=get_pods_by_isvc_label,
127+
client=client,
128+
isvc=isvc,
129+
)

tests/model_serving/model_server/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,25 +132,32 @@ def verify_inference_response(
132132
res[inference.inference_response_key_name],
133133
re.MULTILINE,
134134
):
135-
assert "".join(output) == expected_response_text
135+
assert "".join(output) == expected_response_text, (
136+
f"Expected: {expected_response_text} does not match response: {output}"
137+
)
136138

137139
elif inference_type == inference.INFER or use_regex:
138140
formatted_res = json.dumps(res[inference.inference_response_text_key_name]).replace(" ", "")
139141
if use_regex:
140-
assert re.search(expected_response_text, formatted_res) # type: ignore[arg-type] # noqa: E501
142+
assert re.search(expected_response_text, formatted_res), ( # type: ignore[arg-type] # noqa: E501
143+
f"Expected: {expected_response_text} not found in: {formatted_res}"
144+
)
141145

142146
else:
143-
assert (
144-
json.dumps(res[inference.inference_response_key_name]).replace(" ", "")
145-
== expected_response_text
147+
formatted_res = json.dumps(res[inference.inference_response_key_name]).replace(" ", "")
148+
assert formatted_res == expected_response_text, (
149+
f"Expected: {expected_response_text} does not match output: {formatted_res}"
146150
)
147151

148152
else:
149153
response = res[inference.inference_response_key_name]
150154
if isinstance(response, list):
151155
response = response[0]
152156

153-
assert response[inference.inference_response_text_key_name] == expected_response_text
157+
response_text = response[inference.inference_response_text_key_name]
158+
assert response_text == expected_response_text, (
159+
f"Expected: {expected_response_text} does not mathc response: {response_text}"
160+
)
154161

155162
else:
156163
raise InferenceResponseError(f"Inference response output not found in response. Response: {res}")

utilities/inference_utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
)
3737
import portforward
3838

39-
from utilities.jira import is_jira_open
40-
4139
LOGGER = get_logger(name=__name__)
4240

4341

@@ -65,12 +63,12 @@ def get_deployment_type(self) -> str:
6563
Returns:
6664
deployment type
6765
"""
68-
deployment_type = self.inference_service.instance.metadata.annotations.get("serving.kserve.io/deploymentMode")
69-
70-
if is_jira_open(jira_id="RHOAIENG-16954", admin_client=get_client()) and not deployment_type:
71-
return KServeDeploymentType.SERVERLESS
66+
if deployment_type := self.inference_service.instance.metadata.annotations.get(
67+
"serving.kserve.io/deploymentMode"
68+
):
69+
return deployment_type
7270

73-
return deployment_type
71+
return self.inference_service.instance.status.deploymentMode
7472

7573
def get_inference_url(self) -> str:
7674
"""
@@ -524,6 +522,8 @@ def create_isvc(
524522
autoscaler_mode: str | None = None,
525523
multi_node_worker_spec: dict[str, int] | None = None,
526524
timeout: int = Timeout.TIMEOUT_15MIN,
525+
scale_metric: str | None = None,
526+
scale_target: int | None = None,
527527
) -> Generator[InferenceService, Any, Any]:
528528
"""
529529
Create InferenceService object.
@@ -553,6 +553,8 @@ def create_isvc(
553553
multi_node_worker_spec (dict[str, int]): Multi node worker spec
554554
wait_for_predictor_pods (bool): Wait for predictor pods
555555
timeout (int): Time to wait for the model inference,deployment to be ready
556+
scale_metric (str): Scale metric
557+
scale_target (int): Scale target
556558
557559
Yields:
558560
InferenceService: InferenceService object
@@ -625,6 +627,12 @@ def create_isvc(
625627
if multi_node_worker_spec is not None:
626628
predictor_dict["workerSpec"] = multi_node_worker_spec
627629

630+
if scale_metric is not None:
631+
predictor_dict["scaleMetric"] = scale_metric
632+
633+
if scale_target is not None:
634+
predictor_dict["scaleTarget"] = scale_target
635+
628636
with InferenceService(
629637
client=client,
630638
name=name,
@@ -634,9 +642,17 @@ def create_isvc(
634642
label=labels,
635643
) as inference_service:
636644
if wait_for_predictor_pods:
637-
verify_no_failed_pods(client=client, isvc=inference_service, runtime_name=runtime, timeout=timeout)
645+
verify_no_failed_pods(
646+
client=client,
647+
isvc=inference_service,
648+
runtime_name=runtime,
649+
timeout=timeout,
650+
)
638651
wait_for_inference_deployment_replicas(
639-
client=client, isvc=inference_service, runtime_name=runtime, timeout=timeout
652+
client=client,
653+
isvc=inference_service,
654+
runtime_name=runtime,
655+
timeout=timeout,
640656
)
641657

642658
if wait:

0 commit comments

Comments
 (0)