generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathbyokg_query_engine.py
More file actions
282 lines (237 loc) · 13.6 KB
/
byokg_query_engine.py
File metadata and controls
282 lines (237 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
from typing import List, Tuple, Optional, Set
from .utils import load_yaml, parse_response
class ByoKGQueryEngine:
"""
A query engine that orchestrates the retrieval and generation pipeline for knowledge graph queries.
This class handles the high-level flow of query processing while delegating LLM-specific tasks
to the KGLinker.
"""
def __init__(self,
graph_store,
entity_linker=None,
triplet_retriever=None,
path_retriever=None,
graph_query_executor=None,
llm_generator=None,
kg_linker=None,
cypher_kg_linker=None,
direct_query_linking=False):
"""
Initialize the query engine.
Args:
graph_store: Component that provides access to graph data
entity_linker: Optional component for linking entities to graph nodes
triplet_retriever: Optional component for retrieving triplets
path_retriever: Optional component for retrieving paths
graph_query_executor: Optional component for executing graph queries
llm_generator: Optional language model for generating responses
kg_linker: Optional KG linker for multi-strategy retrieval
cypher_kg_linker: Optional Cypher KG linker for cypher-based retrieval
direct_query_linking: Flag whether to use entity linker with query embedding directly
"""
self.graph_store = graph_store
self.schema = graph_store.get_schema()
if llm_generator is None:
from .llm.bedrock_llms import BedrockGenerator
llm_generator= BedrockGenerator(
model_name='us.anthropic.claude-3-5-sonnet-20240620-v1:0',
region_name='us-west-2')
self.llm_generator = llm_generator
if entity_linker is None:
from .indexing import FuzzyStringIndex
from .graph_retrievers import EntityLinker
string_index = FuzzyStringIndex()
string_index.add(self.graph_store.nodes())
entity_retriever = string_index.as_entity_matcher()
entity_linker = EntityLinker(entity_retriever)
self.entity_linker = entity_linker
self.direct_query_linking = direct_query_linking
if triplet_retriever is None:
from .graph_retrievers import AgenticRetriever
from .graph_retrievers import GTraversal, TripletGVerbalizer
graph_traversal = GTraversal(self.graph_store)
graph_verbalizer = TripletGVerbalizer()
triplet_retriever = AgenticRetriever(
llm_generator=self.llm_generator,
graph_traversal=graph_traversal,
graph_verbalizer=graph_verbalizer)
self.triplet_retriever = triplet_retriever
if path_retriever is None:
from .graph_retrievers import PathRetriever
from .graph_retrievers import GTraversal, PathVerbalizer
graph_traversal = GTraversal(self.graph_store)
path_verbalizer = PathVerbalizer()
path_retriever = PathRetriever(
graph_traversal=graph_traversal,
path_verbalizer=path_verbalizer)
self.path_retriever = path_retriever
if graph_query_executor is None and hasattr(graph_store, 'execute_query'):
from .graph_retrievers import GraphQueryRetriever
graph_query_executor = GraphQueryRetriever(self.graph_store)
self.graph_query_executor = graph_query_executor
if kg_linker is None and cypher_kg_linker is None: #initialize KGLinker as default
from .graph_connectors import KGLinker
kg_linker = KGLinker(
llm_generator=self.llm_generator,
graph_store=self.graph_store
)
self.kg_linker = kg_linker
if self.kg_linker is not None:
self.kg_linker_prompts = self.kg_linker.task_prompts
self.kg_linker_prompts_iterative = self.kg_linker.task_prompts_iterative
if cypher_kg_linker is not None:
assert hasattr(cypher_kg_linker, "is_cypher_linker"), "cypher_kg_linker must be an instance of CypherKGLinker"
self.cypher_kg_linker = cypher_kg_linker
if self.cypher_kg_linker is not None:
self.cypher_kg_linker_prompts = self.cypher_kg_linker.task_prompts
self.cypher_kg_linker_prompts_iterative = self.cypher_kg_linker.task_prompts_iterative
def _add_to_context(self, context_list: List[str], new_items: List[str]) -> None:
"""
Add new items to context list while maintaining order and avoiding duplicates.
Args:
context_list: The list to add items to
new_items: New items to add
Returns:
None
"""
seen = set(context_list)
for item in new_items:
if item not in seen:
context_list.append(item)
seen.add(item)
def query(self, query: str, iterations: int = 2, cypher_iterations: int = 2, user_input: str = "") -> List[str]:
"""
Process a query through the retrieval and generation pipeline.
Args:
query: The search query
iterations: Number of retrieval iterations to perform
cypher_iterations: Number of cypher generation retries
user_input: Optional user input for additional instructions or context
Returns:
List containing retrieved context and final answers
"""
retrieved_context: List[str] = []
explored_entities: Set[str] = set()
opencypher_answers: List[str] = []
cypher_context_with_feedback: List[str] = []
if self.direct_query_linking:
semantic_linked_entities = self.entity_linker.link([query], return_dict=False)
explored_entities.update(semantic_linked_entities)
else:
semantic_linked_entities = []
# If cypher_kg_linker is provided, ByoKGQueryEngine tries to solve KGQA with cypher-based retrieval
if self.cypher_kg_linker is not None:
assert self.graph_query_executor is not None, "graph_query_executor must be initialized"
for iteration in range(cypher_iterations):
# Generate response for current iteration using iterative prompts after first iteration
if iteration == 0:
task_prompts = self.cypher_kg_linker.task_prompts
else:
task_prompts = self.cypher_kg_linker.task_prompts_iterative
response = self.cypher_kg_linker.generate_response(
question=query,
schema=self.schema,
graph_context="\n".join(cypher_context_with_feedback) if cypher_context_with_feedback else "",
task_prompts = task_prompts,
user_input=user_input
)
artifacts = self.cypher_kg_linker.parse_response(response)
# Check for task completion when using iterative prompts - do this early to avoid unnecessary query execution
task_completion = parse_response(response, r"<task-completion>(.*?)</task-completion>")
if "FINISH" in " ".join(task_completion):
break
if "opencypher-linking" in artifacts:
linking_query = " ".join(artifacts["opencypher-linking"])
context, linked_entities_cypher = self.graph_query_executor.retrieve(linking_query, return_answers=True)
cypher_context_with_feedback += context
if len(linked_entities_cypher) == 0:
# Check if the context contains an actual execution error
has_error = any("Error" in c and "Error executing query" in c for c in context)
if has_error:
cypher_context_with_feedback.append("The above cypher query for entity linking failed with an error. Please review the error message and fix the query syntax or schema references.")
else:
cypher_context_with_feedback.append("No executable results for the above cypher query for entity linking. Please improve cypher generation in the future for linking.")
if "opencypher" in artifacts:
graph_query = " ".join(artifacts["opencypher"])
context, answers = self.graph_query_executor.retrieve(graph_query, return_answers=True)
cypher_context_with_feedback += context
if len(answers) == 0:
has_error = any("Error" in c and "Error executing query" in c for c in context)
if has_error:
cypher_context_with_feedback.append("The above cypher query failed with an error. Please review the error message and fix the query syntax or schema references.")
else:
cypher_context_with_feedback.append("No executable results for the above. Please improve cypher generation in the future by focusing more on the given schema and the relations between node types.")
if self.kg_linker is None:
return cypher_context_with_feedback
# TODO : Combine cypher linker with KG linker dynamically
# If kg_linker is provided, ByoKGQueryEngine tries to solve KGQA with multi-strategy retrieval
for iteration in range(iterations):
# Generate response for current iteration
if iteration == 0:
task_prompts = self.kg_linker_prompts
else:
task_prompts = self.kg_linker_prompts_iterative
response = self.kg_linker.generate_response(
question=query,
schema=self.schema,
graph_context="\n".join(retrieved_context) if retrieved_context else "",
task_prompts = task_prompts,
user_input=user_input
)
artifacts = self.kg_linker.parse_response(response)
# Process extracted entities
linked_entities = []
if "entity-extraction" in artifacts and artifacts["entity-extraction"] and "FINISH" not in artifacts["entity-extraction"][0]:
linked_entities = self.entity_linker.link(artifacts["entity-extraction"], return_dict=False)
explored_entities.update(linked_entities)
# Process answer entities
linked_answers = []
if "draft-answer-generation" in artifacts and artifacts["draft-answer-generation"]:
linked_answers = self.entity_linker.link(artifacts["draft-answer-generation"], return_dict=False)
# Retrieve triplets if we have source entities
source_entities = list(set(semantic_linked_entities + linked_entities + linked_answers))
if source_entities and self.triplet_retriever:
triplet_context = self.triplet_retriever.retrieve(query, source_entities)
self._add_to_context(retrieved_context, triplet_context)
# Process paths if available
if "path-extraction" in artifacts and artifacts["path-extraction"] and explored_entities and self.path_retriever:
metapaths = [[component.strip() for component in path.split("->")] for path in artifacts["path-extraction"]]
path_context = self.path_retriever.retrieve(list(explored_entities), metapaths,linked_answers)
self._add_to_context(retrieved_context, path_context)
# Process graph queries
for query_type in ["opencypher", "opencypher-neptune-rdf", "opencypher-neptune"]:
if query_type in artifacts and self.graph_query_executor:
graph_query = " ".join(artifacts[query_type])
context = self.graph_query_executor.retrieve(graph_query, return_answers=False)
self._add_to_context(retrieved_context, context)
task_completion = parse_response(response, r"<task-completion>(.*?)</task-completion>")
if "FINISH" in " ".join(task_completion):
break
return cypher_context_with_feedback + retrieved_context
def generate_response(self, query: str, graph_context: str = "", task_prompt = None, user_input: str = "") -> Tuple[List[str], str]:
"""
Generate a response using the LLM based on the query and graph context.
Args:
query: The search query
graph_context: Retrieved graph context to use for generation
task_prompt: Optional custom task prompt. If None, uses default generation prompt
user_input: Optional user input for additional instructions or context
Returns:
tuple: (list of answers, full response text)
Raises:
NotImplementedError: If custom task_prompt is provided (not yet supported)
"""
if task_prompt is None:
task_prompt = load_yaml("prompts/generation_prompts.yaml")["generate-response-qa"]
user_prompt_formatted = task_prompt.format(
question=query,
graph_context=graph_context,
user_input=user_input
)
response = self.llm_generator.generate(
prompt=user_prompt_formatted,
)
answers = parse_response(response, r"<answers>(.*?)</answers>")
return answers, response
else:
raise NotImplementedError