Skip to content

Commit b55d2d6

Browse files
combiner scripts
1 parent fd2b4b5 commit b55d2d6

File tree

5 files changed

+213
-158
lines changed

5 files changed

+213
-158
lines changed

vllm-samples/sagemaker/README.md

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Deploy and run inference on vLLM models using AWS SageMaker and vLLM DLC.
44

55
## Files
66

7-
- `endpoint.py` - Deploy vLLM model to SageMaker endpoint
8-
- `inference.py` - Run inference against deployed endpoint
7+
- `deploy_and_test_sm_endpoint.py` - Complete workflow: deploy, inference, and cleanup
8+
- `testNixlConnector.sh` - Multi-GPU NixlConnector test script
99

1010
## Prerequisites
1111

@@ -17,24 +17,8 @@ Deploy and run inference on vLLM models using AWS SageMaker and vLLM DLC.
1717
### Create IAM Role
1818

1919
```bash
20-
# Create trust policy
21-
cat > trust-policy.json << EOF
22-
{
23-
"Version": "2012-10-17",
24-
"Statement": [
25-
{
26-
"Effect": "Allow",
27-
"Principal": {
28-
"Service": "sagemaker.amazonaws.com"
29-
},
30-
"Action": "sts:AssumeRole"
31-
}
32-
]
33-
}
34-
EOF
35-
3620
# Create role
37-
aws iam create-role --role-name SageMakerExecutionRole --assume-role-policy-document file://trust-policy.json
21+
aws iam create-role --role-name SageMakerExecutionRole
3822

3923
# Attach policies
4024
aws iam attach-role-policy --role-name SageMakerExecutionRole --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
@@ -43,39 +27,42 @@ aws iam attach-role-policy --role-name SageMakerExecutionRole --policy-arn arn:a
4327

4428
## Quick Start
4529

46-
### 1. Get Latest Image URI
30+
### 1. Set Environment Variables
4731

4832
```bash
4933
# Check available images: https://gallery.ecr.aws/deep-learning-containers/vllm
50-
# Get latest vLLM DLC image URI
5134
export CONTAINER_URI="public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.1"
35+
export IAM_ROLE="SageMakerExecutionRole"
36+
export HF_TOKEN="your-huggingface-token"
5237
```
5338

54-
### 2. Deploy Endpoint
55-
56-
```bash
57-
# update variables in endpoint.py and run
58-
python endpoint.py
59-
```
60-
61-
### 3. Run Inference
39+
### 2. Run Complete Workflow
6240

6341
```bash
64-
# update endpoint_name in inference.py and run
65-
python inference.py
42+
# Deploy, run inference, and cleanup automatically
43+
python deploy_and_test_sm_endpoint.py --endpoint-name vllm-test-$(date +%s) --prompt "Write a Python function to calculate fibonacci numbers"
44+
45+
# Alternate with custom parameters
46+
python deploy_and_test_sm_endpoint.py \
47+
--endpoint-name my-vllm-endpoint \
48+
--model-id deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
49+
--instance-type ml.g5.12xlarge \
50+
--prompt "Explain machine learning" \
51+
--max-tokens 1000 \
52+
--temperature 0.7
6653
```
6754

68-
## Configuration
69-
70-
### Model Parameters
71-
- `SM_VLLM_MODEL` - HuggingFace model ID
72-
- `SM_VLLM_HF_TOKEN` - HuggingFace access token
55+
## Command Line Options
7356

74-
### Inference Parameters
75-
- `max_tokens` - Maximum response length
76-
- `temperature` - Sampling randomness (0-1)
77-
- `top_p` - Nucleus sampling threshold
78-
- `top_k` - Top-k sampling limit
57+
- `--endpoint-name` - SageMaker endpoint name (required)
58+
- `--container-uri` - DLC image URI (default from env)
59+
- `--iam-role` - IAM role ARN (default from env)
60+
- `--instance-type` - Instance type (default: ml.g5.12xlarge)
61+
- `--model-id` - HuggingFace model ID (default: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B)
62+
- `--hf-token` - HuggingFace token (default from env)
63+
- `--prompt` - Inference prompt (default: code generation example)
64+
- `--max-tokens` - Maximum response length (default: 2400)
65+
- `--temperature` - Sampling randomness 0-1 (default: 0.01)
7966

8067
## Instance Types
8168

@@ -90,22 +77,20 @@ Test NixlConnector locally - [NixlConnector Documentation](https://docs.vllm.ai/
9077

9178
```bash
9279
# Pull latest vLLM DLC for EC2
93-
docker pull public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.1
80+
docker pull public.ecr.aws/deep-learning-containers/vllm:0.11-gpu-py312
9481

9582
# Run container with GPU access
9683
docker run -it --entrypoint=/bin/bash --gpus=all \
9784
-v $(pwd):/workspace \
98-
public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.1
85+
public.ecr.aws/deep-learning-containers/vllm:0.11-gpu-py312
9986

