-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprivateGPT.py
99 lines (79 loc) · 3.74 KB
/
privateGPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.llms import GPT4All, LlamaCpp, OpenAI
import os
import argparse
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
from constants import CHROMA_SETTINGS
prompt_template = """
Your are a second generation and succesfull entrepreneur. You are an expert at Lean Startups.
Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Always answer from the perspective of an expert and experienced entrepreneur.
Never ask replying: "As an AI language model"
----------------
{context}
Question: {question}
Answer in English:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
def main():
# Parse the command line arguments
args = parse_arguments()
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
# Prepare the LLM
if model_type == "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False, n_threads=8)
elif model_type == "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
elif model_type == "GPT3":
llm = OpenAI(temperature = 0.0, max_tokens=3000)
else:
print(f"Model {model_type} not supported!")
exit()
chain_type_kwargs = {"prompt": PROMPT}
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source, chain_type_kwargs=chain_type_kwargs)
# Interactive questions and answers
while True:
query = input("\n🚀🪱 Enter a query: ")
if query == "exit":
break
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
def parse_arguments():
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
'using the power of LLMs.')
parser.add_argument("--hide-source", "-S", action='store_true',
help='Use this flag to disable printing of source documents used for answers.')
parser.add_argument("--mute-stream", "-M",
action='store_true',
help='Use this flag to disable the streaming StdOut callback for LLMs.')
return parser.parse_args()
if __name__ == "__main__":
main()