Skip to content

Commit fd2b4b5

Browse files
Adding example for sagemaker testing with vLLM DLC
1 parent c9bcf6d commit fd2b4b5

File tree

4 files changed

+287
-0
lines changed

4 files changed

+287
-0
lines changed

vllm-samples/sagemaker/README.md

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# AWS SageMaker vLLM Inference
2+
3+
Deploy and run inference on vLLM models using AWS SageMaker and vLLM DLC.
4+
5+
## Files
6+
7+
- `endpoint.py` - Deploy vLLM model to SageMaker endpoint
8+
- `inference.py` - Run inference against deployed endpoint
9+
10+
## Prerequisites
11+
12+
- AWS CLI configured with appropriate permissions
13+
- HuggingFace token for model access (if required)
14+
15+
## Setup
16+
17+
### Create IAM Role
18+
19+
```bash
20+
# Create trust policy
21+
cat > trust-policy.json << EOF
22+
{
23+
"Version": "2012-10-17",
24+
"Statement": [
25+
{
26+
"Effect": "Allow",
27+
"Principal": {
28+
"Service": "sagemaker.amazonaws.com"
29+
},
30+
"Action": "sts:AssumeRole"
31+
}
32+
]
33+
}
34+
EOF
35+
36+
# Create role
37+
aws iam create-role --role-name SageMakerExecutionRole --assume-role-policy-document file://trust-policy.json
38+
39+
# Attach policies
40+
aws iam attach-role-policy --role-name SageMakerExecutionRole --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
41+
aws iam attach-role-policy --role-name SageMakerExecutionRole --policy-arn arn:aws:iam::aws:policy/AmazonElasticContainerRegistryPublicFullAccess
42+
```
43+
44+
## Quick Start
45+
46+
### 1. Get Latest Image URI
47+
48+
```bash
49+
# Check available images: https://gallery.ecr.aws/deep-learning-containers/vllm
50+
# Get latest vLLM DLC image URI
51+
export CONTAINER_URI="public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.1"
52+
```
53+
54+
### 2. Deploy Endpoint
55+
56+
```bash
57+
# update variables in endpoint.py and run
58+
python endpoint.py
59+
```
60+
61+
### 3. Run Inference
62+
63+
```bash
64+
# update endpoint_name in inference.py and run
65+
python inference.py
66+
```
67+
68+
## Configuration
69+
70+
### Model Parameters
71+
- `SM_VLLM_MODEL` - HuggingFace model ID
72+
- `SM_VLLM_HF_TOKEN` - HuggingFace access token
73+
74+
### Inference Parameters
75+
- `max_tokens` - Maximum response length
76+
- `temperature` - Sampling randomness (0-1)
77+
- `top_p` - Nucleus sampling threshold
78+
- `top_k` - Top-k sampling limit
79+
80+
## Instance Types
81+
82+
Recommended GPU instances:
83+
- `ml.g5.12xlarge` - 4 A10G GPUs, 48 vCPUs, 192 GB RAM
84+
- `ml.g5.24xlarge` - 4 A10G GPUs, 96 vCPUs, 384 GB RAM
85+
- `ml.p4d.24xlarge` - 8 A100 GPUs, 96 vCPUs, 1152 GB RAM
86+
87+
## Test NixlConnector
88+
89+
Test NixlConnector locally - [NixlConnector Documentation](https://docs.vllm.ai/en/latest/features/nixl_connector_usage.html#transport-configuration)
90+
91+
```bash
92+
# Pull latest vLLM DLC for EC2
93+
docker pull public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.1
94+
95+
# Run container with GPU access
96+
docker run -it --entrypoint=/bin/bash --gpus=all \
97+
-v $(pwd):/workspace \
98+
public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.1
99+
100+
# Inside container, run the NixlConnector test
101+
export HF_TOKEN= "<TOKEN>"
102+
./testNixlConnector.sh
103+
```
104+
105+
## Cleanup
106+
107+
```python
108+
import boto3
109+
sagemaker = boto3.client('sagemaker')
110+
sagemaker.delete_endpoint(EndpointName='<endpoint-name>')
111+
```

vllm-samples/sagemaker/endpoint.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from sagemaker.model import Model
2+
3+
# Configuration - replace placeholders with actual values
4+
aws_region = "<REGION>"
5+
instance_type = "ml.g5.12xlarge" # GPU instance for vLLM
6+
iam_role = "<IAM-ROLE>"
7+
endpoint_name = "<NAME>"
8+
container_uri = "<IMAGE_URI>" # DLC image with vLLM
9+
10+
try:
11+
print(f"Starting deployment of endpoint: {endpoint_name}")
12+
print(f"Using image: {container_uri}")
13+
print(f"Instance type: {instance_type}")
14+
15+
print("Creating SageMaker model...")
16+
17+
model = Model(
18+
name=endpoint_name,
19+
image_uri=container_uri,
20+
role="SageMakerRole",
21+
env={
22+
"SM_VLLM_MODEL": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", # Model to load
23+
"SM_VLLM_HF_TOKEN": "<HF-TOKEN>", # HuggingFace token for model access
24+
},
25+
)
26+
print("Model created successfully")
27+
print("Starting endpoint deployment (this may take 10-15 minutes)...")
28+
29+
endpoint_config = model.deploy(
30+
instance_type=instance_type,
31+
initial_instance_count=1,
32+
endpoint_name=endpoint_name,
33+
wait=False, # Deploy asynchronously
34+
)
35+
except Exception as e:
36+
print(f"Deployment failed: {str(e)}")
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import json
2+
import sagemaker
3+
from sagemaker import serializers
4+
from sagemaker.predictor import Predictor
5+
6+
7+
def invoke_endpoint(endpoint_name, prompt, max_tokens=2400, temperature=0.01):
8+
"""Invoke SageMaker endpoint with vLLM model for text generation"""
9+
try:
10+
predictor = Predictor(
11+
endpoint_name=endpoint_name,
12+
serializer=serializers.JSONSerializer(),
13+
)
14+
15+
payload = {
16+
"messages": [{"role": "user", "content": prompt}], # Chat format
17+
"max_tokens": max_tokens, # Response length limit
18+
"temperature": temperature, # Randomness (0=deterministic, 1=creative)
19+
"top_p": 0.9, # Nucleus sampling
20+
"top_k": 50, # Top-k sampling
21+
}
22+
23+
response = predictor.predict(payload)
24+
25+
# Handle different response formats
26+
if isinstance(response, bytes):
27+
response = response.decode("utf-8")
28+
29+
if isinstance(response, str):
30+
try:
31+
response = json.loads(response)
32+
except json.JSONDecodeError:
33+
print("Warning: Response is not valid JSON. Returning as string.")
34+
35+
return response
36+
37+
except Exception as e:
38+
print(f"Inference failed: {str(e)}")
39+
return None
40+
41+
42+
def main():
43+
endpoint_name = "<NAME>" # Replace with actual endpoint name
44+
45+
# Sample prompt for testing
46+
test_prompt = "Write a python code to generate n prime numbers"
47+
48+
print("Sending request to endpoint...")
49+
response = invoke_endpoint(
50+
endpoint_name=endpoint_name,
51+
prompt=test_prompt,
52+
max_tokens=2400, # Adjust based on expected response length
53+
temperature=0.01, # Low temperature for consistent code generation
54+
)
55+
56+
if response:
57+
print("\nResponse from endpoint:")
58+
if isinstance(response, (dict, list)):
59+
print(json.dumps(response, indent=2))
60+
else:
61+
print(response)
62+
else:
63+
print("No response received from the endpoint.")
64+
65+
66+
if __name__ == "__main__":
67+
main()
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#!/bin/bash
2+
3+
# Function to wait for server to be ready
4+
wait_for_server() {
5+
local host=$1
6+
local port=$2
7+
local timeout=120
8+
local count=0
9+
10+
echo "Waiting for server at $host:$port to be ready..."
11+
while ! curl -s http://$host:$port/health > /dev/null 2>&1; do
12+
sleep 5
13+
count=$((count + 5))
14+
if [ $count -ge $timeout ]; then
15+
echo "Timeout waiting for server at $host:$port"
16+
return 1
17+
fi
18+
done
19+
echo "Server at $host:$port is ready"
20+
}
21+
22+
# Start first GPU (prefiller)
23+
echo "Starting prefiller on GPU 0..."
24+
CUDA_VISIBLE_DEVICES=0 \
25+
UCX_NET_DEVICES=all \
26+
VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
27+
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
28+
--port 8100 \
29+
--max-model-len 6000 \
30+
--enforce-eager \
31+
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
32+
> vllm_gpu0.log 2>&1 &
33+
34+
# Start second GPU (decoder)
35+
echo "Starting decoder on GPU 1..."
36+
CUDA_VISIBLE_DEVICES=1 \
37+
UCX_NET_DEVICES=all \
38+
VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
39+
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
40+
--port 8200 \
41+
--max-model-len 6000 \
42+
--enforce-eager \
43+
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
44+
> vllm_gpu1.log 2>&1 &
45+
46+
47+
# Wait for GPU servers
48+
wait_for_server localhost 8100
49+
wait_for_server localhost 8200
50+
51+
# Start proxy server
52+
echo "Starting proxy server..."
53+
python3 proxy.py \
54+
--host 0.0.0.0 \
55+
--port 8192 \
56+
--prefiller-hosts localhost \
57+
--prefiller-ports 8100 \
58+
--decoder-hosts localhost \
59+
--decoder-ports 8200 \
60+
> proxy_server.log 2>&1 &
61+
62+
# Wait for proxy server
63+
wait_for_server localhost 8192
64+
65+
# wget -q https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
66+
vllm bench serve \
67+
--host 0.0.0.0 \
68+
--port 8192 \
69+
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
70+
--dataset-name sharegpt \
71+
--dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
72+
--num-prompts 30
73+

0 commit comments

Comments
 (0)