Skip to content

Commit ea1f74c

Browse files
author
EC2 Default User
committed
Add huggingface/vllm local mode tests with tiny-random-qwen3 model
1 parent 2d993e1 commit ea1f74c

File tree

3 files changed

+44
-47
lines changed

3 files changed

+44
-47
lines changed

test/sagemaker_tests/huggingface/vllm/integration/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,21 @@
1313
from __future__ import absolute_import
1414

1515
import json
16+
import os
1617
import re
1718

1819
import boto3
1920

21+
# Path to test resources
22+
resources_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "resources"))
2023

21-
ROLE = "SageMakerRole"
24+
# Model artifacts for local mode tests
25+
model_dir = os.path.join(resources_path, "tiny-random-qwen3")
26+
model_data = "tiny-random-qwen3.tar.gz"
27+
model_data_path = os.path.join(model_dir, model_data)
28+
29+
# Role for local mode (not used but required by SageMaker SDK)
30+
ROLE = "dummy/unused-role"
2231
DEFAULT_TIMEOUT = 45
2332

2433

@@ -32,7 +41,8 @@ class SageMakerEndpointFailure(Exception):
3241

3342
def dump_logs_from_cloudwatch(e, region="us-west-2"):
3443
"""
35-
Function to dump logs from cloudwatch during error handling
44+
Function to dump logs from cloudwatch during error handling.
45+
Gracefully handles missing log groups/streams.
3646
"""
3747
error_hosting_endpoint_regex = re.compile(r"Error hosting endpoint ((\w|-)+):")
3848
endpoint_url_regex = re.compile(r"/aws/sagemaker/Endpoints/((\w|-)+)")
@@ -43,6 +53,7 @@ def dump_logs_from_cloudwatch(e, region="us-west-2"):
4353
logs_client = boto3.client("logs", region_name=region)
4454
endpoint = endpoint_match.group(1)
4555
log_group_name = f"/aws/sagemaker/Endpoints/{endpoint}"
56+
try:
4657
log_stream_resp = logs_client.describe_log_streams(logGroupName=log_group_name)
4758
all_traffic_log_stream = ""
4859
for log_stream in log_stream_resp.get("logStreams", []):
@@ -60,3 +71,8 @@ def dump_logs_from_cloudwatch(e, region="us-west-2"):
6071
raise SageMakerEndpointFailure(
6172
f"Error from endpoint {endpoint}:\n{json.dumps(events, indent=4)}"
6273
) from e
74+
except logs_client.exceptions.ResourceNotFoundException:
75+
# Log group doesn't exist yet - endpoint may have failed before creating logs
76+
raise SageMakerEndpointFailure(
77+
f"Endpoint {endpoint} failed. No CloudWatch logs available yet."
78+
) from e

test/sagemaker_tests/huggingface/vllm/integration/local/test_serving.py

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import json
16-
import logging
1715
from contextlib import contextmanager
1816

1917
import pytest
@@ -22,101 +20,84 @@
2220
from sagemaker.serializers import JSONSerializer
2321
from sagemaker.deserializers import JSONDeserializer
2422

25-
from ...integration import ROLE
26-
27-
LOGGER = logging.getLogger(__name__)
23+
from ...integration import ROLE, model_data_path
24+
from ...utils import local_mode_utils
2825

2926

3027
@contextmanager
31-
def _predictor(image, sagemaker_local_session, instance_type, model_id):
32-
"""Context manager for vLLM model deployment and cleanup."""
28+
def _predictor(image, sagemaker_local_session, instance_type):
29+
"""Context manager for vLLM model deployment and cleanup.
30+
31+
Model is extracted to /opt/ml/model by SageMaker from model_data tar.gz.
32+
vLLM loads the model from this local path.
33+
"""
3334
env = {
34-
"SM_VLLM_MODEL": model_id,
35+
"SM_VLLM_MODEL": "/opt/ml/model",
3536
"SM_VLLM_MAX_MODEL_LEN": "512",
3637
"SM_VLLM_HOST": "0.0.0.0",
3738
}
3839

