Skip to content

Commit ab0da1a

Browse files
authored
Merge pull request #1721 from siiddhantt/feat/react-agent
feat: ReActAgent and agent refactor
2 parents 0e31329 + 7f31ac7 commit ab0da1a

File tree

15 files changed

+691
-325
lines changed

15 files changed

+691
-325
lines changed

application/agents/__init__.py

Whitespace-only changes.

application/agents/agent_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from application.agents.classic_agent import ClassicAgent
2+
from application.agents.react_agent import ReActAgent
23

34

45
class AgentCreator:
56
agents = {
67
"classic": ClassicAgent,
8+
"react": ReActAgent,
79
}
810

911
@classmethod

application/agents/base.py

Lines changed: 112 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,60 @@
1-
from typing import Dict, Generator
1+
import uuid
2+
from abc import ABC, abstractmethod
3+
from typing import Dict, Generator, List, Optional
24

35
from application.agents.llm_handler import get_llm_handler
46
from application.agents.tools.tool_action_parser import ToolActionParser
57
from application.agents.tools.tool_manager import ToolManager
68

79
from application.core.mongo_db import MongoDB
810
from application.llm.llm_creator import LLMCreator
11+
from application.logging import build_stack_data, log_activity, LogContext
12+
from application.retriever.base import BaseRetriever
913

1014

11-
class BaseAgent:
15+
class BaseAgent(ABC):
1216
def __init__(
1317
self,
14-
endpoint,
15-
llm_name,
16-
gpt_model,
17-
api_key,
18-
user_api_key=None,
19-
decoded_token=None,
18+
endpoint: str,
19+
llm_name: str,
20+
gpt_model: str,
21+
api_key: str,
22+
user_api_key: Optional[str] = None,
23+
prompt: str = "",
24+
chat_history: Optional[List[Dict]] = None,
25+
decoded_token: Optional[Dict] = None,
2026
):
2127
self.endpoint = endpoint
28+
self.llm_name = llm_name
29+
self.gpt_model = gpt_model
30+
self.api_key = api_key
31+
self.user_api_key = user_api_key
32+
self.prompt = prompt
33+
self.decoded_token = decoded_token or {}
34+
self.user: str = decoded_token.get("sub")
35+
self.tool_config: Dict = {}
36+
self.tools: List[Dict] = []
37+
self.tool_calls: List[Dict] = []
38+
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
2239
self.llm = LLMCreator.create_llm(
2340
llm_name,
2441
api_key=api_key,
2542
user_api_key=user_api_key,
2643
decoded_token=decoded_token,
2744
)
2845
self.llm_handler = get_llm_handler(llm_name)
29-
self.gpt_model = gpt_model
30-
self.tools = []
31-
self.tool_config = {}
32-
self.tool_calls = []
3346

34-
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
35-
raise NotImplementedError('Method "gen" must be implemented in the child class')
47+
@log_activity()
48+
def gen(
49+
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
50+
) -> Generator[Dict, None, None]:
51+
yield from self._gen_inner(query, retriever, log_context)
52+
53+
@abstractmethod
54+
def _gen_inner(
55+
self, query: str, retriever: BaseRetriever, log_context: LogContext
56+
) -> Generator[Dict, None, None]:
57+
pass
3658

