Skip to content

Commit fe02bf9

Browse files
authored
Merge pull request #1693 from siiddhantt/fix/response-and-sources
feat: agent use in answer and enhance search
2 parents 46d32b4 + faa5838 commit fe02bf9

File tree

2 files changed

+88
-34
lines changed

2 files changed

+88
-34
lines changed

application/agents/classic_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _gen_inner(
107107
if isinstance(line, str):
108108
yield {"answer": line}
109109

110+
yield {"sources": retrieved_data}
110111
yield {"tool_calls": self.tool_calls.copy()}
111112

112113
def _retriever_search(self, retriever, query, log_context):

application/api/answer/routes.py

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,14 @@ def is_azure_configured():
116116

117117

118118
def save_conversation(
119-
conversation_id, question, response, source_log_docs, tool_calls, llm, index=None, api_key=None
119+
conversation_id,
120+
question,
121+
response,
122+
source_log_docs,
123+
tool_calls,
124+
llm,
125+
index=None,
126+
api_key=None,
120127
):
121128
current_time = datetime.datetime.now(datetime.timezone.utc)
122129
if conversation_id is not None and index is not None:
@@ -128,7 +135,7 @@ def save_conversation(
128135
f"queries.{index}.response": response,
129136
f"queries.{index}.sources": source_log_docs,
130137
f"queries.{index}.tool_calls": tool_calls,
131-
f"queries.{index}.timestamp": current_time
138+
f"queries.{index}.timestamp": current_time,
132139
}
133140
},
134141
)
@@ -147,7 +154,7 @@ def save_conversation(
147154
"response": response,
148155
"sources": source_log_docs,
149156
"tool_calls": tool_calls,
150-
"timestamp": current_time
157+
"timestamp": current_time,
151158
}
152159
}
153160
},
@@ -182,15 +189,17 @@ def save_conversation(
182189
"response": response,
183190
"sources": source_log_docs,
184191
"tool_calls": tool_calls,
185-
"timestamp": current_time
192+
"timestamp": current_time,
186193
}
187194
],
188195
}
189196
if api_key:
190197
api_key_doc = api_key_collection.find_one({"key": api_key})
191198
if api_key_doc:
192199
conversation_data["api_key"] = api_key_doc["key"]
193-
conversation_id = conversations_collection.insert_one(conversation_data).inserted_id
200+
conversation_id = conversations_collection.insert_one(
201+
conversation_data
202+
).inserted_id
194203
return conversation_id
195204

196205

@@ -205,36 +214,42 @@ def get_prompt(prompt_id):
205214
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
206215
return prompt
207216

217+
208218
def complete_stream(
209-
question,
219+
question,
210220
agent,
211-
retriever,
212-
conversation_id,
213-
user_api_key,
214-
isNoneDoc=False,
221+
retriever,
222+
conversation_id,
223+
user_api_key,
224+
isNoneDoc=False,
215225
index=None,
216-
should_save_conversation=True
226+
should_save_conversation=True,
217227
):
218228
try:
219229
response_full = ""
220230
source_log_docs = []
221231
tool_calls = []
232+
222233
answer = agent.gen(query=question, retriever=retriever)
223-
sources = retriever.search(question)
224-
for source in sources:
225-
if "text" in source:
226-
source["text"] = source["text"][:100].strip() + "..."
227-
if len(sources) > 0:
228-
data = json.dumps({"type": "source", "source": sources})
229-
yield f"data: {data}\n\n"
230234

