Skip to content

Commit 7edab8d

Browse files
authored
[model server] Add Canary rollout and concurrency scaling tests (#170)
* Create size-labeler.yml * Delete .github/workflows/size-labeler.yml * model mesh - add auth tests * xx * feat: test serverless canary rollout * feat: add canary and concurency tests
1 parent 38fc0de commit 7edab8d

File tree

10 files changed

+402
-15
lines changed

10 files changed

+402
-15
lines changed

tests/model_serving/model_server/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ def s3_models_inference_service(
148148
if (enable_auth := request.param.get("enable-auth")) is not None:
149149
isvc_kwargs["enable_auth"] = enable_auth
150150

151+
if (scale_metric := request.param.get("scale-metric")) is not None:
152+
isvc_kwargs["scale_metric"] = scale_metric
153+
154+
if (scale_target := request.param.get("scale-target")) is not None:
155+
isvc_kwargs["scale_target"] = scale_target
156+
151157
with create_isvc(**isvc_kwargs) as isvc:
152158
yield isvc
153159

tests/model_serving/model_server/serverless/conftest.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
from typing import Any, Generator
2+
13
import pytest
24
from _pytest.fixtures import FixtureRequest
35
from ocp_resources.inference_service import InferenceService
46
from ocp_resources.resource import ResourceEditor
57

8+
from tests.model_serving.model_server.serverless.utils import wait_for_canary_rollout
9+
from tests.model_serving.model_server.utils import run_inference_multiple_times
10+
from utilities.constants import ModelFormat, Protocols
11+
from utilities.inference_utils import Inference
12+
from utilities.manifests.caikit_tgis import CAIKIT_TGIS_INFERENCE_CONFIG
13+
614

715
@pytest.fixture(scope="class")
816
def inference_service_patched_replicas(
@@ -19,3 +27,35 @@ def inference_service_patched_replicas(
1927
).update()
2028

2129
return ovms_serverless_inference_service
30+
31+
32+
@pytest.fixture
33+
def inference_service_updated_canary_config(
34+
request: FixtureRequest, s3_models_inference_service: InferenceService
35+
) -> Generator[InferenceService, Any, Any]:
36+
canary_percent = request.param["canary-traffic-percent"]
37+
predictor_config = {
38+
"spec": {
39+
"predictor": {"canaryTrafficPercent": canary_percent},
40+
}
41+
}
42+
43+
if model_path := request.param.get("model-path"):
44+
predictor_config["spec"]["predictor"]["model"] = {"storage": {"path": model_path}}
45+
46+
with ResourceEditor(patches={s3_models_inference_service: predictor_config}):
47+
wait_for_canary_rollout(isvc=s3_models_inference_service, percentage=canary_percent)
48+
yield s3_models_inference_service
49+
50+
51+
@pytest.fixture
52+
def multiple_tgis_inference_requests(s3_models_inference_service: InferenceService) -> None:
53+
run_inference_multiple_times(
54+
isvc=s3_models_inference_service,
55+
inference_config=CAIKIT_TGIS_INFERENCE_CONFIG,
56+
inference_type=Inference.ALL_TOKENS,
57+
protocol=Protocols.HTTPS,
58+
model_name=ModelFormat.CAIKIT,
59+
iterations=50,
60+
run_in_parallel=True,
61+
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import pytest
2+
3+
from tests.model_serving.model_server.serverless.utils import verify_canary_traffic
4+
from tests.model_serving.model_server.utils import verify_inference_response
5+
from utilities.constants import (
6+
KServeDeploymentType,
7+
ModelAndFormat,
8+
ModelName,
9+
ModelStoragePath,
10+
Protocols,
11+
RuntimeTemplates,
12+
)
13+
from utilities.inference_utils import Inference
14+
from utilities.manifests.pytorch import PYTORCH_TGIS_INFERENCE_CONFIG
15+
from utilities.manifests.tgis_grpc import TGIS_INFERENCE_CONFIG
16+
17+
pytestmark = [pytest.mark.serverless, pytest.mark.sanity]
18+
19+
20+
@pytest.mark.polarion("ODS-2371")
21+
@pytest.mark.parametrize(
22+
"model_namespace, serving_runtime_from_template, s3_models_inference_service",
23+
[
24+
pytest.param(
25+
{"name": "serverless-canary-rollout"},
26+
{
27+
"name": "tgis-runtime",
28+
"template-name": RuntimeTemplates.TGIS_GRPC_SERVING,
29+
"multi-model": False,
30+
"enable-http": False,
31+
"enable-grpc": True,
32+
},
33+
{
34+
"name": f"{ModelName.BLOOM_560M}-model",
35+
"deployment-mode": KServeDeploymentType.SERVERLESS,
36+
"model-dir": f"{ModelStoragePath.BLOOM_560M_CAIKIT}/artifacts",
37+
"external-route": True,
38+
},
39+
)
40+
],
41+
indirect=True,
42+
)
43+
class TestServerlessCanaryRollout:
44+
def test_serverless_before_model_update(
45+
self,
46+
s3_models_inference_service,
47+
):
48+
"""Test inference with Bloom before model is updated."""
49+
verify_inference_response(
50+
inference_service=s3_models_inference_service,
51+
inference_config=PYTORCH_TGIS_INFERENCE_CONFIG,
52+
inference_type=Inference.ALL_TOKENS,
53+
protocol=Protocols.GRPC,
54+
model_name=ModelAndFormat.BLOOM_560M_CAIKIT,
55+
use_default_query=True,
56+
)
57+
58+
@pytest.mark.parametrize(
59+
"inference_service_updated_canary_config",
60+
[
61+
pytest.param(
62+
{"canary-traffic-percent": 30, "model-path": ModelStoragePath.FLAN_T5_SMALL_HF},
63+
)
64+
],
65+
indirect=True,
66+
)
67+
def test_serverless_during_canary_rollout(self, inference_service_updated_canary_config):
68+
"""Test inference during canary rollout"""
69+
verify_canary_traffic(
70+
isvc=inference_service_updated_canary_config,
71+
inference_config=TGIS_INFERENCE_CONFIG,
72+
model_name=ModelAndFormat.FLAN_T5_SMALL_CAIKIT,
73+
inference_type=Inference.ALL_TOKENS,
74+
protocol=Protocols.GRPC,
75+
iterations=20,
76+
expected_percentage=30,
77+
tolerance=10,
78+
)
79+
80+
@pytest.mark.parametrize(
81+
"inference_service_updated_canary_config",
82+
[
83+
pytest.param(
84+
{"canary-traffic-percent": 100},
85+
)
86+
],
87+
indirect=True,
88+
)
89+
def test_serverless_after_canary_rollout(self, inference_service_updated_canary_config):
90+
"""Test inference after canary rollout"""
91+
verify_canary_traffic(
92+
isvc=inference_service_updated_canary_config,
93+
inference_config=TGIS_INFERENCE_CONFIG,
94+
model_name=ModelAndFormat.FLAN_T5_SMALL_CAIKIT,
95+
inference_type=Inference.ALL_TOKENS,
96+
protocol=Protocols.GRPC,
97+
iterations=5,
98+
expected_percentage=100,
99+
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
3+
from tests.model_serving.model_server.serverless.utils import (
4+
inference_service_pods_sampler,
5+
)
6+
from utilities.constants import (
7+
KServeDeploymentType,
8+
ModelFormat,
9+
ModelInferenceRuntime,
10+
ModelStoragePath,
11+
RuntimeTemplates,
12+
Timeout,
13+
)
14+
15+
pytestmark = [
16+
pytest.mark.serverless,
17+
pytest.mark.sanity,
18+
pytest.mark.usefixtures("valid_aws_config"),
19+
]
20+
21+
22+
@pytest.mark.parametrize(
23+
"model_namespace, serving_runtime_from_template, s3_models_inference_service",
24+
[
25+
pytest.param(
26+
{"name": "serverless-auto-scale"},
27+
{
28+
"name": f"{ModelInferenceRuntime.CAIKIT_TGIS_RUNTIME}",
29+
"template-name": RuntimeTemplates.CAIKIT_TGIS_SERVING,
30+
"multi-model": False,
31+
"enable-http": True,
32+
},
33+
{
34+
"name": f"{ModelFormat.CAIKIT}-auto-scale",
35+
"deployment-mode": KServeDeploymentType.SERVERLESS,
36+
"model-dir": ModelStoragePath.FLAN_T5_SMALL_CAIKIT,
37+
"scale-metric": "concurrency",
38+
"scale-target": 1,
39+
},
40+
)
41+
],
42+
indirect=True,
43+
)
44+
class TestConcurrencyAutoScale:
45+
@pytest.mark.dependency(name="test_auto_scale_using_concurrency")
46+
def test_auto_scale_using_concurrency(
47+
self,
48+
admin_client,
49+
s3_models_inference_service,
50+
multiple_tgis_inference_requests,
51+
):
52+
"""Verify model is successfully scaled up based on concurrency metrics (KPA)"""
53+
for pods in inference_service_pods_sampler(
54+
client=admin_client,
55+
isvc=s3_models_inference_service,
56+
timeout=Timeout.TIMEOUT_1MIN,
57+
):
58+
if pods:
59+
if len(pods) > 1 and all([pod.status == pod.Status.RUNNING for pod in pods]):
60+
return
61+
62+
@pytest.mark.dependency(requires=["test_auto_scale_using_concurrency"])
63+
def test_pods_scaled_down_when_no_requests(self, admin_client, s3_models_inference_service):
64+
"""Verify auto-scaled pods are deleted when there are no inference requests"""
65+
for pods in inference_service_pods_sampler(
66+
client=admin_client,
67+
isvc=s3_models_inference_service,
68+
timeout=Timeout.TIMEOUT_4MIN,
69+
):
70+
if pods and len(pods) == 1:
71+
return

tests/model_serving/model_server/serverless/utils.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
15
from kubernetes.dynamic import DynamicClient
26
from ocp_resources.inference_service import InferenceService
37
from simple_logger.logger import get_logger
48
from timeout_sampler import TimeoutSampler
9+
from timeout_sampler import TimeoutExpiredError
510

11+
from tests.model_serving.model_server.utils import verify_inference_response
612
from utilities.constants import Timeout
13+
from utilities.exceptions import InferenceCanaryTrafficError
714
from utilities.infra import get_pods_by_isvc_label
815

916

@@ -38,3 +45,117 @@ def verify_no_inference_pods(client: DynamicClient, isvc: InferenceService) -> N
3845
except TimeoutError:
3946
LOGGER.error(f"{[pod.name for pod in pods]} were not deleted")
4047
raise
48+
49+
50+
def wait_for_canary_rollout(isvc: InferenceService, percentage: int, timeout: int = Timeout.TIMEOUT_5MIN) -> None:
51+
"""
52+
Wait for inference service to be updated with canary rollout.
53+
54+
Args:
55+
isvc (InferenceService): InferenceService object
56+
percentage (int): Percentage of canary rollout
57+
timeout (int): Timeout in seconds
58+
59+
Raises:
60+
TimeoutExpired: If canary rollout is not updated
61+
62+
"""
63+
sample = None
64+
65+
try:
66+
for sample in TimeoutSampler(
67+
wait_timeout=timeout,
68+
sleep=5,
69+
func=lambda: isvc.instance.status.components.predictor.get("traffic", []),
70+
):
71+
if sample:
72+
for traffic_info in sample:
73+
if traffic_info.get("latestRevision") and traffic_info.get("percent") == percentage:
74+
return
75+
76+
except TimeoutExpiredError:
77+
LOGGER.error(
78+
f"InferenceService {isvc.name} canary rollout is not updated to {percentage}. Traffic info:\n{sample}"
79+
)
80+
raise
81+
82+
83+
def verify_canary_traffic(
84+
isvc: InferenceService,
85+
inference_config: dict[str, Any],
86+
inference_type: str,
87+
protocol: str,
88+
model_name: str,
89+
iterations: int,
90+
expected_percentage: int,
91+
tolerance: int = 0,
92+
) -> None:
93+
"""
94+
Verify canary traffic percentage against inference_config.
95+
96+
Args:
97+
isvc (InferenceService): Inference service.
98+
inference_config (dict[str, Any]): Inference config.
99+
inference_type (str): Inference type.
100+
protocol (str): Protocol.
101+
model_name (str): Model name.
102+
iterations (int): Number of iterations.
103+
expected_percentage (int): Percentage of canary rollout.
104+
tolerance (int): Tolerance of traffic percentage distribution;
105+
difference between actual and expected percentage.
106+
107+
Raises:
108+
InferenceCanaryTrafficError: If canary rollout is not updated
109+
110+
"""
111+
successful_inferences = 0
112+
113+
for iteration in range(iterations):
114+
try:
115+
verify_inference_response(
116+
inference_service=isvc,
117+
inference_config=inference_config,
118+
inference_type=inference_type,
119+
protocol=protocol,
120+
model_name=model_name,
121+
use_default_query=True,
122+
)
123+
LOGGER.info(f"Successful inference. Iteration: {iteration + 1}")
124+
125+
successful_inferences += 1
126+
127+
except Exception as ex:
128+
LOGGER.warning(f"Inference failed. Error: {ex}. Previous model was used.")
129+
130+
LOGGER.info(f"Number of inference requests to the new model: {successful_inferences}")
131+
successful_inferences_percentage = successful_inferences / iterations * 100
132+
133+
diff_percentage = abs(expected_percentage - successful_inferences_percentage)
134+
135+
if successful_inferences == 0 or diff_percentage > tolerance:
136+
raise InferenceCanaryTrafficError(
137+
f"Percentage of inference requests {successful_inferences_percentage} "
138+
f"to the new model does not match the expected percentage {expected_percentage}. "
139+
)
140+
141+
142+
def inference_service_pods_sampler(client: DynamicClient, isvc: InferenceService, timeout: int) -> TimeoutSampler:
143+
"""
144+
Returns TimeoutSampler for inference service.
145+
146+
Args:
147+
client (DynamicClient): DynamicClient object
148+
isvc (InferenceService): InferenceService object
149+
timeout (int): Timeout in seconds
150+
151+
Returns:
152+
TimeoutSampler: TimeoutSampler object
153+
154+
"""
155+
return TimeoutSampler(
156+
wait_timeout=timeout,
157+
sleep=1,
158+
func=get_pods_by_isvc_label,
159+
client=client,
160+
isvc=isvc,
161+
)

0 commit comments

Comments
 (0)