3759
def _get_user_tools(self, user="local"):
3860
mongo = MongoDB.get_client()
@@ -109,14 +131,12 @@ def _execute_tool_action(self, tools_dict, call):
109131
for param, details in action_data[param_type]["properties"].items():
110132
if param not in call_args and "value" in details:
111133
target_dict[param] = details["value"]
112-
113134
for param, value in call_args.items():
114135
for param_type, target_dict in param_types.items():
115136
if param_type in action_data and param in action_data[param_type].get(
116137
"properties", {}
117138
):
118139
target_dict[param] = value
119-
120140
tm = ToolManager(config={})
121141
tool = tm.load_tool(
122142
tool_data["name"],
@@ -151,3 +171,79 @@ def _execute_tool_action(self, tools_dict, call):
151171
self.tool_calls.append(tool_call_data)
152172

153173
return result, call_id
174+
175+
def _build_messages(
176+
self,
177+
system_prompt: str,
178+
query: str,
179+
retrieved_data: List[Dict],
180+
) -> List[Dict]:
181+
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
182+
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
183+
messages_combine = [{"role": "system", "content": p_chat_combine}]
184+
185+
for i in self.chat_history:
186+
if "prompt" in i and "response" in i:
187+
messages_combine.append({"role": "user", "content": i["prompt"]})
188+
messages_combine.append({"role": "assistant", "content": i["response"]})
189+
if "tool_calls" in i:
190+
for tool_call in i["tool_calls"]:
191+
call_id = tool_call.get("call_id") or str(uuid.uuid4())
192+
193+
function_call_dict = {
194+
"function_call": {
195+
"name": tool_call.get("action_name"),
196+
"args": tool_call.get("arguments"),
197+
"call_id": call_id,
198+
}
199+
}
200+
function_response_dict = {
201+
"function_response": {
202+
"name": tool_call.get("action_name"),
203+
"response": {"result": tool_call.get("result")},
204+
"call_id": call_id,
205+
}
206+
}
207+
208+
messages_combine.append(
209+
{"role": "assistant", "content": [function_call_dict]}
210+
)
211+
messages_combine.append(
212+
{"role": "tool", "content": [function_response_dict]}
213+
)
214+
messages_combine.append({"role": "user", "content": query})
215+
return messages_combine
216+
217+
def _retriever_search(
218+
self,
219+
retriever: BaseRetriever,
220+
query: str,
221+
log_context: Optional[LogContext] = None,
222+
) -> List[Dict]:
223+
retrieved_data = retriever.search(query)
224+
if log_context:
225+
data = build_stack_data(retriever, exclude_attributes=["llm"])
226+
log_context.stacks.append({"component": "retriever", "data": data})
227+
return retrieved_data
228+
229+
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
230+
resp = self.llm.gen_stream(
231+
model=self.gpt_model, messages=messages, tools=self.tools
232+
)
233+
if log_context:
234+
data = build_stack_data(self.llm)
235+
log_context.stacks.append({"component": "llm", "data": data})
236+
return resp
237+
238+
def _llm_handler(
239+
self,
240+
resp,
241+
tools_dict: Dict,
242+
messages: List[Dict],
243+
log_context: Optional[LogContext] = None,
244+
):
245+
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
246+
if log_context:
247+
data = build_stack_data(self.llm_handler)
248+
log_context.stacks.append({"component": "llm_handler", "data": data})
249+
return resp
Lines changed: 6 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,23 @@
1-
import uuid
21
from typing import Dict, Generator
32

43
from application.agents.base import BaseAgent
5-
from application.logging import build_stack_data, log_activity, LogContext
4+
from application.logging import LogContext
65

76
from application.retriever.base import BaseRetriever
87

98

109
class ClassicAgent(BaseAgent):
11-
def __init__(
12-
self,
13-
endpoint,
14-
llm_name,
15-
gpt_model,
16-
api_key,
17-
user_api_key=None,
18-
prompt="",
19-
chat_history=None,
20-
decoded_token=None,
21-
):
22-
super().__init__(
23-
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
24-
)
25-
self.user = decoded_token.get("sub")
26-
self.prompt = prompt
27-
self.chat_history = chat_history if chat_history is not None else []
28-
29-
@log_activity()
30-
def gen(
31-
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
32-
) -> Generator[Dict, None, None]:
33-
yield from self._gen_inner(query, retriever, log_context)
34-
3510
def _gen_inner(
3611
self, query: str, retriever: BaseRetriever, log_context: LogContext
3712
) -> Generator[Dict, None, None]:
3813
retrieved_data = self._retriever_search(retriever, query, log_context)
3914

40-
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
41-
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
42-
messages_combine = [{"role": "system", "content": p_chat_combine}]
43-
44-
if len(self.chat_history) > 0:
45-
for i in self.chat_history:
46-
if "prompt" in i and "response" in i:
47-
messages_combine.append({"role": "user", "content": i["prompt"]})
48-
messages_combine.append(
49-
{"role": "assistant", "content": i["response"]}
50-
)
51-
if "tool_calls" in i:
52-
for tool_call in i["tool_calls"]:
53-
call_id = tool_call.get("call_id")
54-
if call_id is None or call_id == "None":
55-
call_id = str(uuid.uuid4())
56-
57-
function_call_dict = {
58-
"function_call": {
59-
"name": tool_call.get("action_name"),
60-
"args": tool_call.get("arguments"),
61-
"call_id": call_id,
62-
}
63-
}
64-
function_response_dict = {
65-
"function_response": {
66-
"name": tool_call.get("action_name"),
67-
"response": {"result": tool_call.get("result")},
68-
"call_id": call_id,
69-
}
70-
}
71-
72-
messages_combine.append(
73-
{"role": "assistant", "content": [function_call_dict]}
74-
)
75-
messages_combine.append(
76-
{"role": "tool", "content": [function_response_dict]}
77-
)
78-
messages_combine.append({"role": "user", "content": query})
79-
8015
tools_dict = self._get_user_tools(self.user)
8116
self._prepare_tools(tools_dict)
8217

83-
resp = self._llm_gen(messages_combine, log_context)
18+
messages = self._build_messages(self.prompt, query, retrieved_data)
19+
20+
resp = self._llm_gen(messages, log_context)
8421

8522
if isinstance(resp, str):
8623
yield {"answer": resp}
@@ -93,7 +30,7 @@ def _gen_inner(
9330
yield {"answer": resp.message.content}
9431
return
9532

96-
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
33+
resp = self._llm_handler(resp, tools_dict, messages, log_context)
9734

9835
if isinstance(resp, str):
9936
yield {"answer": resp}
@@ -105,36 +42,11 @@ def _gen_inner(
10542
yield {"answer": resp.message.content}
10643
else:
10744
completion = self.llm.gen_stream(
108-
model=self.gpt_model, messages=messages_combine, tools=self.tools
45+
model=self.gpt_model, messages=messages, tools=self.tools
10946
)
11047
for line in completion:
11148
if isinstance(line, str):
11249
yield {"answer": line}
11350

11451
yield {"sources": retrieved_data}
11552
yield {"tool_calls": self.tool_calls.copy()}
116-
117-
def _retriever_search(self, retriever, query, log_context):
118-
retrieved_data = retriever.search(query)
119-
if log_context:
120-
data = build_stack_data(retriever, exclude_attributes=["llm"])
121-
log_context.stacks.append({"component": "retriever", "data": data})
122-
return retrieved_data
123-
124-
def _llm_gen(self, messages_combine, log_context):
125-
resp = self.llm.gen_stream(
126-
model=self.gpt_model, messages=messages_combine, tools=self.tools
127-
)
128-
if log_context:
129-
data = build_stack_data(self.llm)
130-
log_context.stacks.append({"component": "llm", "data": data})
131-
return resp
132-
133-
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
134-
resp = self.llm_handler.handle_response(
135-
self, resp, tools_dict, messages_combine
136-
)
137-
if log_context:
138-
data = build_stack_data(self.llm_handler)
139-
log_context.stacks.append({"component": "llm_handler", "data": data})
140-
return resp

0 commit comments

Comments
 (0)