Skip to content

Commit 04e0b37

Browse files
lxningchauhangmsaroufimagunapalHamidShojanazeri
authored
llama2 70b chat accelerate example (#2494)
* llam2 accelerate example * add readme * fmt * fixing the padding and prompt * update steps * Updated readme with more details * changed to inheriting from basehandler * add model_path * change to int8 * add download cmd * update download path * minor edit for model_path --------- Co-authored-by: Geeta Chauhan <[email protected]> Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Ankith Gunapal <[email protected]> Co-authored-by: Hamid Shojanazeri <[email protected]>
1 parent 683608b commit 04e0b37

File tree

6 files changed

+224
-0
lines changed

6 files changed

+224
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Loading meta-llama/Llama-2-70b-chat-hf on AWS EC2 g5.24xlarge using accelerate
2+
3+
This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint).
4+
5+
### Step 1: Download model Permission
6+
7+
Follow [this instruction](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to get permission
8+
9+
Login with a Hugging Face account
10+
```
11+
huggingface-cli login
12+
# or using an environment variable
13+
huggingface-cli login --token $HUGGINGFACE_TOKEN
14+
```
15+
16+
```bash
17+
python ../Download_model.py --model_path model --model_name meta-llama/Llama-2-70b-chat-hf
18+
```
19+
Model will be saved in the following path, `model/models--meta-llama--Llama-2-70b-chat-hf`.
20+
21+
### Step 2: Generate MAR file
22+
23+
Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.
24+
25+
```bash
26+
torch-model-archiver --model-name llama2-70b-chat --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive
27+
```
28+
29+
If you are using conda, and notice issues with mpi4py, you would need to install openmpi-mpicc using the following
30+
31+
```
32+
conda install -c conda-forge openmpi-mpicc
33+
```
34+
35+
### Step 3: Add the mar file to model store
36+
37+
```bash
38+
mkdir model_store
39+
mv llama2-70b-chat model_store
40+
mv model model_store/llama2-70b-chat
41+
```
42+
43+
### Step 3: Start torchserve
44+
45+
Update config.properties and start torchserve
46+
47+
```bash
48+
torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama2-70b-chat
49+
```
50+
51+
### Step 4: Run inference
52+
53+
```bash
54+
curl -v "http://localhost:8080/predictions/llama2-70b-chat" -T sample_text.txt
55+
```
56+
57+
results in the following output
58+
```
59+
Mayonnaise is a thick, creamy condiment made from a mixture of egg yolks, oil, vinegar or lemon juice, and seasonings'
60+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
inference_address=http://0.0.0.0:8080
2+
management_address=http://0.0.0.0:8081
3+
metrics_address=http://0.0.0.0:8082
4+
enable_envvars_config=true
5+
install_py_dep_per_model=true
6+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import logging
2+
from abc import ABC
3+
4+
import torch
5+
import transformers
6+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7+
from accelerate import init_empty_weights
8+
from accelerate import load_checkpoint_and_dispatch
9+
10+
from ts.context import Context
11+
from ts.torch_handler.base_handler import BaseHandler
12+
13+
logger = logging.getLogger(__name__)
14+
logger.info("Transformers version %s", transformers.__version__)
15+
16+
17+
class LlamaHandler(BaseHandler, ABC):
18+
"""
19+
Transformers handler class for sequence, token classification and question answering.
20+
"""
21+
22+
def __init__(self):
23+
super(LlamaHandler, self).__init__()
24+
self.max_length = None
25+
self.max_new_tokens = None
26+
self.tokenizer = None
27+
self.initialized = False
28+
29+
def initialize(self, ctx: Context):
30+
"""In this initialize function, the HF large model is loaded and
31+
partitioned using DeepSpeed.
32+
Args:
33+
ctx (context): It is a JSON Object containing information
34+
pertaining to the model artifacts parameters.
35+
"""
36+
model_dir = ctx.system_properties.get("model_dir")
37+
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
38+
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
39+
model_name = ctx.model_yaml_config["handler"]["model_name"]
40+
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}'
41+
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
42+
torch.manual_seed(seed)
43+
44+
logger.info("Model %s loading tokenizer", ctx.model_name)
45+
self.model = AutoModelForCausalLM.from_pretrained(
46+
model_path,
47+
device_map="balanced",
48+
low_cpu_mem_usage=True,
49+
torch_dtype=torch.float16,
50+
load_in_8bit=True,
51+
trust_remote_code=True)
52+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
53+
self.tokenizer.add_special_tokens(
54+
{
55+
56+
"pad_token": "<PAD>",
57+
}
58+
)
59+
self.model.resize_token_embeddings(self.model.config.vocab_size + 1)
60+
61+
logger.info("Model %s loaded successfully", ctx.model_name)
62+
self.initialized = True
63+
64+
def preprocess(self, requests):
65+
"""
66+
Basic text preprocessing, based on the user's choice of application mode.
67+
Args:
68+
requests (list): A list of dictionaries with a "data" or "body" field, each
69+
containing the input text to be processed.
70+
Returns:
71+
tuple: A tuple with two tensors: the batch of input ids and the batch of
72+
attention masks.
73+
"""
74+
input_texts = [data.get("data") or data.get("body") for data in requests]
75+
input_ids_batch, attention_mask_batch = [], []
76+
for input_text in input_texts:
77+
input_ids, attention_mask = self.encode_input_text(input_text)
78+
input_ids_batch.append(input_ids)
79+
attention_mask_batch.append(attention_mask)
80+
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device)
81+
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
82+
return input_ids_batch, attention_mask_batch
83+
84+
def encode_input_text(self, input_text):
85+
"""
86+
Encodes a single input text using the tokenizer.
87+
Args:
88+
input_text (str): The input text to be encoded.
89+
Returns:
90+
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
91+
"""
92+
if isinstance(input_text, (bytes, bytearray)):
93+
input_text = input_text.decode("utf-8")
94+
logger.info("Received text: '%s'", input_text)
95+
inputs = self.tokenizer.encode_plus(
96+
input_text,
97+
max_length=self.max_length,
98+
padding=True,
99+
add_special_tokens=True,
100+
return_tensors="pt",
101+
truncation=True,
102+
)
103+
input_ids = inputs["input_ids"]
104+
attention_mask = inputs["attention_mask"]
105+
return input_ids, attention_mask
106+
107+
def inference(self, input_batch):
108+
"""
109+
Predicts the class (or classes) of the received text using the serialized transformers
110+
checkpoint.
111+
Args:
112+
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
113+
of attention masks, as returned by the preprocess function.
114+
Returns:
115+
list: A list of strings with the predicted values for each input text in the batch.
116+
"""
117+
input_ids_batch, attention_mask_batch = input_batch
118+
input_ids_batch = input_ids_batch.to(self.device)
119+
outputs = self.model.generate(
120+
input_ids_batch,
121+
attention_mask=attention_mask_batch,
122+
max_length=self.max_new_tokens,
123+
)
124+
125+
inferences = self.tokenizer.batch_decode(
126+
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
127+
)
128+
129+
logger.info("Generated text: %s", inferences)
130+
return inferences
131+
132+
def postprocess(self, inference_output):
133+
"""Post Process Function converts the predicted response into Torchserve readable format.
134+
Args:
135+
inference_output (list): It contains the predicted response of the input text.
136+
Returns:
137+
(list): Returns a list of the Predictions and Explanations.
138+
"""
139+
return inference_output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# TorchServe frontend parameters
2+
minWorkers: 1
3+
maxWorkers: 1
4+
maxBatchDelay: 100
5+
responseTimeout: 1200
6+
deviceType: "gpu"
7+
8+
handler:
9+
model_name: "meta-llama/Llama-2-70b-chat-hf"
10+
model_path: "model/models--meta-llama--Llama-2-70b-chat-hf/snapshots/9ff8b00464fc439a64bb374769dec3dd627be1c2"
11+
max_length: 50
12+
max_new_tokens: 50
13+
manual_seed: 40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
transformers==4.31.0
2+
accelerate
3+
bitsandbytes
4+
scipy
5+
mpi4py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
what is the recipe of mayonnaise?

0 commit comments

Comments
 (0)