44import asyncio
55import logging
66import aiofiles
7- from pathlib import Path
87from threading import Thread
98from datetime import datetime
109from typing import Dict , List , Optional , Self , TypeAlias , Union , Type , Any
1110
1211from pydantic import BaseModel , Field
1312
1413from cognitrix .tasks import Task
15- from cognitrix .llms .base import LLM
14+ from cognitrix .llms .base import LLM , LLMResponse
1615from cognitrix .tools .base import Tool
1716from cognitrix .utils import extract_json , json_return_format
1817from cognitrix .agents .templates import AUTONOMOUSE_AGENT_2
@@ -184,27 +183,12 @@ def get_sub_agent_by_name(self, name: str) -> Optional['Agent']:
184183 def get_tool_by_name (self , name : str ) -> Optional [Tool ]:
185184 return next ((tool for tool in self .tools if tool .name .lower () == name .lower ()), None )
186185
187- async def process_response (self , response : str | dict ) -> Union [dict , str ]:
188- # print(response)
189- # response = response.replace("'", '"')
190- response_data = response
191- if isinstance (response , str ):
192- response = response .replace ('\\ n' , '' )
193- # response = response.replace("'", "\""
194- # response = response.replace('"', '\\"')
195- response_data = extract_json (response )
196-
197-
186+ async def call_tools (self , tool_calls : list ) -> Union [dict , str ]:
187+
198188 try :
199- if isinstance (response_data , dict ):
200- # final_result_keys = ['final_answer', 'tool_calls_result', 'response']
201-
189+ if tool_calls :
202190 tool_calls_result = []
203-
204- if response_data ['type' ].replace ('\\ ' , '' ) != 'tool_calls' :
205- return response_data ['result' ]
206-
207- for t in response_data ['tool_calls' ]:
191+ for t in tool_calls :
208192 tool = self .get_tool_by_name (t ['name' ])
209193
210194 if not tool :
@@ -235,8 +219,8 @@ async def process_response(self, response: str|dict) -> Union[dict, str]:
235219 else :
236220 raise Exception ('Not a json object' )
237221 except Exception as e :
238- # logger.exception(e)
239- return response_data
222+ logger .exception (e )
223+ return str ( e )
240224
241225 def add_tool (self , tool : Tool ):
242226 self .tools .append (tool )
@@ -282,18 +266,22 @@ def initialize(self, session_id: Optional[str] = None):
282266 response : Any = self .llm (full_prompt )
283267
284268 self .llm .chat_history .append (full_prompt )
285- self .llm .chat_history .append ({'role' : self .name , 'type' : 'text' , 'message' : response })
286-
287- if self .verbose :
288- print (response )
289-
290- result : dict [Any , Any ] | str = asyncio .run (self .process_response (response ))
269+
270+ if response .text :
271+ self .llm .chat_history .append ({'role' : self .name , 'type' : 'text' , 'message' : response .text })
272+ print (f"\n { self .name } : { response .text } " )
273+
274+ if response .tool_calls :
275+ result : dict [Any , Any ] | str = asyncio .run (self .call_tools (response .tool_calls ))
291276
292- if isinstance (result , dict ) and result ['type' ] == 'tool_calls_result' :
293- query = result
277+ if isinstance (result , dict ) and result ['type' ] == 'tool_calls_result' :
278+ query = result
279+ else :
280+ print (result )
294281 else :
295- print (f"\n { self .name } : { result } " )
296282 query = input ("\n User (q to quit): " )
283+
284+ # query = input("\nUser (q to quit): ")
297285
298286 self .save_session (session )
299287
@@ -322,7 +310,7 @@ def handle_transcription(self, sentence: str, transcriber: Transcriber):
322310 if self .verbose :
323311 print (response )
324312
325- processsed_response : dict [Any , Any ] | str = asyncio .run (self .process_response (response ))
313+ processsed_response : dict [Any , Any ] | str = asyncio .run (self .call_tools (response ))
326314
327315 if isinstance (processsed_response , dict ) and processsed_response ['type' ] == 'tool_calls_result' :
328316 query = processsed_response
@@ -361,7 +349,7 @@ def run_task(self, parent: Self):
361349 if parent .verbose :
362350 print (response )
363351
364- agent_result = asyncio .run (self .process_response (response ))
352+ agent_result = asyncio .run (self .call_tools (response ))
365353 if isinstance (agent_result , dict ) and agent_result ['type' ] == 'tool_calls_result' :
366354 query = agent_result
367355 else :
@@ -371,7 +359,7 @@ def run_task(self, parent: Self):
371359 parent_response : Any = parent .llm (parent_prompt )
372360 parent .llm .chat_history .append (parent_prompt )
373361 parent .llm .chat_history .append ({'role' : 'assistant' , 'type' : 'text' , 'message' : parent_response })
374- parent_result = asyncio .run (parent .process_response (parent_response ))
362+ parent_result = asyncio .run (parent .call_tools (parent_response ))
375363 print (f"\n \n { parent .name } : { parent_result } " )
376364 query = ""
377365
0 commit comments