-
Notifications
You must be signed in to change notification settings - Fork 56
Expand file tree
/
Copy pathmain.py
More file actions
118 lines (92 loc) · 3.55 KB
/
main.py
File metadata and controls
118 lines (92 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI
import logging
import os
import time
import torch
from huggingface_hub import login
from transformers import LlamaTokenizer, LlamaForCausalLM
from llama_recipes.inference.model_utils import load_peft_model
torch.set_float32_matmul_precision("high")
from api import (
ProcessRequest,
ProcessResponse,
TokenizeRequest,
TokenizeResponse,
Token,
)
app = FastAPI()
logger = logging.getLogger(__name__)
# Configure the logging module
logging.basicConfig(level=logging.INFO)
login(token=os.environ["HUGGINGFACE_TOKEN"])
model = LlamaForCausalLM.from_pretrained(
'meta-llama/Llama-2-7b-hf',
return_dict=True,
torch_dtype=torch.float16,
device_map="cuda"
)
model = load_peft_model(model, os.environ["HUGGINGFACE_REPO"])
model.eval()
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b')
LLAMA2_CONTEXT_LENGTH = 4096
@app.post("/process")
async def process_request(input_data: ProcessRequest) -> ProcessResponse:
if input_data.seed is not None:
torch.manual_seed(input_data.seed)
encoded = tokenizer(input_data.prompt, return_tensors="pt")
prompt_length = encoded["input_ids"][0].size(0)
max_returned_tokens = prompt_length + input_data.max_new_tokens
assert max_returned_tokens <= LLAMA2_CONTEXT_LENGTH, (
max_returned_tokens,
LLAMA2_CONTEXT_LENGTH,
)
t0 = time.perf_counter()
encoded = {k: v.to("cuda") for k, v in encoded.items()}
with torch.no_grad():
outputs = model.generate(
**encoded,
max_new_tokens=input_data.max_new_tokens,
do_sample=True,
temperature=input_data.temperature,
top_k=input_data.top_k,
return_dict_in_generate=True,
output_scores=True,
)
t = time.perf_counter() - t0
if not input_data.echo_prompt:
output = tokenizer.decode(outputs.sequences[0][prompt_length:], skip_special_tokens=True)
else:
output = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
tokens_generated = outputs.sequences[0].size(0) - prompt_length
logger.info(
f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec"
)
logger.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
generated_tokens = []
log_probs = torch.log(torch.stack(outputs.scores, dim=1).softmax(-1))
gen_sequences = outputs.sequences[:, encoded["input_ids"].shape[-1]:]
gen_logprobs = torch.gather(log_probs, 2, gen_sequences[:, :, None]).squeeze(-1)
top_indices = torch.argmax(log_probs, dim=-1)
top_logprobs = torch.gather(log_probs, 2, top_indices[:,:,None]).squeeze(-1)
top_indices = top_indices.tolist()[0]
top_logprobs = top_logprobs.tolist()[0]
for t, lp, tlp in zip(gen_sequences.tolist()[0], gen_logprobs.tolist()[0], zip(top_indices, top_logprobs)):
idx, val = tlp
tok_str = tokenizer.decode(idx)
token_tlp = {tok_str: val}
generated_tokens.append(
Token(text=tokenizer.decode(t), logprob=lp, top_logprob=token_tlp)
)
logprob_sum = gen_logprobs.sum().item()
return ProcessResponse(
text=output, tokens=generated_tokens, logprob=logprob_sum, request_time=t
)
@app.post("/tokenize")
async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse:
t0 = time.perf_counter()
encoded = tokenizer(
input_data.text
)
t = time.perf_counter() - t0
tokens = encoded["input_ids"]
return TokenizeResponse(tokens=tokens, request_time=t)