-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama_chatbot.py
72 lines (57 loc) · 2.24 KB
/
llama_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torch import bfloat16
import torch
import transformers
from transformers import AutoTokenizer
import os
from huggingface_hub import login
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
class LLAMA_Chat:
def __init__(self):
if "HF_ACCESS_TOKEN" in os.environ:
login(token=os.environ["HF_ACCESS_TOKEN"])
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
self.tokenizer = tokenizer
bnb_config = transformers.BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_use_double_quant=True,
bnb_8bit_compute_dtype=bfloat16,
)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map=device,
)
def ask(self, input_data):
system_message = "You are a helpful chatbot."
formatted_chat = [
{"role": "system", "content": system_message},
{"role": "user", "content": f"{input_data}"},
]
tokenizer = self.tokenizer
tokenized_prompt = tokenizer.apply_chat_template(formatted_chat, add_generation_prompt=True, return_tensors="pt", max_length=1000, return_dict=True).to("cuda")
if 'llama' in MODEL_ID.lower():
if 'token_type_ids' in tokenized_prompt:
del tokenized_prompt['token_type_ids']
outputs = self.model.generate(
**tokenized_prompt,
max_new_tokens=1000,
do_sample = False
)
generated_sequence = outputs[0]
full_answer = self.tokenizer.decode(
generated_sequence, skip_special_tokens=True)
return full_answer
if __name__ == "__main__":
bot = LLAMA_Chat()
print("\n LLaMA-3 Chat is ready! Type your question (type 'exit' to quit):\n")
while True:
user_input = input("User: ").strip()
if user_input.lower() in ["exit", "quit"]:
print("Goodbye!")
break
answer = bot.ask(user_input)
print("Bot:", answer)