Skip to content

Commit f3ea9b5

Browse files
committed
refactor agent response handling
- Introduce LLMResponse class to separate text and tool calls from LLM responses - Modify Agent.process_response() to Agent.call_tools() to handle tool calls - Update LLM classes to use LLMResponse - Remove redundant code and improve error handling
1 parent dce437d commit f3ea9b5

File tree

12 files changed

+88
-80
lines changed

12 files changed

+88
-80
lines changed

cognitrix/agents/base.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import asyncio
55
import logging
66
import aiofiles
7-
from pathlib import Path
87
from threading import Thread
98
from datetime import datetime
109
from typing import Dict, List, Optional, Self, TypeAlias, Union, Type, Any
1110

1211
from pydantic import BaseModel, Field
1312

1413
from cognitrix.tasks import Task
15-
from cognitrix.llms.base import LLM
14+
from cognitrix.llms.base import LLM, LLMResponse
1615
from cognitrix.tools.base import Tool
1716
from cognitrix.utils import extract_json, json_return_format
1817
from cognitrix.agents.templates import AUTONOMOUSE_AGENT_2
@@ -184,27 +183,12 @@ def get_sub_agent_by_name(self, name: str) -> Optional['Agent']:
184183
def get_tool_by_name(self, name: str) -> Optional[Tool]:
185184
return next((tool for tool in self.tools if tool.name.lower() == name.lower()), None)
186185

187-
async def process_response(self, response: str|dict) -> Union[dict, str]:
188-
# print(response)
189-
# response = response.replace("'", '"')
190-
response_data = response
191-
if isinstance(response, str):
192-
response = response.replace('\\n', '')
193-
# response = response.replace("'", "\""
194-
# response = response.replace('"', '\\"')
195-
response_data = extract_json(response)
196-
197-
186+
async def call_tools(self, tool_calls: list) -> Union[dict, str]:
187+
198188
try:
199-
if isinstance(response_data, dict):
200-
# final_result_keys = ['final_answer', 'tool_calls_result', 'response']
201-
189+
if tool_calls:
202190
tool_calls_result = []
203-
204-
if response_data['type'].replace('\\', '') != 'tool_calls':
205-
return response_data['result']
206-
207-
for t in response_data['tool_calls']:
191+
for t in tool_calls:
208192
tool = self.get_tool_by_name(t['name'])
209193

210194
if not tool:
@@ -235,8 +219,8 @@ async def process_response(self, response: str|dict) -> Union[dict, str]:
235219
else:
236220
raise Exception('Not a json object')
237221
except Exception as e:
238-
# logger.exception(e)
239-
return response_data
222+
logger.exception(e)
223+
return str(e)
240224

241225
def add_tool(self, tool: Tool):
242226
self.tools.append(tool)
@@ -282,18 +266,22 @@ def initialize(self, session_id: Optional[str] = None):
282266
response: Any = self.llm(full_prompt)
283267

284268
self.llm.chat_history.append(full_prompt)
285-
self.llm.chat_history.append({'role': self.name, 'type': 'text', 'message': response})
286-
287-
if self.verbose:
288-
print(response)
289-
290-
result: dict[Any, Any] | str = asyncio.run(self.process_response(response))
269+
270+
if response.text:
271+
self.llm.chat_history.append({'role': self.name, 'type': 'text', 'message': response.text})
272+
print(f"\n{self.name}: {response.text}")
273+
274+
if response.tool_calls:
275+
result: dict[Any, Any] | str = asyncio.run(self.call_tools(response.tool_calls))
291276

292-
if isinstance(result, dict) and result['type'] == 'tool_calls_result':
293-
query = result
277+
if isinstance(result, dict) and result['type'] == 'tool_calls_result':
278+
query = result
279+
else:
280+
print(result)
294281
else:
295-
print(f"\n{self.name}: {result}")
296282
query = input("\nUser (q to quit): ")
283+
284+
# query = input("\nUser (q to quit): ")
297285

298286
self.save_session(session)
299287

@@ -322,7 +310,7 @@ def handle_transcription(self, sentence: str, transcriber: Transcriber):
322310
if self.verbose:
323311
print(response)
324312

325-
processsed_response: dict[Any, Any] | str = asyncio.run(self.process_response(response))
313+
processsed_response: dict[Any, Any] | str = asyncio.run(self.call_tools(response))
326314

327315
if isinstance(processsed_response, dict) and processsed_response['type'] == 'tool_calls_result':
328316
query = processsed_response
@@ -361,7 +349,7 @@ def run_task(self, parent: Self):
361349
if parent.verbose:
362350
print(response)
363351

