-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag.py
More file actions
86 lines (67 loc) · 2.75 KB
/
rag.py
File metadata and controls
86 lines (67 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# RUNNING IN TENSORFLOW ENV
# coher, or reranking, long-context-reorder method to abound missing in the middle
# uvicorn rag:app
from langchain import PromptTemplate
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
from langchain.embeddings import SentenceTransformerEmbeddings
from fastapi import FastAPI, Request, Form, Response
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from qdrant_client import QdrantClient
from langchain.vectorstores import Qdrant
import os
import json
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# app.mount("/static", StaticFiles(directory="static"), name="static")
local_llm = "BioMistral-7B.Q4_K_M.gguf"
config = {
'max_new_tokens': 2048,
'context_length': 2048,
'repetition_penalty': 1.1,
'temperature': 0.2,
'top_k': 50,
'top_p': 1,
'stream': True,
'threads': int(os.cpu_count() / 2)
}
llm = CTransformers(
model=local_llm,
model_type="mistral",
lib="avx2",
**config
)
print("LLM Initialized....")
prompt_template = """
Use the following pieces of medical information to provide accurate responses to the user's questions. Please refrain from speculation or providing false information.
Context: {context}
Question: {question}
Provide a concise and accurate answer based on the medical context.
Helpful answer:
"""
embeddings = SentenceTransformerEmbeddings(model_name="NeuML/pubmedbert-base-embeddings")
url = "http://localhost:6333"
client = QdrantClient(
url=url, prefer_grpc=False
)
db = Qdrant(client=client, embeddings=embeddings, collection_name="medical_db")
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question'])
retriever = db.as_retriever(search_kwargs={"k":1})
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/get_response")
async def get_response(query: str = Form(...)):
chain_type_kwargs = {"prompt": prompt}
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs=chain_type_kwargs, verbose=True)
response = qa(query)
print(response)
answer = response['result']
source_document = response['source_documents'][0].page_content
doc = response['source_documents'][0].metadata['source'] + " " + str(response['source_documents'][0].metadata['page']) + " " + response['source_documents'][0].metadata['_collection_name']
response_data = jsonable_encoder(json.dumps({"answer": answer, "source_document": source_document, "doc": doc}))
res = Response(response_data)
return res