|
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | from __future__ import absolute_import |
14 | 14 |
|
15 | | -import json |
16 | | -import logging |
17 | 15 | from contextlib import contextmanager |
18 | 16 |
|
19 | 17 | import pytest |
|
22 | 20 | from sagemaker.serializers import JSONSerializer |
23 | 21 | from sagemaker.deserializers import JSONDeserializer |
24 | 22 |
|
25 | | -from ...integration import ROLE |
26 | | - |
27 | | -LOGGER = logging.getLogger(__name__) |
| 23 | +from ...integration import ROLE, model_data_path |
| 24 | +from ...utils import local_mode_utils |
28 | 25 |
|
29 | 26 |
|
30 | 27 | @contextmanager |
31 | | -def _predictor(image, sagemaker_local_session, instance_type, model_id): |
32 | | - """Context manager for vLLM model deployment and cleanup.""" |
| 28 | +def _predictor(image, sagemaker_local_session, instance_type): |
| 29 | + """Context manager for vLLM model deployment and cleanup. |
| 30 | + |
| 31 | + Model is extracted to /opt/ml/model by SageMaker from model_data tar.gz. |
| 32 | + vLLM loads the model from this local path. |
| 33 | + """ |
33 | 34 | env = { |
34 | | - "SM_VLLM_MODEL": model_id, |
| 35 | + "SM_VLLM_MODEL": "/opt/ml/model", |
35 | 36 | "SM_VLLM_MAX_MODEL_LEN": "512", |
36 | 37 | "SM_VLLM_HOST": "0.0.0.0", |
37 | 38 | } |
38 | 39 |
|
39 | 40 | model = Model( |
| 41 | + model_data=f"file://{model_data_path}", |
40 | 42 | role=ROLE, |
41 | 43 | image_uri=image, |
42 | 44 | env=env, |
43 | 45 | sagemaker_session=sagemaker_local_session, |
44 | 46 | predictor_cls=Predictor, |
45 | 47 | ) |
46 | | - |
47 | | - predictor = None |
48 | | - try: |
49 | | - predictor = model.deploy(1, instance_type) |
50 | | - yield predictor |
51 | | - finally: |
52 | | - if predictor is not None: |
53 | | - predictor.delete_endpoint() |
| 48 | + with local_mode_utils.lock(): |
| 49 | + predictor = None |
| 50 | + try: |
| 51 | + predictor = model.deploy(1, instance_type) |
| 52 | + yield predictor |
| 53 | + finally: |
| 54 | + if predictor is not None: |
| 55 | + predictor.delete_endpoint() |
54 | 56 |
|
55 | 57 |
|
56 | 58 | def _assert_vllm_prediction(predictor): |
57 | | - """Test vLLM inference using OpenAI-compatible API format.""" |
| 59 | + """Test vLLM inference using OpenAI-compatible completions API.""" |
58 | 60 | predictor.serializer = JSONSerializer() |
59 | 61 | predictor.deserializer = JSONDeserializer() |
60 | 62 |
|
61 | | - # vLLM uses OpenAI-compatible API format |
62 | 63 | data = { |
63 | 64 | "prompt": "What is Deep Learning?", |
64 | 65 | "max_tokens": 50, |
65 | 66 | "temperature": 0.7, |
66 | 67 | } |
67 | | - |
68 | | - LOGGER.info(f"Running inference with data: {data}") |
69 | 68 | output = predictor.predict(data) |
70 | | - LOGGER.info(f"Output: {json.dumps(output)}") |
71 | 69 |
|
72 | 70 | assert output is not None |
73 | | - # vLLM returns OpenAI-compatible response with 'choices' field |
74 | | - assert "choices" in output or "text" in output |
| 71 | + assert "choices" in output |
75 | 72 |
|
76 | 73 |
|
77 | 74 | def _assert_vllm_chat_prediction(predictor): |
78 | | - """Test vLLM inference using chat completions format.""" |
| 75 | + """Test vLLM inference using OpenAI-compatible chat completions API.""" |
79 | 76 | predictor.serializer = JSONSerializer() |
80 | 77 | predictor.deserializer = JSONDeserializer() |
81 | 78 |
|
82 | | - # vLLM chat completions format |
83 | 79 | data = { |
84 | | - "messages": [ |
85 | | - {"role": "user", "content": "What is Deep Learning?"} |
86 | | - ], |
| 80 | + "messages": [{"role": "user", "content": "What is Deep Learning?"}], |
87 | 81 | "max_tokens": 50, |
88 | 82 | "temperature": 0.7, |
89 | 83 | } |
90 | | - |
91 | | - LOGGER.info(f"Running chat inference with data: {data}") |
92 | 84 | output = predictor.predict(data) |
93 | | - LOGGER.info(f"Output: {json.dumps(output)}") |
94 | 85 |
|
95 | 86 | assert output is not None |
96 | 87 | assert "choices" in output |
97 | 88 |
|
98 | 89 |
|
99 | | -@pytest.mark.model("qwen3-0.6b") |
100 | | -@pytest.mark.processor("gpu") |
101 | | -@pytest.mark.gpu_test |
| 90 | +@pytest.mark.model("tiny-random-qwen3") |
102 | 91 | @pytest.mark.team("sagemaker-1p-algorithms") |
103 | 92 | def test_vllm_local_completions(ecr_image, sagemaker_local_session, instance_type): |
104 | 93 | """Test vLLM local deployment with completions API.""" |
105 | | - instance_type = instance_type if instance_type != "local" else "local_gpu" |
106 | | - with _predictor( |
107 | | - ecr_image, sagemaker_local_session, instance_type, "Qwen/Qwen3-0.6B" |
108 | | - ) as predictor: |
| 94 | + with _predictor(ecr_image, sagemaker_local_session, instance_type) as predictor: |
109 | 95 | _assert_vllm_prediction(predictor) |
110 | 96 |
|
111 | 97 |
|
112 | | -@pytest.mark.model("qwen3-0.6b") |
113 | | -@pytest.mark.processor("gpu") |
114 | | -@pytest.mark.gpu_test |
| 98 | +@pytest.mark.model("tiny-random-qwen3") |
115 | 99 | @pytest.mark.team("sagemaker-1p-algorithms") |
116 | 100 | def test_vllm_local_chat(ecr_image, sagemaker_local_session, instance_type): |
117 | 101 | """Test vLLM local deployment with chat completions API.""" |
118 | | - instance_type = instance_type if instance_type != "local" else "local_gpu" |
119 | | - with _predictor( |
120 | | - ecr_image, sagemaker_local_session, instance_type, "Qwen/Qwen3-0.6B" |
121 | | - ) as predictor: |
| 102 | + with _predictor(ecr_image, sagemaker_local_session, instance_type) as predictor: |
122 | 103 | _assert_vllm_chat_prediction(predictor) |
0 commit comments