364-
agent_result = asyncio.run(self.process_response(response))
352+
agent_result = asyncio.run(self.call_tools(response))
365353
if isinstance(agent_result, dict) and agent_result['type'] == 'tool_calls_result':
366354
query = agent_result
367355
else:
@@ -371,7 +359,7 @@ def run_task(self, parent: Self):
371359
parent_response: Any = parent.llm(parent_prompt)
372360
parent.llm.chat_history.append(parent_prompt)
373361
parent.llm.chat_history.append({'role': 'assistant', 'type': 'text', 'message': parent_response})
374-
parent_result = asyncio.run(parent.process_response(parent_response))
362+
parent_result = asyncio.run(parent.call_tools(parent_response))
375363
print(f"\n\n{parent.name}: {parent_result}")
376364
query = ""
377365

cognitrix/llms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from cognitrix.llms.base import LLM
2+
from cognitrix.llms.base import LLMResponse
23
from cognitrix.llms.groq_llm import Groq
34
from cognitrix.llms.local_llm import Local
45
from cognitrix.llms.cohere_llm import Cohere

cognitrix/llms/base.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import json
22
from pydantic import BaseModel, Field
3-
from typing import Any, List, Dict
3+
from typing import Any, List, Dict, Optional, TypedDict
44
import logging
55
import inspect
66

7-
from cognitrix.utils import image_to_base64
7+
from cognitrix.utils import extract_json, image_to_base64
88
from cognitrix.tools.base import Tool
99

1010
logging.basicConfig(
@@ -14,6 +14,35 @@
1414
)
1515
logger = logging.getLogger('cognitrix.log')
1616

17+
class LLMResponse:
18+
"""Class to handle and separate LLM responses into text and tool calls."""
19+
20+
def __init__(self, llm_response: Optional[str]=None):
21+
self.llm_response = llm_response
22+
self.text: Optional[str] = None
23+
self.tool_calls: Optional[List[Dict[str, Any]]] = None
24+
self.parse_llm_response()
25+
26+
def parse_llm_response(self):
27+
"""Parse the LLM response into text and tool calls."""
28+
29+
if not self.llm_response: return
30+
31+
response_data = extract_json(self.llm_response)
32+
33+
try:
34+
if isinstance(response_data, dict):
35+
if 'result' in response_data.keys():
36+
self.text = response_data['result']
37+
else:
38+
self.tool_calls = response_data['tool_calls']
39+
else:
40+
self.text = str(response_data)
41+
except Exception as e:
42+
logger.exception(e)
43+
self.text = str(response_data)
44+
45+
1746
class LLM(BaseModel):
1847
"""
1948
A class for representing a large language model.
@@ -162,4 +191,4 @@ def load_llm(model_name: str):
162191
return None
163192

164193
def __call__(*args, **kwargs):
165-
pass
194+
return LLMResponse()

cognitrix/llms/clarifai_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from clarifai.client.model import Model
2-
from cognitrix.llms.base import LLM
2+
from cognitrix.llms.base import LLM, LLMResponse
33
from typing import Any, Optional
44
from dotenv import load_dotenv
55
import logging
@@ -36,7 +36,7 @@ class Clarifai(LLM):
3636
api_key: str = os.getenv('CLARIFAI_API_KEY', '')
3737
"""Clarifai Personal Access Token"""
3838

39-
def __call__(self, query, **kwds: Any)->str:
39+
def __call__(self, query, **kwds: Any):
4040
"""Generates a response to a query using the Clarifai API.
4141
4242
Args:
@@ -54,5 +54,5 @@ def __call__(self, query, **kwds: Any)->str:
5454
query = f"{self.system_prompt}\n {json.dumps(formatted_messages)}"
5555
result = self.client.predict_by_bytes(query.encode(), input_type="text")
5656

57-
return result.outputs[0].data.text.raw
57+
return LLMResponse(result.outputs[0].data.text.raw)
5858

cognitrix/llms/cohere_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import cohere
2-
from cognitrix.llms.base import LLM
2+
from cognitrix.llms.base import LLM, LLMResponse
33
from typing import Any, Optional
44
from dotenv import load_dotenv
55
import logging
@@ -54,7 +54,7 @@ def format_tools(self, tools: list[dict[str, Any]]):
5454

5555
self.tools.append(f_tool)
5656

