-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathopenai_api_server.py
227 lines (208 loc) · 7.59 KB
/
openai_api_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import argparse
import os
from fastapi import FastAPI
import uvicorn
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path',default=None,type=str)
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--load_in_8bit',action='store_true', help='use 8 bit model')
parser.add_argument('--only_cpu',action='store_true',help='only use CPU for inference')
args = parser.parse_args()
load_in_8bit = args.load_in_8bit
if args.only_cpu is True:
args.gpus = ""
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from peft import PeftModel
from openai_api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatMessage,
ChatCompletionResponseChoice,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
EmbeddingsRequest,
EmbeddingsResponse,
)
generation_config = dict(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.1,
max_new_tokens=400
)
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.lora_model
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path)
base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=load_in_8bit,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto' if not args.only_cpu else None,
)
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size!=tokenzier_vocab_size:
assert tokenzier_vocab_size > model_vocab_size
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model is not None:
print("loading peft model")
model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',)
else:
model = base_model
if device==torch.device('cpu'):
model.float()
model.eval()
def generate_completion_prompt(instruction: str):
"""Generate prompt for completion"""
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response: """
def generate_chat_prompt(messages: list):
"""Generate prompt for chat completion"""
system_msg = '''Below is an instruction that describes a task. Write a response that appropriately completes the request.'''
for msg in messages:
if msg.role == 'system':
system_msg = msg.message
prompt = f"{system_msg}\n\n"
for msg in messages:
if msg.role == 'system':
continue
if msg.role == 'assistant':
prompt += f"### Response: {msg.content}\n\n"
if msg.role == 'user':
prompt += f"### Instruction:\n{msg.content}\n\n"
prompt += "### Response: "
return prompt
def predict(
input,
max_new_tokens=128,
top_p=0.75,
temperature=0.1,
top_k=40,
num_beams=4,
repetition_penalty=1.0,
**kwargs,
):
"""
Main inference method
type(input) == str -> /v1/completions
type(input) == list -> /v1/chat/completions
"""
if isinstance(input, str):
prompt = generate_completion_prompt(input)
else:
prompt = generate_chat_prompt(input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=False,
max_new_tokens=max_new_tokens,
repetition_penalty=float(repetition_penalty),
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
output = output.split("### Response:")[-1].strip()
return output
def get_embedding(input):
"""Get embedding main function"""
with torch.no_grad():
if tokenizer.pad_token == None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
encoding = tokenizer(
input, padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
model_output = model(
input_ids, attention_mask, output_hidden_states=True
)
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
seq_length = torch.sum(mask, dim=1)
embedding = sum_embeddings / seq_length
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret = normalized_embeddings.squeeze(0).tolist()
return ret
app = FastAPI()
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""Creates a completion for the chat message"""
msgs = request.messages
if isinstance(msgs, str):
msgs = [ChatMessage(role='user',content=msgs)]
else:
msgs = [ChatMessage(role=x['role'],content=x['message']) for x in msgs]
output = predict(
input=msgs,
max_new_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
)
choices = [ChatCompletionResponseChoice(index = i, message = msg) for i, msg in enumerate(msgs)]
choices += [ChatCompletionResponseChoice(index = len(choices), message = ChatMessage(role='assistant',content=output))]
return ChatCompletionResponse(choices = choices)
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
"""Creates a completion"""
output = predict(
input=request.prompt,
max_new_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
num_beams=request.num_beams,
repetition_penalty=request.repetition_penalty,
)
choices = [CompletionResponseChoice(index = 0, text = output)]
return CompletionResponse(choices = choices)
@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingsRequest):
"""Creates text embedding"""
embedding = get_embedding(request.input)
data = [{
"object": "embedding",
"embedding": embedding,
"index": 0
}]
return EmbeddingsResponse(data=data)
if __name__ == "__main__":
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s"
uvicorn.run(app, host='0.0.0.0', port=19327, workers=1, log_config=log_config)