@@ -116,7 +116,14 @@ def is_azure_configured():
116116
117117
118118def save_conversation (
119- conversation_id , question , response , source_log_docs , tool_calls , llm , index = None , api_key = None
119+ conversation_id ,
120+ question ,
121+ response ,
122+ source_log_docs ,
123+ tool_calls ,
124+ llm ,
125+ index = None ,
126+ api_key = None ,
120127):
121128 current_time = datetime .datetime .now (datetime .timezone .utc )
122129 if conversation_id is not None and index is not None :
@@ -128,7 +135,7 @@ def save_conversation(
128135 f"queries.{ index } .response" : response ,
129136 f"queries.{ index } .sources" : source_log_docs ,
130137 f"queries.{ index } .tool_calls" : tool_calls ,
131- f"queries.{ index } .timestamp" : current_time
138+ f"queries.{ index } .timestamp" : current_time ,
132139 }
133140 },
134141 )
@@ -147,7 +154,7 @@ def save_conversation(
147154 "response" : response ,
148155 "sources" : source_log_docs ,
149156 "tool_calls" : tool_calls ,
150- "timestamp" : current_time
157+ "timestamp" : current_time ,
151158 }
152159 }
153160 },
@@ -182,15 +189,17 @@ def save_conversation(
182189 "response" : response ,
183190 "sources" : source_log_docs ,
184191 "tool_calls" : tool_calls ,
185- "timestamp" : current_time
192+ "timestamp" : current_time ,
186193 }
187194 ],
188195 }
189196 if api_key :
190197 api_key_doc = api_key_collection .find_one ({"key" : api_key })
191198 if api_key_doc :
192199 conversation_data ["api_key" ] = api_key_doc ["key" ]
193- conversation_id = conversations_collection .insert_one (conversation_data ).inserted_id
200+ conversation_id = conversations_collection .insert_one (
201+ conversation_data
202+ ).inserted_id
194203 return conversation_id
195204
196205
@@ -205,36 +214,42 @@ def get_prompt(prompt_id):
205214 prompt = prompts_collection .find_one ({"_id" : ObjectId (prompt_id )})["content" ]
206215 return prompt
207216
217+
208218def complete_stream (
209- question ,
219+ question ,
210220 agent ,
211- retriever ,
212- conversation_id ,
213- user_api_key ,
214- isNoneDoc = False ,
221+ retriever ,
222+ conversation_id ,
223+ user_api_key ,
224+ isNoneDoc = False ,
215225 index = None ,
216- should_save_conversation = True
226+ should_save_conversation = True ,
217227):
218228 try :
219229 response_full = ""
220230 source_log_docs = []
221231 tool_calls = []
232+
222233 answer = agent .gen (query = question , retriever = retriever )
223- sources = retriever .search (question )
224- for source in sources :
225- if "text" in source :
226- source ["text" ] = source ["text" ][:100 ].strip () + "..."
227- if len (sources ) > 0 :
228- data = json .dumps ({"type" : "source" , "source" : sources })
229- yield f"data: { data } \n \n "
230234
231235 for line in answer :
232236 if "answer" in line :
233237 response_full += str (line ["answer" ])
234- data = json .dumps (line )
238+ data = json .dumps ({ "type" : "answer" , "answer" : line [ "answer" ]} )
235239 yield f"data: { data } \n \n "
236- elif "source" in line :
237- source_log_docs .append (line ["source" ])
240+ elif "sources" in line :
241+ truncated_sources = []
242+ source_log_docs = line ["sources" ]
243+ for source in line ["sources" ]:
244+ truncated_source = source .copy ()
245+ if "text" in truncated_source :
246+ truncated_source ["text" ] = (
247+ truncated_source ["text" ][:100 ].strip () + "..."
248+ )
249+ truncated_sources .append (truncated_source )
250+ if len (truncated_sources ) > 0 :
251+ data = json .dumps ({"type" : "source" , "source" : truncated_sources })
252+ yield f"data: { data } \n \n "
238253 elif "tool_calls" in line :
239254 tool_calls = line ["tool_calls" ]
240255 data = json .dumps ({"type" : "tool_calls" , "tool_calls" : tool_calls })
@@ -245,11 +260,9 @@ def complete_stream(
245260 doc ["source" ] = "None"
246261
247262 llm = LLMCreator .create_llm (
248- settings .LLM_NAME ,
249- api_key = settings .API_KEY ,
250- user_api_key = user_api_key
263+ settings .LLM_NAME , api_key = settings .API_KEY , user_api_key = user_api_key
251264 )
252-
265+
253266 if should_save_conversation :
254267 conversation_id = save_conversation (
255268 conversation_id ,
@@ -259,7 +272,7 @@ def complete_stream(
259272 tool_calls ,
260273 llm ,
261274 index ,
262- api_key = user_api_key
275+ api_key = user_api_key ,
263276 )
264277 else :
265278 conversation_id = None
@@ -523,9 +536,19 @@ def post(self):
523536 extra = {"data" : json .dumps ({"request_data" : data , "source" : source })},
524537 )
525538
539+ agent = AgentCreator .create_agent (
540+ settings .AGENT_NAME ,
541+ endpoint = "api/answer" ,
542+ llm_name = settings .LLM_NAME ,
543+ gpt_model = gpt_model ,
544+ api_key = settings .API_KEY ,
545+ user_api_key = user_api_key ,
546+ prompt = prompt ,
547+ chat_history = history ,
548+ )
549+
526550 retriever = RetrieverCreator .create_retriever (
527551 retriever_name ,
528- question = question ,
529552 source = source ,
530553 chat_history = history ,
531554 prompt = prompt ,
@@ -538,13 +561,41 @@ def post(self):
538561 response_full = ""
539562 source_log_docs = []
540563 tool_calls = []
541- for line in retriever .gen ():
542- if "source" in line :
543- source_log_docs .append (line ["source" ])
544- elif "answer" in line :
545- response_full += line ["answer" ]
546- elif "tool_calls" in line :
547- tool_calls .append (line ["tool_calls" ])
564+ stream_ended = False
565+
566+ for line in complete_stream (
567+ question = question ,
568+ agent = agent ,
569+ retriever = retriever ,
570+ conversation_id = conversation_id ,
571+ user_api_key = user_api_key ,
572+ isNoneDoc = data .get ("isNoneDoc" ),
573+ index = None ,
574+ should_save_conversation = False ,
575+ ):
576+ try :
577+ event_data = line .replace ("data: " , "" ).strip ()
578+ event = json .loads (event_data )
579+
580+ if event ["type" ] == "answer" :
581+ response_full += event ["answer" ]
582+ elif event ["type" ] == "source" :
583+ source_log_docs = event ["source" ]
584+ elif event ["type" ] == "tool_calls" :
585+ tool_calls = event ["tool_calls" ]
586+ elif event ["type" ] == "error" :
587+ logger .error (f"Error from stream: { event ['error' ]} " )
588+ return bad_request (500 , event ["error" ])
589+ elif event ["type" ] == "end" :
590+ stream_ended = True
591+
592+ except (json .JSONDecodeError , KeyError ) as e :
593+ logger .warning (f"Error parsing stream event: { e } , line: { line } " )
594+ continue
595+
596+ if not stream_ended :
597+ logger .error ("Stream ended unexpectedly without an 'end' event." )
598+ return bad_request (500 , "Stream ended unexpectedly." )
548599
549600 if data .get ("isNoneDoc" ):
550601 for doc in source_log_docs :
@@ -563,8 +614,10 @@ def post(self):
563614 source_log_docs ,
564615 tool_calls ,
565616 llm ,
617+ api_key = user_api_key ,
566618 )
567619 )
620+
568621 retriever_params = retriever .get_params ()
569622 user_logs_collection .insert_one (
570623 {
0 commit comments