57-
def __call__(self, query, **kwds: Any)->str:
57+
def __call__(self, query, **kwds: Any):
5858
"""Generates a response to a query using the Cohere API.
5959
6060
Args:
@@ -79,7 +79,7 @@ def __call__(self, query, **kwds: Any)->str:
7979
connectors=[{"id": "web-search"}]
8080
)
8181

82-
return response.text
82+
return LLMResponse(response.text)
8383

8484
if __name__ == "__main__":
8585
try:

cognitrix/llms/google_llm.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cognitrix.llms.base import LLM
1+
from cognitrix.llms.base import LLM, LLMResponse
22
from typing import Any
33
import google.generativeai as genai
44
from dotenv import load_dotenv
@@ -72,7 +72,7 @@ def format_query(self, message: dict[str, str]) -> list:
7272

7373
return messages
7474

75-
def __call__(self, query, **kwds: Any)->str|None:
75+
def __call__(self, query, **kwds: Any):
7676
"""Generates a response to a query using the Gemini API.
7777
7878
Args:
@@ -104,15 +104,5 @@ def __call__(self, query, **kwds: Any)->str|None:
104104

105105
response.resolve()
106106

107-
return response.text
108-
109-
if __name__ == "__main__":
110-
try:
111-
assistant = Google()
112-
# assistant.add_tool(calculator)
113-
while True:
114-
message = input("\nEnter Query$ ")
115-
result = assistant(message)
116-
print(result)
117-
except KeyboardInterrupt:
118-
sys.exit(1)
107+
return LLMResponse(response.text)
108+

cognitrix/llms/groq_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from groq import Groq as GroqLLM
2-
from cognitrix.llms.base import LLM
2+
from cognitrix.llms.base import LLM, LLMResponse
33
from typing import Any, Optional
44
from dotenv import load_dotenv
55
import logging
@@ -74,7 +74,7 @@ def format_query(self, message: dict[str, str]) -> list:
7474
return messages
7575

7676

77-
def __call__(self, query, **kwds: Any)->str|None:
77+
def __call__(self, query, **kwds: Any):
7878
"""Generates a response to a query using the Groq API.
7979
8080
Args:
@@ -107,7 +107,7 @@ def __call__(self, query, **kwds: Any)->str|None:
107107
# tool_calls = response_message.tool_calls
108108
# print(tool_calls)
109109

110-
return response_message.content
110+
return LLMResponse(response_message.content)
111111

112112
if __name__ == "__main__":
113113
try:

cognitrix/llms/local_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from openai import OpenAI as OpenAILLM
2-
from cognitrix.llms.base import LLM
2+
from cognitrix.llms.base import LLM, LLMResponse
33
from cognitrix.utils import image_to_base64
44
from typing import Any, Optional
55
from dotenv import load_dotenv
@@ -95,7 +95,7 @@ def format_query(self, message: dict[str, str]) -> list:
9595

9696
return messages
9797

98-
def __call__(self, query: dict, **kwds: dict)->Optional[str]:
98+
def __call__(self, query: dict, **kwds: dict):
9999
"""Generates a response to a query using the OpenAI API.
100100
101101
Args:
@@ -121,4 +121,4 @@ def __call__(self, query: dict, **kwds: dict)->Optional[str]:
121121
max_tokens=self.max_tokens
122122
)
123123

124-
return response.choices[0].message.content #type: ignore
124+
return LLMResponse(response.choices[0].message.content) #type: ignore

cognitrix/llms/mindsdb_llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Optional
44
import os
55

6+
from cognitrix.llms.base import LLMResponse
7+
68
class MindsDB(OpenAI):
79
"""A class for interacting with the MindsDB API."""
810

@@ -12,7 +14,7 @@ class MindsDB(OpenAI):
1214

1315
base_url: str = 'https://llm.mdb.ai'
1416

15-
def __call__(self, query: dict, **kwds: dict)->Optional[str]:
17+
def __call__(self, query: dict, **kwds: dict):
1618
"""Generates a response to a query using the OpenAI API.
1719
1820
Args:
@@ -38,4 +40,4 @@ def __call__(self, query: dict, **kwds: dict)->Optional[str]:
3840
max_tokens=self.max_tokens
3941
)
4042

41-
return response.choices[0].message.content
43+
return LLMResponse(response.choices[0].message.content)

cognitrix/llms/ollama_llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from cognitrix.llms.base import LLM
1+
from cognitrix.llms.base import LLM, LLMResponse
22
from cognitrix.utils import image_to_base64
33
from typing import Any, Optional
44
from dotenv import load_dotenv
@@ -88,7 +88,7 @@ def format_query(self, message: dict[str, str]) -> list:
8888

8989
return messages
9090

91-
def __call__(self, query: dict, **kwds: dict)->Optional[str]:
91+
def __call__(self, query: dict, **kwds: dict):
9292
"""Generates a response to a query using the OpenAI API.
9393
9494
Args:
@@ -112,5 +112,5 @@ def __call__(self, query: dict, **kwds: dict)->Optional[str]:
112112
],
113113
stream=False,
114114
)
115-
print(response)
116-
return response['message']['content'] #type: ignore
115+
116+
return LLMResponse(response['message']['content']) #type: ignore

0 commit comments

Comments
 (0)