3940
model = Model(
41+
model_data=f"file://{model_data_path}",
4042
role=ROLE,
4143
image_uri=image,
4244
env=env,
4345
sagemaker_session=sagemaker_local_session,
4446
predictor_cls=Predictor,
4547
)
46-
47-
predictor = None
48-
try:
49-
predictor = model.deploy(1, instance_type)
50-
yield predictor
51-
finally:
52-
if predictor is not None:
53-
predictor.delete_endpoint()
48+
with local_mode_utils.lock():
49+
predictor = None
50+
try:
51+
predictor = model.deploy(1, instance_type)
52+
yield predictor
53+
finally:
54+
if predictor is not None:
55+
predictor.delete_endpoint()
5456

5557

5658
def _assert_vllm_prediction(predictor):
57-
"""Test vLLM inference using OpenAI-compatible API format."""
59+
"""Test vLLM inference using OpenAI-compatible completions API."""
5860
predictor.serializer = JSONSerializer()
5961
predictor.deserializer = JSONDeserializer()
6062

61-
# vLLM uses OpenAI-compatible API format
6263
data = {
6364
"prompt": "What is Deep Learning?",
6465
"max_tokens": 50,
6566
"temperature": 0.7,
6667
}
67-
68-
LOGGER.info(f"Running inference with data: {data}")
6968
output = predictor.predict(data)
70-
LOGGER.info(f"Output: {json.dumps(output)}")
7169

7270
assert output is not None
73-
# vLLM returns OpenAI-compatible response with 'choices' field
74-
assert "choices" in output or "text" in output
71+
assert "choices" in output
7572

7673

7774
def _assert_vllm_chat_prediction(predictor):
78-
"""Test vLLM inference using chat completions format."""
75+
"""Test vLLM inference using OpenAI-compatible chat completions API."""
7976
predictor.serializer = JSONSerializer()
8077
predictor.deserializer = JSONDeserializer()
8178

82-
# vLLM chat completions format
8379
data = {
84-
"messages": [
85-
{"role": "user", "content": "What is Deep Learning?"}
86-
],
80+
"messages": [{"role": "user", "content": "What is Deep Learning?"}],
8781
"max_tokens": 50,
8882
"temperature": 0.7,
8983
}
90-
91-
LOGGER.info(f"Running chat inference with data: {data}")
9284
output = predictor.predict(data)
93-
LOGGER.info(f"Output: {json.dumps(output)}")
9485

9586
assert output is not None
9687
assert "choices" in output
9788

9889

99-
@pytest.mark.model("qwen3-0.6b")
100-
@pytest.mark.processor("gpu")
101-
@pytest.mark.gpu_test
90+
@pytest.mark.model("tiny-random-qwen3")
10291
@pytest.mark.team("sagemaker-1p-algorithms")
10392
def test_vllm_local_completions(ecr_image, sagemaker_local_session, instance_type):
10493
"""Test vLLM local deployment with completions API."""
105-
instance_type = instance_type if instance_type != "local" else "local_gpu"
106-
with _predictor(
107-
ecr_image, sagemaker_local_session, instance_type, "Qwen/Qwen3-0.6B"
108-
) as predictor:
94+
with _predictor(ecr_image, sagemaker_local_session, instance_type) as predictor:
10995
_assert_vllm_prediction(predictor)
11096

11197

112-
@pytest.mark.model("qwen3-0.6b")
113-
@pytest.mark.processor("gpu")
114-
@pytest.mark.gpu_test
98+
@pytest.mark.model("tiny-random-qwen3")
11599
@pytest.mark.team("sagemaker-1p-algorithms")
116100
def test_vllm_local_chat(ecr_image, sagemaker_local_session, instance_type):
117101
"""Test vLLM local deployment with chat completions API."""
118-
instance_type = instance_type if instance_type != "local" else "local_gpu"
119-
with _predictor(
120-
ecr_image, sagemaker_local_session, instance_type, "Qwen/Qwen3-0.6B"
121-
) as predictor:
102+
with _predictor(ecr_image, sagemaker_local_session, instance_type) as predictor:
122103
_assert_vllm_chat_prediction(predictor)
Binary file not shown.

0 commit comments

Comments
 (0)