Skip to content

Commit 6ca7af8

Browse files
authored
test(model-validation): run raw and serverless execution in parallel (#693)
Signed-off-by: Snomaan6846 <syedali@redhat.com> rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED
1 parent 49cb01b commit 6ca7af8

File tree

5 files changed

+100
-16
lines changed

5 files changed

+100
-16
lines changed

conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,11 @@ def pytest_sessionstart(session: Session) -> None:
280280
pathlib.Path(tests_log_file).unlink()
281281
if session.config.getoption("--collect-must-gather"):
282282
session.config.option.must_gather_db = Database()
283+
thread_name = os.environ.get("PYTEST_XDIST_WORKER", "master")
283284
session.config.option.log_listener = setup_logging(
284285
log_file=tests_log_file,
285286
log_level=session.config.getoption("log_cli_level") or logging.INFO,
287+
thread_name=thread_name,
286288
)
287289
must_gather_dict = set_must_gather_collector_values()
288290
shutil.rmtree(

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ markers =
77
polarion: Store polarion test ID
88
jira: Store jira bug ID
99
skip_on_disconnected: Mark tests that can only be run in deployments with Internet access i.e. not on disconnected clusters.
10+
parallel: marks tests that can run in parallel along with pytest-xdist
1011

1112
# CI
1213
smoke: Mark tests as smoke tests; covers core functionality of the product. Aims to ensure that the build is stable enough for further testing.

tests/model_serving/model_runtime/model_validation/conftest.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Generator
2+
from typing import Any, Generator, List
33

44
import pytest
55
import yaml
@@ -158,51 +158,89 @@ def deployment_config(request: FixtureRequest) -> dict[str, Any]:
158158

159159

160160
def build_raw_params(
161-
name: str, image: str, args: list[str], gpu_count: int, model_output_type: str = "text"
161+
name: str,
162+
image: str,
163+
args: list[str],
164+
gpu_count: int,
165+
execution_mode: str,
166+
model_output_type: str = "text",
162167
) -> tuple[Any, str]:
163168
test_id = f"{name}-raw"
169+
deployment_type = KServeDeploymentType.RAW_DEPLOYMENT
164170
param = pytest.param(
165171
{"name": "raw-model-validation"},
166-
{"deployment_type": KServeDeploymentType.RAW_DEPLOYMENT},
172+
{"deployment_type": deployment_type},
167173
{
168174
"model_name": name,
169175
"model_car_image_uri": image,
170176
},
171177
{
172-
"deployment_type": KServeDeploymentType.RAW_DEPLOYMENT,
178+
"deployment_type": deployment_type,
173179
"runtime_argument": args,
174180
"gpu_count": gpu_count,
175181
"model_output_type": model_output_type,
176182
},
177183
id=test_id,
178-
marks=[pytest.mark.rawdeployment],
184+
marks=build_pytest_markers(deployment_type=deployment_type, execution_mode=execution_mode),
179185
)
180186
return param, test_id
181187

182188

183189
def build_serverless_params(
184-
name: str, image: str, args: list[str], gpu_count: int, model_output_type: str = "text"
190+
name: str,
191+
image: str,
192+
args: list[str],
193+
gpu_count: int,
194+
execution_mode: str,
195+
model_output_type: str = "text",
185196
) -> tuple[Any, str]:
186197
test_id = f"{name}-serverless"
198+
deployment_type = KServeDeploymentType.SERVERLESS
187199
param = pytest.param(
188200
{"name": "serverless-model-validation"},
189-
{"deployment_type": KServeDeploymentType.SERVERLESS},
201+
{"deployment_type": deployment_type},
190202
{
191203
"model_name": name,
192204
"model_car_image_uri": image,
193205
},
194206
{
195-
"deployment_type": KServeDeploymentType.SERVERLESS,
207+
"deployment_type": deployment_type,
196208
"runtime_argument": args,
197209
"gpu_count": gpu_count,
198210
"model_output_type": model_output_type,
199211
},
200212
id=test_id,
201-
marks=[pytest.mark.serverless],
213+
marks=build_pytest_markers(deployment_type=deployment_type, execution_mode=execution_mode),
202214
)
203215
return param, test_id
204216

205217

218+
def build_pytest_markers(deployment_type: str, execution_mode: str) -> List[Any]:
219+
"""
220+
Build a list of pytest markers based on deployment type, execution mode.
221+
222+
Args:
223+
deployment_type (str): Deployment type (e.g., RAW_DEPLOYMENT, SERVERLESS)
224+
execution_mode (str): "parallel" or "sequential"
225+
226+
Returns:
227+
List[Any]: List of pytest.mark objects to attach to the test
228+
"""
229+
markers: List[pytest.MarkDecorator] = []
230+
231+
if deployment_type == KServeDeploymentType.RAW_DEPLOYMENT:
232+
markers.append(pytest.mark.rawdeployment)
233+
elif deployment_type == KServeDeploymentType.SERVERLESS:
234+
markers.append(pytest.mark.serverless)
235+
236+
# Execution mode markers
237+
if execution_mode == "parallel":
238+
markers.append(pytest.mark.parallel)
239+
markers.append(pytest.mark.skip_must_gather)
240+
241+
return markers
242+
243+
206244
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
207245
yaml_config = None
208246
yaml_path = metafunc.config.getoption(name="model_car_yaml_path")
@@ -232,6 +270,10 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
232270

233271
name = model_car.get("name", "").strip()
234272
image = model_car.get("image", "").strip()
273+
execution_mode = (
274+
model_car.get("execution_mode", "").strip()
275+
or default_serving_config.get("execution_mode", "sequential").strip()
276+
)
235277

236278
if not name or not image:
237279
continue
@@ -243,11 +285,21 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
243285

244286
if metafunc.cls.__name__ == "TestVLLMModelCarRaw":
245287
param, test_id = build_raw_params(
246-
name=name, image=image, args=args, gpu_count=gpu_count, model_output_type=model_output_type
288+
name=name,
289+
image=image,
290+
args=args,
291+
gpu_count=gpu_count,
292+
execution_mode=execution_mode,
293+
model_output_type=model_output_type,
247294
)
248295
elif metafunc.cls.__name__ == "TestVLLMModelCarServerless":
249296
param, test_id = build_serverless_params(
250-
name=name, image=image, args=args, gpu_count=gpu_count, model_output_type=model_output_type
297+
name=name,
298+
image=image,
299+
args=args,
300+
gpu_count=gpu_count,
301+
execution_mode=execution_mode,
302+
model_output_type=model_output_type,
251303
)
252304
else:
253305
continue

tests/model_serving/model_runtime/model_validation/sample_modelcar_config.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ model-car:
22
- name: granite-3.1-8b-base-quantized.w4a16
33
image: oci://registry.redhat.io/rhelai1/modelcar-granite-3-1-8b-base-quantized-w4a16:1.5
44
model_output_type: text
5+
execution_mode: parallel
56
serving_arguments:
67
args:
78
- "--uvicorn-log-level=info"
@@ -13,6 +14,7 @@ model-car:
1314
- name: whisper-large-v2-W4A16-G128
1415
image: oci://registry.redhat.io/rhelai1/modelcar-whisper-large-v2-w4a16-g128:1.5
1516
model_output_type: audio
17+
execution_mode: parallel
1618
serving_arguments:
1719
args:
1820
- "--uvicorn-log-level=info"
@@ -23,14 +25,17 @@ model-car:
2325
- name: granite-3.1-8b-starter-v2
2426
image: oci://registry.redhat.io/rhelai1/modelcar-granite-3-1-8b-starter-v2:1.5
2527
model_output_type: text
28+
execution_mode: sequential
2629

2730
- name: granite-3.1-8b-instruct-quantized.w8a8
2831
image: oci://registry.redhat.io/rhelai1/modelcar-granite-3-1-8b-instruct-quantized-w8a8:1.5
2932
model_output_type: text
33+
execution_mode: parallel
3034

3135
- name: Llama-3.1-8B-Instruct
3236
image: oci://registry.redhat.io/rhelai1/modelcar-llama-3-1-8b-instruct:1.5
3337
model_output_type: text
38+
execution_mode: sequential
3439
serving_arguments:
3540
args:
3641
- "--uvicorn-log-level=debug"
@@ -42,6 +47,7 @@ model-car:
4247
- name: Mistral-7B-Instruct-v0.3-quantized.w4a16
4348
image: oci://registry.redhat.io/rhelai1/modelcar-mistral-7b-instruct-v0-3-quantized-w4a16:1.5
4449
model_output_type: text
50+
execution_mode: parallel
4551
serving_arguments:
4652
args:
4753
- "--uvicorn-log-level=debug"
@@ -53,11 +59,12 @@ model-car:
5359
- name: Qwen2.5-7B-Instruct-quantized.w8a8
5460
image: oci://registry.redhat.io/rhelai1/modelcar-qwen2-5-7b-instruct-quantized-w8a8:1.5
5561
model_output_type: text
56-
62+
execution_mode: parallel
5763

5864
- name: Qwen2.5-7B-Instruct-FP8-dynamic
5965
image: oci://registry.redhat.io/rhelai1/modelcar-qwen2-5-7b-instruct-fp8-dynamic:1.5
6066
model_output_type: text
67+
execution_mode: parallel
6168
serving_arguments:
6269
args:
6370
- "--uvicorn-log-level=debug"
@@ -69,6 +76,7 @@ model-car:
6976
- name: DeepSeek-R1-Distill-Llama-8B-FP8-dynamic
7077
image: oci://registry.redhat.io/rhelai1/modelcar-deepseek-r1-distill-llama-8b-fp8-dynamic:1.5
7178
model_output_type: text
79+
execution_mode: parallel
7280
serving_arguments:
7381
args:
7482
- "--uvicorn-log-level=debug"
@@ -80,6 +88,7 @@ model-car:
8088
- name: phi-4-quantized.w4a16
8189
image: oci://registry.redhat.io/rhelai1/modelcar-phi-4-quantized-w4a16:1.5
8290
model_output_type: text
91+
execution_mode: parallel
8392
serving_arguments:
8493
args:
8594
- "--uvicorn-log-level=debug"
@@ -91,6 +100,7 @@ model-car:
91100
- name: phi-4-quantized.w8a8
92101
image: oci://registry.redhat.io/rhelai1/modelcar-phi-4-quantized-w8a8:1.5
93102
model_output_type: text
103+
execution_mode: sequential
94104
serving_arguments:
95105
args:
96106
- "--uvicorn-log-level=debug"
@@ -107,3 +117,4 @@ default:
107117
- "--trust-remote-code"
108118
- "--distributed-executor-backend=mp"
109119
gpu_count: 1
120+
execution_mode: sequential

utilities/logger.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ def __repr__(self) -> str:
2121
return "'***REDACTED***'"
2222

2323

24-
def setup_logging(log_level: int, log_file: str = "/tmp/pytest-tests.log") -> QueueListener:
24+
def setup_logging(
25+
log_level: int, log_file: str = "/tmp/pytest-tests.log", thread_name: str | None = None
26+
) -> QueueListener:
2527
"""
2628
Setup basic/root logging using QueueHandler/QueueListener
2729
to consolidate log messages into a single stream to be written to multiple outputs.
2830
2931
Args:
3032
log_level (int): log level
3133
log_file (str): logging output file
34+
thread_name (str | None): optional thread_name id prefix, e.g., [gw0]
3235
3336
Returns:
3437
QueueListener: Process monitoring the log Queue
@@ -38,9 +41,16 @@ def setup_logging(log_level: int, log_file: str = "/tmp/pytest-tests.log") -> Qu
3841
├> Queue -> QueueListener ┤
3942
basic QueueHandler ┘ └> FileHandler
4043
"""
41-
basic_log_formatter = logging.Formatter(fmt="%(message)s")
44+
basic_fmt_str = "%(message)s"
45+
root_fmt_str = "%(asctime)s %(name)s %(log_color)s%(levelname)s%(reset)s %(message)s"
46+
47+
if thread_name:
48+
basic_fmt_str = f"[{thread_name}] {basic_fmt_str}"
49+
root_fmt_str = f"[{thread_name}] {root_fmt_str}"
50+
51+
basic_log_formatter = logging.Formatter(fmt=basic_fmt_str)
4252
root_log_formatter = WrapperLogFormatter(
43-
fmt="%(asctime)s %(name)s %(log_color)s%(levelname)s%(reset)s %(message)s",
53+
fmt=root_fmt_str,
4454
log_colors={
4555
"DEBUG": "cyan",
4656
"INFO": "green",
@@ -67,20 +77,28 @@ def setup_logging(log_level: int, log_file: str = "/tmp/pytest-tests.log") -> Qu
6777

6878
basic_logger = logging.getLogger(name="basic")
6979
basic_logger.setLevel(level=log_level)
80+
basic_logger.handlers.clear()
7081
basic_logger.addHandler(hdlr=basic_log_queue_handler)
7182

7283
root_log_queue_handler = QueueHandler(queue=log_queue)
7384
root_log_queue_handler.set_name(name="root")
7485
root_log_queue_handler.setFormatter(fmt=root_log_formatter)
7586

76-
root_logger = logging.getLogger()
87+
root_logger = logging.getLogger(name="root")
7788
root_logger.setLevel(level=log_level)
89+
root_logger.handlers.clear()
7890
root_logger.addHandler(hdlr=root_log_queue_handler)
7991
root_logger.addFilter(filter=DuplicateFilter())
8092

8193
root_logger.propagate = False
8294
basic_logger.propagate = False
8395

96+
for name, logger in logging.root.manager.loggerDict.items():
97+
if isinstance(logger, logging.Logger) and (name not in ("root", "basic")):
98+
logger.handlers.clear()
99+
logger.addHandler(hdlr=root_log_queue_handler)
100+
logger.propagate = False
101+
84102
log_listener.start()
85103
return log_listener
86104

0 commit comments

Comments
 (0)