Skip to content

Commit b5a604b

Browse files
author
EC2 Default User
committed
updated test to sagemaker v3
1 parent 181d61b commit b5a604b

File tree

1 file changed

+19
-15
lines changed
  • test/sagemaker_tests/huggingface/vllm/integration/sagemaker

1 file changed

+19
-15
lines changed

test/sagemaker_tests/huggingface/vllm/integration/sagemaker/test_vllm.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import pytest
1919
import sagemaker
20-
from sagemaker.huggingface import HuggingFaceModel
20+
from sagemaker.model import Model
21+
from sagemaker.predictor import Predictor
2122
from sagemaker.serializers import JSONSerializer
2223
from sagemaker.deserializers import JSONDeserializer
2324

@@ -33,7 +34,6 @@
3334
@pytest.mark.gpu_test
3435
@pytest.mark.team("sagemaker-1p-algorithms")
3536
def test_vllm_bloom(framework_version, ecr_image, instance_type, sagemaker_regions):
36-
instance_type = "ml.g6.12xlarge"
3737
invoke_sm_endpoint_helper_function(
3838
ecr_image=ecr_image,
3939
sagemaker_regions=sagemaker_regions,
@@ -50,7 +50,6 @@ def test_vllm_bloom(framework_version, ecr_image, instance_type, sagemaker_regio
5050
@pytest.mark.gpu_test
5151
@pytest.mark.team("sagemaker-1p-algorithms")
5252
def test_vllm_qwen(framework_version, ecr_image, instance_type, sagemaker_regions):
53-
instance_type = "ml.g6.12xlarge"
5453
invoke_sm_endpoint_helper_function(
5554
ecr_image=ecr_image,
5655
sagemaker_regions=sagemaker_regions,
@@ -72,13 +71,15 @@ def _test_vllm_model(
7271
):
7372
"""Test vLLM model deployment and inference using OpenAI-compatible API format
7473
75-
Note: Parameters must match what invoke_sm_endpoint_helper_function passes:
76-
- image_uri: ECR image URI
77-
- sagemaker_session: SageMaker session
78-
- instance_type: ML instance type (passed via **test_function_args)
79-
- model_id: HuggingFace model ID (passed via **test_function_args)
80-
- framework_version: Optional version info (passed via **test_function_args)
81-
- **kwargs: Additional args from helper (boto_session, sagemaker_client, etc.)
74+
Uses sagemaker.model.Model for SDK v3 compatibility instead of HuggingFaceModel.
75+
76+
Args:
77+
image_uri: ECR image URI
78+
sagemaker_session: SageMaker session
79+
instance_type: ML instance type
80+
model_id: HuggingFace model ID
81+
framework_version: Optional version info
82+
**kwargs: Additional args from helper (boto_session, sagemaker_client, etc.)
8283
"""
8384
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-hf-vllm-serving")
8485

@@ -88,19 +89,22 @@ def _test_vllm_model(
8889
"SM_VLLM_HOST": "0.0.0.0",
8990
}
9091

91-
hf_model = HuggingFaceModel(
92-
env=env,
93-
role="SageMakerRole",
92+
model = Model(
93+
name=endpoint_name,
9494
image_uri=image_uri,
95+
role="SageMakerRole",
96+
env=env,
9597
sagemaker_session=sagemaker_session,
98+
predictor_cls=Predictor,
9699
)
97100

98101
with timeout_and_delete_endpoint(endpoint_name, sagemaker_session, minutes=45):
99-
predictor = hf_model.deploy(
102+
predictor = model.deploy(
100103
initial_instance_count=1,
101104
instance_type=instance_type,
102105
endpoint_name=endpoint_name,
103106
container_startup_health_check_timeout=1800,
107+
inference_ami_version="al2-ami-sagemaker-inference-gpu-3-1",
104108
)
105109

106110
predictor.serializer = JSONSerializer()
@@ -117,4 +121,4 @@ def _test_vllm_model(
117121
output = predictor.predict(data)
118122
LOGGER.info(f"Output: {json.dumps(output)}")
119123

120-
assert output is not None
124+
assert output is not None

0 commit comments

Comments
 (0)