231235
for line in answer:
232236
if "answer" in line:
233237
response_full += str(line["answer"])
234-
data = json.dumps(line)
238+
data = json.dumps({"type": "answer", "answer": line["answer"]})
235239
yield f"data: {data}\n\n"
236-
elif "source" in line:
237-
source_log_docs.append(line["source"])
240+
elif "sources" in line:
241+
truncated_sources = []
242+
source_log_docs = line["sources"]
243+
for source in line["sources"]:
244+
truncated_source = source.copy()
245+
if "text" in truncated_source:
246+
truncated_source["text"] = (
247+
truncated_source["text"][:100].strip() + "..."
248+
)
249+
truncated_sources.append(truncated_source)
250+
if len(truncated_sources) > 0:
251+
data = json.dumps({"type": "source", "source": truncated_sources})
252+
yield f"data: {data}\n\n"
238253
elif "tool_calls" in line:
239254
tool_calls = line["tool_calls"]
240255
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
@@ -245,11 +260,9 @@ def complete_stream(
245260
doc["source"] = "None"
246261

247262
llm = LLMCreator.create_llm(
248-
settings.LLM_NAME,
249-
api_key=settings.API_KEY,
250-
user_api_key=user_api_key
263+
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key
251264
)
252-
265+
253266
if should_save_conversation:
254267
conversation_id = save_conversation(
255268
conversation_id,
@@ -259,7 +272,7 @@ def complete_stream(
259272
tool_calls,
260273
llm,
261274
index,
262-
api_key=user_api_key
275+
api_key=user_api_key,
263276
)
264277
else:
265278
conversation_id = None
@@ -523,9 +536,19 @@ def post(self):
523536
extra={"data": json.dumps({"request_data": data, "source": source})},
524537
)
525538

539+
agent = AgentCreator.create_agent(
540+
settings.AGENT_NAME,
541+
endpoint="api/answer",
542+
llm_name=settings.LLM_NAME,
543+
gpt_model=gpt_model,
544+
api_key=settings.API_KEY,
545+
user_api_key=user_api_key,
546+
prompt=prompt,
547+
chat_history=history,
548+
)
549+
526550
retriever = RetrieverCreator.create_retriever(
527551
retriever_name,
528-
question=question,
529552
source=source,
530553
chat_history=history,
531554
prompt=prompt,
@@ -538,13 +561,41 @@ def post(self):
538561
response_full = ""
539562
source_log_docs = []
540563
tool_calls = []
541-
for line in retriever.gen():
542-
if "source" in line:
543-
source_log_docs.append(line["source"])
544-
elif "answer" in line:
545-
response_full += line["answer"]
546-
elif "tool_calls" in line:
547-
tool_calls.append(line["tool_calls"])
564+
stream_ended = False
565+
566+
for line in complete_stream(
567+
question=question,
568+
agent=agent,
569+
retriever=retriever,
570+
conversation_id=conversation_id,
571+
user_api_key=user_api_key,
572+
isNoneDoc=data.get("isNoneDoc"),
573+
index=None,
574+
should_save_conversation=False,
575+
):
576+
try:
577+
event_data = line.replace("data: ", "").strip()
578+
event = json.loads(event_data)
579+
580+
if event["type"] == "answer":
581+
response_full += event["answer"]
582+
elif event["type"] == "source":
583+
source_log_docs = event["source"]
584+
elif event["type"] == "tool_calls":
585+
tool_calls = event["tool_calls"]
586+
elif event["type"] == "error":
587+
logger.error(f"Error from stream: {event['error']}")
588+
return bad_request(500, event["error"])
589+
elif event["type"] == "end":
590+
stream_ended = True
591+
592+
except (json.JSONDecodeError, KeyError) as e:
593+
logger.warning(f"Error parsing stream event: {e}, line: {line}")
594+
continue
595+
596+
if not stream_ended:
597+
logger.error("Stream ended unexpectedly without an 'end' event.")
598+
return bad_request(500, "Stream ended unexpectedly.")
548599

549600
if data.get("isNoneDoc"):
550601
for doc in source_log_docs:
@@ -563,8 +614,10 @@ def post(self):
563614
source_log_docs,
564615
tool_calls,
565616
llm,
617+
api_key=user_api_key,
566618
)
567619
)
620+
568621
retriever_params = retriever.get_params()
569622
user_logs_collection.insert_one(
570623
{

0 commit comments

Comments
 (0)