Skip to content

Commit 9d5d969

Browse files
committed
test: add REST and gRPC test cases for Sklearn framework with MLServer runtime deployment
Signed-off-by: Snomaan6846 <syedali@redhat.com> rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED
1 parent de11de9 commit 9d5d969

16 files changed

+1156
-0
lines changed

conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def pytest_addoption(parser: Parser) -> None:
102102
default=os.environ.get("VLLM_RUNTIME_IMAGE"),
103103
help="Specify the runtime image to use for the tests",
104104
)
105+
runtime_group.addoption(
106+
"--mlserver-runtime-image",
107+
default=os.environ.get("MLSERVER_RUNTIME_IMAGE"),
108+
help="Specify the runtime image to use for the tests",
109+
)
105110

106111
# Upgrade options
107112
upgrade_group.addoption(

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ def vllm_runtime_image(pytestconfig: pytest.Config) -> str | None:
236236
return runtime_image
237237

238238

239+
@pytest.fixture(scope="session")
240+
def mlserver_runtime_image(pytestconfig: pytest.Config) -> str | None:
241+
runtime_image = pytestconfig.option.mlserver_runtime_image
242+
if not runtime_image:
243+
return None
244+
return runtime_image
245+
246+
239247
@pytest.fixture(scope="session")
240248
def use_unprivileged_client(pytestconfig: pytest.Config) -> bool:
241249
_use_unprivileged_client = py_config.get("use_unprivileged_client")

tests/model_serving/model_runtime/mlserver/__init__.py

Whitespace-only changes.

tests/model_serving/model_runtime/mlserver/basic_model_deployment/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"modelName": "sklearn-iris",
3+
"modelVersion": "v1.0.0",
4+
"outputs": [
5+
{
6+
"contents": {
7+
"int64Contents": [
8+
"1",
9+
"1"
10+
]
11+
},
12+
"datatype": "INT64",
13+
"name": "predict",
14+
"parameters": {
15+
"content_type": {
16+
"stringParam": "np"
17+
}
18+
},
19+
"shape": [
20+
"2",
21+
"1"
22+
]
23+
}
24+
]
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"id": "sklearn-iris",
3+
"model_name": "sklearn-iris",
4+
"model_version": "v1.0.0",
5+
"outputs": [
6+
{
7+
"data": [
8+
1,
9+
1
10+
],
11+
"datatype": "INT64",
12+
"name": "predict",
13+
"parameters": {
14+
"content_type": "np"
15+
},
16+
"shape": [
17+
2,
18+
1
19+
]
20+
}
21+
],
22+
"parameters": {}
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"modelName": "sklearn-iris",
3+
"modelVersion": "v1.0.0",
4+
"outputs": [
5+
{
6+
"contents": {
7+
"int64Contents": [
8+
"1",
9+
"1"
10+
]
11+
},
12+
"datatype": "INT64",
13+
"name": "predict",
14+
"parameters": {
15+
"content_type": {
16+
"stringParam": "np"
17+
}
18+
},
19+
"shape": [
20+
"2",
21+
"1"
22+
]
23+
}
24+
]
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"id": "sklearn-iris",
3+
"model_name": "sklearn-iris",
4+
"model_version": "v1.0.0",
5+
"outputs": [
6+
{
7+
"data": [
8+
1,
9+
1
10+
],
11+
"datatype": "INT64",
12+
"name": "predict",
13+
"parameters": {
14+
"content_type": "np"
15+
},
16+
"shape": [
17+
2,
18+
1
19+
]
20+
}
21+
],
22+
"parameters": {}
23+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
Test module for sklearn model deployment using MLServer runtime.
3+
4+
This module contains parameterized tests that validate sklearn model inference
5+
across different protocols (REST/gRPC) and deployment types (raw/serverless).
6+
"""
7+
8+
from typing import Any
9+
10+
import pytest
11+
from simple_logger.logger import get_logger
12+
13+
from ocp_resources.inference_service import InferenceService
14+
from ocp_resources.pod import Pod
15+
16+
from utilities.constants import Protocols
17+
18+
from tests.model_serving.model_runtime.mlserver.constant import (
19+
BASE_RAW_DEPLOYMENT_CONFIG,
20+
BASE_SERVERLESS_DEPLOYMENT_CONFIG,
21+
MODEL_PATH_PREFIX,
22+
SKLEARN_GRPC_INPUT_QUERY,
23+
SKLEARN_REST_INPUT_QUERY,
24+
)
25+
from tests.model_serving.model_runtime.mlserver.utils import validate_inference_request
26+
27+
28+
LOGGER = get_logger(name=__name__)
29+
30+
MODEL_NAME: str = "sklearn-iris"
31+
32+
MODEL_VERSION: str = "v1.0.0"
33+
34+
MODEL_NAME_DICT: dict[str, str] = {"name": MODEL_NAME}
35+
36+
MODEL_STORAGE_URI_DICT: dict[str, str] = {"model-dir": f"{MODEL_PATH_PREFIX}/sklearn"}
37+
38+
39+
pytestmark = pytest.mark.usefixtures(
40+
"root_dir", "valid_aws_config", "mlserver_rest_serving_runtime_template", "mlserver_grpc_serving_runtime_template"
41+
)
42+
43+
44+
@pytest.mark.parametrize(
45+
"protocol, model_namespace, mlserver_inference_service, s3_models_storage_uri, serving_runtime",
46+
[
47+
pytest.param(
48+
{"protocol_type": Protocols.REST},
49+
{"name": "sklearn-iris-raw-rest"},
50+
{
51+
**BASE_RAW_DEPLOYMENT_CONFIG,
52+
**MODEL_NAME_DICT,
53+
},
54+
MODEL_STORAGE_URI_DICT,
55+
BASE_RAW_DEPLOYMENT_CONFIG,
56+
id="sklearn-iris-raw-rest-deployment",
57+
),
58+
pytest.param(
59+
{"protocol_type": Protocols.GRPC},
60+
{"name": "sklearn-iris-raw-grpc"},
61+
{
62+
**BASE_RAW_DEPLOYMENT_CONFIG,
63+
**MODEL_NAME_DICT,
64+
},
65+
MODEL_STORAGE_URI_DICT,
66+
BASE_RAW_DEPLOYMENT_CONFIG,
67+
id="sklearn-iris-raw-grpc-deployment",
68+
),
69+
pytest.param(
70+
{"protocol_type": Protocols.REST},
71+
{"name": "sklearn-iris-serverless-rest"},
72+
{**BASE_SERVERLESS_DEPLOYMENT_CONFIG, **MODEL_NAME_DICT},
73+
MODEL_STORAGE_URI_DICT,
74+
BASE_SERVERLESS_DEPLOYMENT_CONFIG,
75+
id="sklearn-iris-serverless-rest-deployment",
76+
),
77+
pytest.param(
78+
{"protocol_type": Protocols.GRPC},
79+
{"name": "sklearn-iris-serverless-grpc"},
80+
{**BASE_SERVERLESS_DEPLOYMENT_CONFIG, **MODEL_NAME_DICT},
81+
MODEL_STORAGE_URI_DICT,
82+
BASE_SERVERLESS_DEPLOYMENT_CONFIG,
83+
id="sklearn-iris-serverless-grpc-deployment",
84+
),
85+
],
86+
indirect=True,
87+
)
88+
class TestSkLearnModel:
89+
"""Test class for sklearn model inference with MLServer runtime.
90+
91+
Tests cover multiple deployment scenarios:
92+
- REST and gRPC protocols
93+
- Raw and serverless deployment modes
94+
- Response validation against snapshots
95+
"""
96+
97+
def test_sklearn_model_inference(
98+
self,
99+
mlserver_inference_service: InferenceService,
100+
mlserver_pod_resource: Pod,
101+
response_snapshot: Any,
102+
protocol: str,
103+
root_dir: str,
104+
):
105+
"""Test sklearn model inference across different protocols and deployment types.
106+
107+
Args:
108+
mlserver_inference_service: The deployed inference service
109+
mlserver_pod_resource: Pod running the model server
110+
response_snapshot: Expected response for validation
111+
protocol: Communication protocol (REST or gRPC)
112+
root_dir: Test root directory path
113+
"""
114+
115+
input_query = SKLEARN_REST_INPUT_QUERY if protocol == Protocols.REST else SKLEARN_GRPC_INPUT_QUERY
116+
117+
validate_inference_request(
118+
pod_name=mlserver_pod_resource.name,
119+
isvc=mlserver_inference_service,
120+
response_snapshot=response_snapshot,
121+
input_query=input_query,
122+
model_version=MODEL_VERSION,
123+
protocol=protocol,
124+
root_dir=root_dir,
125+
)

0 commit comments

Comments
 (0)