10087
# Inside container, run the NixlConnector test
10188
export HF_TOKEN= "<TOKEN>"
10289
./testNixlConnector.sh
10390
```
10491

105-
## Cleanup
92+
## Notes
10693

107-
```python
108-
import boto3
109-
sagemaker = boto3.client('sagemaker')
110-
sagemaker.delete_endpoint(EndpointName='<endpoint-name>')
111-
```
94+
- The script automatically cleans up resources after inference to avoid ongoing costs
95+
- Deployment waits for endpoint to be ready before running inference
96+
- All parameters can be set via environment variables or command line arguments
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import argparse
2+
import json
3+
import os
4+
import sagemaker
5+
from sagemaker.model import Model
6+
from sagemaker import serializers
7+
from sagemaker.predictor import Predictor
8+
9+
10+
def deploy_endpoint(
11+
endpoint_name, container_uri, iam_role, instance_type, model_id, hf_token
12+
):
13+
"""Deploy vLLM model to SageMaker endpoint"""
14+
try:
15+
print(f"Starting deployment of endpoint: {endpoint_name}")
16+
print(f"Using image: {container_uri}")
17+
print(f"Instance type: {instance_type}")
18+
19+
print("Creating SageMaker model...")
20+
model = Model(
21+
name=endpoint_name,
22+
image_uri=container_uri,
23+
role=iam_role,
24+
env={
25+
"SM_VLLM_MODEL": model_id, # Model to load
26+
"SM_VLLM_HF_TOKEN": hf_token, # HuggingFace token for model access
27+
},
28+
)
29+
print("Model created successfully")
30+
print("Starting endpoint deployment (this may take 10-15 minutes)...")
31+
32+
model.deploy(
33+
instance_type=instance_type,
34+
initial_instance_count=1,
35+
endpoint_name=endpoint_name,
36+
wait=True, # Wait for deployment to complete
37+
)
38+
print(f"Endpoint {endpoint_name} deployed successfully")
39+
return True
40+
except Exception as e:
41+
print(f"Deployment failed: {str(e)}")
42+
return False
43+
44+
45+
def cleanup_endpoint(endpoint_name):
46+
"""Delete SageMaker endpoint and model"""
47+
try:
48+
import boto3
49+
50+
sagemaker_client = boto3.client("sagemaker")
51+
52+
print(f"Cleaning up endpoint: {endpoint_name}")
53+
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
54+
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
55+
sagemaker_client.delete_model(ModelName=endpoint_name)
56+
print(f"Endpoint {endpoint_name} cleaned up successfully")
57+
return True
58+
except Exception as e:
59+
print(f"Cleanup failed: {str(e)}")
60+
return False
61+
62+
63+
def invoke_endpoint(endpoint_name, prompt, max_tokens=2400, temperature=0.01):
64+
"""Invoke SageMaker endpoint with vLLM model for text generation"""
65+
try:
66+
predictor = Predictor(
67+
endpoint_name=endpoint_name,
68+
serializer=serializers.JSONSerializer(),
69+
)
70+
71+
payload = {
72+
"messages": [{"role": "user", "content": prompt}], # Chat format
73+
"max_tokens": max_tokens, # Response length limit
74+
"temperature": temperature, # Randomness (0=deterministic, 1=creative)
75+
"top_p": 0.9, # Nucleus sampling
76+
"top_k": 50, # Top-k sampling
77+
}
78+
79+
response = predictor.predict(payload)
80+
81+
# Handle different response formats
82+
if isinstance(response, bytes):
83+
response = response.decode("utf-8")
84+
85+
if isinstance(response, str):
86+
try:
87+
response = json.loads(response)
88+
except json.JSONDecodeError:
89+
print("Warning: Response is not valid JSON. Returning as string.")
90+
91+
return response
92+
93+
except Exception as e:
94+
print(f"Inference failed: {str(e)}")
95+
return None
96+
97+
98+
def main():
99+
parser = argparse.ArgumentParser(description="SageMaker vLLM Inference")
100+
parser.add_argument(
101+
"--endpoint-name", required=True, help="SageMaker endpoint name"
102+
)
103+
parser.add_argument(
104+
"--container-uri",
105+
help="DLC image URI",
106+
default=os.getenv(
107+
"CONTAINER_URI",
108+
"public.ecr.aws/deep-learning-containers/vllm:0.11.0-gpu-py312",
109+
),
110+
)
111+
parser.add_argument(
112+
"--iam-role", help="IAM role ARN", default=os.getenv("IAM_ROLE")
113+
)
114+
parser.add_argument(
115+
"--instance-type", default="ml.g5.12xlarge", help="Instance type"
116+
)
117+
parser.add_argument(
118+
"--model-id",
119+
default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
120+
help="HuggingFace model ID",
121+
)
122+
parser.add_argument(
123+
"--hf-token", help="HuggingFace token", default=os.getenv("HF_TOKEN", "")
124+
)
125+
parser.add_argument(
126+
"--prompt",
127+
default="Write a python code to generate n prime numbers",
128+
help="Inference prompt",
129+
)
130+
parser.add_argument("--max-tokens", type=int, default=2400, help="Maximum tokens")
131+
parser.add_argument(
132+
"--temperature", type=float, default=0.01, help="Sampling temperature"
133+
)
134+
135+
args = parser.parse_args()
136+
137+
if not args.iam_role:
138+
print("Error: IAM role required")
139+
return
140+
141+
# Deploy endpoint
142+
if not deploy_endpoint(
143+
args.endpoint_name,
144+
args.container_uri,
145+
args.iam_role,
146+
args.instance_type,
147+
args.model_id,
148+
args.hf_token,
149+
):
150+
return
151+
152+
# Run inference
153+
print("\nSending request to endpoint...")
154+
response = invoke_endpoint(
155+
endpoint_name=args.endpoint_name,
156+
prompt=args.prompt,
157+
max_tokens=args.max_tokens,
158+
temperature=args.temperature,
159+
)
160+
161+
if response:
162+
print("\nResponse from endpoint:")
163+
if isinstance(response, (dict, list)):
164+
print(json.dumps(response, indent=2))
165+
else:
166+
print(response)
167+
else:
168+
print("No response received from the endpoint.")
169+
170+
# Cleanup
171+
print("\nCleaning up resources...")
172+
cleanup_endpoint(args.endpoint_name)
173+
174+
175+
if __name__ == "__main__":
176+
main()

vllm-samples/sagemaker/endpoint.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)