1717
1818import pytest
1919import sagemaker
20- from sagemaker .huggingface import HuggingFaceModel
20+ from sagemaker .model import Model
21+ from sagemaker .predictor import Predictor
2122from sagemaker .serializers import JSONSerializer
2223from sagemaker .deserializers import JSONDeserializer
2324
3334@pytest .mark .gpu_test
3435@pytest .mark .team ("sagemaker-1p-algorithms" )
3536def test_vllm_bloom (framework_version , ecr_image , instance_type , sagemaker_regions ):
36- instance_type = "ml.g6.12xlarge"
3737 invoke_sm_endpoint_helper_function (
3838 ecr_image = ecr_image ,
3939 sagemaker_regions = sagemaker_regions ,
@@ -50,7 +50,6 @@ def test_vllm_bloom(framework_version, ecr_image, instance_type, sagemaker_regio
5050@pytest .mark .gpu_test
5151@pytest .mark .team ("sagemaker-1p-algorithms" )
5252def test_vllm_qwen (framework_version , ecr_image , instance_type , sagemaker_regions ):
53- instance_type = "ml.g6.12xlarge"
5453 invoke_sm_endpoint_helper_function (
5554 ecr_image = ecr_image ,
5655 sagemaker_regions = sagemaker_regions ,
@@ -72,13 +71,15 @@ def _test_vllm_model(
7271):
7372 """Test vLLM model deployment and inference using OpenAI-compatible API format
7473
75- Note: Parameters must match what invoke_sm_endpoint_helper_function passes:
76- - image_uri: ECR image URI
77- - sagemaker_session: SageMaker session
78- - instance_type: ML instance type (passed via **test_function_args)
79- - model_id: HuggingFace model ID (passed via **test_function_args)
80- - framework_version: Optional version info (passed via **test_function_args)
81- - **kwargs: Additional args from helper (boto_session, sagemaker_client, etc.)
74+ Uses sagemaker.model.Model for SDK v3 compatibility instead of HuggingFaceModel.
75+
76+ Args:
77+ image_uri: ECR image URI
78+ sagemaker_session: SageMaker session
79+ instance_type: ML instance type
80+ model_id: HuggingFace model ID
81+ framework_version: Optional version info
82+ **kwargs: Additional args from helper (boto_session, sagemaker_client, etc.)
8283 """
8384 endpoint_name = sagemaker .utils .unique_name_from_base ("sagemaker-hf-vllm-serving" )
8485
@@ -88,19 +89,22 @@ def _test_vllm_model(
8889 "SM_VLLM_HOST" : "0.0.0.0" ,
8990 }
9091
91- hf_model = HuggingFaceModel (
92- env = env ,
93- role = "SageMakerRole" ,
92+ model = Model (
93+ name = endpoint_name ,
9494 image_uri = image_uri ,
95+ role = "SageMakerRole" ,
96+ env = env ,
9597 sagemaker_session = sagemaker_session ,
98+ predictor_cls = Predictor ,
9699 )
97100
98101 with timeout_and_delete_endpoint (endpoint_name , sagemaker_session , minutes = 45 ):
99- predictor = hf_model .deploy (
102+ predictor = model .deploy (
100103 initial_instance_count = 1 ,
101104 instance_type = instance_type ,
102105 endpoint_name = endpoint_name ,
103106 container_startup_health_check_timeout = 1800 ,
107+ inference_ami_version = "al2-ami-sagemaker-inference-gpu-3-1" ,
104108 )
105109
106110 predictor .serializer = JSONSerializer ()
@@ -117,4 +121,4 @@ def _test_vllm_model(
117121 output = predictor .predict (data )
118122 LOGGER .info (f"Output: { json .dumps (output )} " )
119123
120- assert output is not None
124+ assert output is not None
0 commit comments