-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathapp_graph.py
More file actions
154 lines (125 loc) · 6.06 KB
/
app_graph.py
File metadata and controls
154 lines (125 loc) · 6.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from typing import TypedDict, List
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from router import route, RouteQuery
from ingestion import PDFIngestor
class GraphState(TypedDict):
question: str
response: str
documents: List[str]
class PdfChat:
def __init__(self, api_key, retriever):
self.model = ChatOpenAI(api_key=api_key, model="gpt-4o-mini-2024-07-18", temperature=0) # minimum costing model
builder = StateGraph(GraphState)
builder.add_node("retrieve", self.retrieve_node)
builder.add_node("boost_question", self.boost_question)
builder.add_node("structer_document", self.structer_document)
builder.add_node("generate_with_rag", self.generate_with_doc)
builder.add_node("generate", self.generate_wo_doc)
builder.set_entry_point("boost_question")
builder.add_conditional_edges(
"boost_question",
self.decide_retrieve,
{
"retrieve": "retrieve",
"generate": "generate"
}
)
builder.add_edge("retrieve", "structer_document")
builder.add_edge("structer_document", "generate_with_rag")
builder.add_edge("generate_with_rag", END)
builder.add_edge("generate", END)
self.retriever = retriever
self.graph = builder.compile()
self.graph.get_graph().draw_mermaid_png(output_file_path="graph.png")
self.memory = ConversationBufferMemory()
def decide_retrieve(self, state: GraphState):
question = state["question"]
memory = self.memory.load_memory_variables({})
source: RouteQuery = route(self.model, memory).invoke({"question": question})
if source.datasource == "vectorstore":
return "retrieve"
else:
return "generate"
def boost_question(self, state: GraphState):
question = state["question"]
memory = self.memory.load_memory_variables({})
prompt = """You are an assistant in a question-answering tasks.
You have to boost the question to help search in vectorstore.
Don't make up random names.
Return a better structred question for vectorstore search, but don't make it longer
\n
Conversation history: {memory}
\n
Question: {question}
"""
prompt = PromptTemplate.from_template(prompt)
chain = prompt | self.model | StrOutputParser()
question = chain.invoke({"question": question, "memory": memory})
return {"question": question}
def retrieve_node(self, state: GraphState):
question = state["question"]
documents = self.retriever.invoke(question)
if not documents:
return "I couldn't find any relevant documents. Can you please rephrase your question?"
return {"documents": documents}
def structer_document(self, state: GraphState):
documents = state["documents"]
question = state["question"]
documents = [doc.page_content for doc in documents]
prompt = """You are an expert assistant for question-answering tasks.
You have to restructure the documents for the question.
Keep it short, only knowledge that is relevant to the question.
Don't make up random names.
Return a better structured document for better understanding.
\n
Documents: {documents}
\n
Question: {question}
"""
prompt = PromptTemplate.from_template(prompt)
chain = prompt | self.model | StrOutputParser()
document = chain.invoke({"question": question, "documents": documents})
return {"documents": document}
def generate_with_doc(self, state: GraphState):
documents = state["documents"]
question = state["question"]
memory = self.memory.load_memory_variables({})
prompt = """"You are an expert assistant for question-answering tasks.
Use the provided documents as context to extract and answer the question.
Don't be lazy, check every details in the Context.
If the answer is not mentioned in context, respond with 'I don't know.'
Keep your limited to three sentences.
\n
Conversation history: {memory}
\n
Context: {context}
\n
Question: {question}
\n
Answer:
"""
prompt = PromptTemplate.from_template(prompt)
chain = prompt | self.model | StrOutputParser()
response = chain.invoke({"memory": memory, "question": question, "context": documents})
self.memory.save_context(inputs={"input": question}, outputs={"output": response})
return {"response": response}
def generate_wo_doc(self, state: GraphState):
question = state["question"]
prompt = """You are an assistant for question-answering tasks.
If you don't know the answer, just say that you don't know.
Don't forget to check previous conversations for context.
Use three sentences maximum and keep the answer concise.
Conversation history: {memory}
Question: {question}
"""
memory = self.memory.load_memory_variables({})
prompt = PromptTemplate.from_template(prompt)
chain = prompt | self.model | StrOutputParser()
response = chain.invoke({"memory": memory, "question": question})
self.memory.save_context(inputs={"input": question}, outputs={"output": response})
return {"response": response}