Skip to content

Commit 72bbe3b

Browse files
2 parents 8568243 + 632cba8 commit 72bbe3b

17 files changed

+723
-360
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@
4848
- [x] Add tools (Jan 2025)
4949
- [x] Manually updating chunks in the app UI (Feb 2025)
5050
- [x] Devcontainer for easy development (Feb 2025)
51+
- [x] ReACT agent (March 2025)
5152
- [ ] Anthropic Tool compatibility
53+
- [ ] New input box in the conversation menu
5254
- [ ] Add triggerable actions / tools (webhook)
5355
- [ ] Add OAuth 2.0 authentication for tools and sources
54-
- [ ] Chatbots menu re-design to handle tools, scheduling, and more
56+
- [ ] Chatbots menu re-design to handle tools, agent types, and more
57+
- [ ] Agent scheduling
5558

5659
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
5760

application/agents/__init__.py

Whitespace-only changes.

application/agents/agent_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from application.agents.classic_agent import ClassicAgent
2+
from application.agents.react_agent import ReActAgent
23

34

45
class AgentCreator:
56
agents = {
67
"classic": ClassicAgent,
8+
"react": ReActAgent,
79
}
810

911
@classmethod

application/agents/base.py

Lines changed: 114 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,62 @@
1-
from typing import Dict, Generator
1+
import uuid
2+
from abc import ABC, abstractmethod
3+
from typing import Dict, Generator, List, Optional
24

35
from application.agents.llm_handler import get_llm_handler
46
from application.agents.tools.tool_action_parser import ToolActionParser
57
from application.agents.tools.tool_manager import ToolManager
68

79
from application.core.mongo_db import MongoDB
810
from application.llm.llm_creator import LLMCreator
11+
from application.logging import build_stack_data, log_activity, LogContext
12+
from application.retriever.base import BaseRetriever
913

1014

11-
class BaseAgent:
15+
class BaseAgent(ABC):
1216
def __init__(
1317
self,
14-
endpoint,
15-
llm_name,
16-
gpt_model,
17-
api_key,
18-
user_api_key=None,
19-
decoded_token=None,
20-
attachments=None,
18+
endpoint: str,
19+
llm_name: str,
20+
gpt_model: str,
21+
api_key: str,
22+
user_api_key: Optional[str] = None,
23+
prompt: str = "",
24+
chat_history: Optional[List[Dict]] = None,
25+
decoded_token: Optional[Dict] = None,
26+
attachments: Optional[str]=None,
2127
):
2228
self.endpoint = endpoint
29+
self.llm_name = llm_name
30+
self.gpt_model = gpt_model
31+
self.api_key = api_key
32+
self.user_api_key = user_api_key
33+
self.prompt = prompt
34+
self.decoded_token = decoded_token or {}
35+
self.user: str = decoded_token.get("sub")
36+
self.tool_config: Dict = {}
37+
self.tools: List[Dict] = []
38+
self.tool_calls: List[Dict] = []
39+
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
2340
self.llm = LLMCreator.create_llm(
2441
llm_name,
2542
api_key=api_key,
2643
user_api_key=user_api_key,
2744
decoded_token=decoded_token,
2845
)
2946
self.llm_handler = get_llm_handler(llm_name)
30-
self.gpt_model = gpt_model
31-
self.tools = []
32-
self.tool_config = {}
33-
self.tool_calls = []
34-
self.attachments = attachments or []
47+
set.attachments = attachments or []
3548

36-
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
37-
raise NotImplementedError('Method "gen" must be implemented in the child class')
49+
@log_activity()
50+
def gen(
51+
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
52+
) -> Generator[Dict, None, None]:
53+
yield from self._gen_inner(query, retriever, log_context)
54+
55+
@abstractmethod
56+
def _gen_inner(
57+
self, query: str, retriever: BaseRetriever, log_context: LogContext
58+
) -> Generator[Dict, None, None]:
59+
pass
3860

3961
def _get_user_tools(self, user="local"):
4062
mongo = MongoDB.get_client()
@@ -111,14 +133,12 @@ def _execute_tool_action(self, tools_dict, call):
111133
for param, details in action_data[param_type]["properties"].items():
112134
if param not in call_args and "value" in details:
113135
target_dict[param] = details["value"]
114-
115136
for param, value in call_args.items():
116137
for param_type, target_dict in param_types.items():
117138
if param_type in action_data and param in action_data[param_type].get(
118139
"properties", {}
119140
):
120141
target_dict[param] = value
121-
122142
tm = ToolManager(config={})
123143
tool = tm.load_tool(
124144
tool_data["name"],
@@ -153,3 +173,79 @@ def _execute_tool_action(self, tools_dict, call):
153173
self.tool_calls.append(tool_call_data)
154174

155175
return result, call_id
176+
177+
def _build_messages(
178+
self,
179+
system_prompt: str,
180+
query: str,
181+
retrieved_data: List[Dict],
182+
) -> List[Dict]:
183+
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
184+
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
185+
messages_combine = [{"role": "system", "content": p_chat_combine}]
186+
187+
for i in self.chat_history:
188+
if "prompt" in i and "response" in i:
189+
messages_combine.append({"role": "user", "content": i["prompt"]})
190+
messages_combine.append({"role": "assistant", "content": i["response"]})
191+
if "tool_calls" in i:
192+
for tool_call in i["tool_calls"]:
193+
call_id = tool_call.get("call_id") or str(uuid.uuid4())
194+
195+
function_call_dict = {
196+
"function_call": {
197+
"name": tool_call.get("action_name"),
198+
"args": tool_call.get("arguments"),
199+
"call_id": call_id,
200+
}
201+
}
202+
function_response_dict = {
203+
"function_response": {
204+
"name": tool_call.get("action_name"),
205+
"response": {"result": tool_call.get("result")},
206+
"call_id": call_id,
207+
}
208+
}
209+
210+
messages_combine.append(
211+
{"role": "assistant", "content": [function_call_dict]}
212+
)
213+
messages_combine.append(
214+
{"role": "tool", "content": [function_response_dict]}
215+
)
216+
messages_combine.append({"role": "user", "content": query})
217+
return messages_combine
218+
219+
def _retriever_search(
220+
self,
221+
retriever: BaseRetriever,
222+
query: str,
223+
log_context: Optional[LogContext] = None,
224+
) -> List[Dict]:
225+
retrieved_data = retriever.search(query)
226+
if log_context:
227+
data = build_stack_data(retriever, exclude_attributes=["llm"])
228+
log_context.stacks.append({"component": "retriever", "data": data})
229+
return retrieved_data
230+
231+
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
232+
resp = self.llm.gen_stream(
233+
model=self.gpt_model, messages=messages, tools=self.tools
234+
)
235+
if log_context:
236+
data = build_stack_data(self.llm)
237+
log_context.stacks.append({"component": "llm", "data": data})
238+
return resp
239+
240+
def _llm_handler(
241+
self,
242+
resp,
243+
tools_dict: Dict,
244+
messages: List[Dict],
245+
log_context: Optional[LogContext] = None,
246+
):
247+
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
248+
if log_context:
249+
data = build_stack_data(self.llm_handler)
250+
log_context.stacks.append({"component": "llm_handler", "data": data})
251+
return resp
Lines changed: 8 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,26 @@
1-
import uuid
21
from typing import Dict, Generator
32

43
from application.agents.base import BaseAgent
5-
from application.logging import build_stack_data, log_activity, LogContext
4+
from application.logging import LogContext
65

76
from application.retriever.base import BaseRetriever
87
import logging
98
logger = logging.getLogger(__name__)
109

1110
class ClassicAgent(BaseAgent):
12-
def __init__(
13-
self,
14-
endpoint,
15-
llm_name,
16-
gpt_model,
17-
api_key,
18-
user_api_key=None,
19-
prompt="",
20-
chat_history=None,
21-
decoded_token=None,
22-
attachments=None,
23-
):
24-
super().__init__(
25-
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token, attachments
26-
)
27-
self.user = decoded_token.get("sub")
28-
self.prompt = prompt
29-
self.chat_history = chat_history if chat_history is not None else []
30-
31-
@log_activity()
32-
def gen(
33-
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
34-
) -> Generator[Dict, None, None]:
35-
yield from self._gen_inner(query, retriever, log_context)
36-
3711
def _gen_inner(
3812
self, query: str, retriever: BaseRetriever, log_context: LogContext
3913
) -> Generator[Dict, None, None]:
4014
retrieved_data = self._retriever_search(retriever, query, log_context)
4115

42-
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
43-
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
44-
messages_combine = [{"role": "system", "content": p_chat_combine}]
45-
46-
if len(self.chat_history) > 0:
47-
for i in self.chat_history:
48-
if "prompt" in i and "response" in i:
49-
messages_combine.append({"role": "user", "content": i["prompt"]})
50-
messages_combine.append(
51-
{"role": "assistant", "content": i["response"]}
52-
)
53-
if "tool_calls" in i:
54-
for tool_call in i["tool_calls"]:
55-
call_id = tool_call.get("call_id")
56-
if call_id is None or call_id == "None":
57-
call_id = str(uuid.uuid4())
58-
59-
function_call_dict = {
60-
"function_call": {
61-
"name": tool_call.get("action_name"),
62-
"args": tool_call.get("arguments"),
63-
"call_id": call_id,
64-
}
65-
}
66-
function_response_dict = {
67-
"function_response": {
68-
"name": tool_call.get("action_name"),
69-
"response": {"result": tool_call.get("result")},
70-
"call_id": call_id,
71-
}
72-
}
73-
74-
messages_combine.append(
75-
{"role": "assistant", "content": [function_call_dict]}
76-
)
77-
messages_combine.append(
78-
{"role": "tool", "content": [function_response_dict]}
79-
)
80-
messages_combine.append({"role": "user", "content": query})
81-
8216
tools_dict = self._get_user_tools(self.user)
8317
self._prepare_tools(tools_dict)
8418

85-
resp = self._llm_gen(messages_combine, log_context)
19+
messages = self._build_messages(self.prompt, query, retrieved_data)
20+
21+
resp = self._llm_gen(messages, log_context)
22+
23+
attachments = self.attachments
8624

8725
if isinstance(resp, str):
8826
yield {"answer": resp}
@@ -95,7 +33,7 @@ def _gen_inner(
9533
yield {"answer": resp.message.content}
9634
return
9735

98-
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context, self.attachments)
36+
resp = self._llm_handler(resp, tools_dict, messages, log_context,attachments)
9937

10038
if isinstance(resp, str):
10139
yield {"answer": resp}
@@ -107,37 +45,11 @@ def _gen_inner(
10745
yield {"answer": resp.message.content}
10846
else:
10947
completion = self.llm.gen_stream(
110-
model=self.gpt_model, messages=messages_combine, tools=self.tools
48+
model=self.gpt_model, messages=messages, tools=self.tools
11149
)
11250
for line in completion:
11351
if isinstance(line, str):
11452
yield {"answer": line}
11553

11654
yield {"sources": retrieved_data}
11755
yield {"tool_calls": self.tool_calls.copy()}
118-
119-
def _retriever_search(self, retriever, query, log_context):
120-
retrieved_data = retriever.search(query)
121-
if log_context:
122-
data = build_stack_data(retriever, exclude_attributes=["llm"])
123-
log_context.stacks.append({"component": "retriever", "data": data})
124-
return retrieved_data
125-
126-
def _llm_gen(self, messages_combine, log_context):
127-
resp = self.llm.gen_stream(
128-
model=self.gpt_model, messages=messages_combine, tools=self.tools
129-
)
130-
if log_context:
131-
data = build_stack_data(self.llm)
132-
log_context.stacks.append({"component": "llm", "data": data})
133-
return resp
134-
135-
def _llm_handler(self, resp, tools_dict, messages_combine, log_context, attachments=None):
136-
logger.info(f"Handling LLM response with {len(attachments) if attachments else 0} attachments")
137-
resp = self.llm_handler.handle_response(
138-
self, resp, tools_dict, messages_combine, attachments=attachments
139-
)
140-
if log_context:
141-
data = build_stack_data(self.llm_handler)
142-
log_context.stacks.append({"component": "llm_handler", "data": data})
143-
return resp

0 commit comments

Comments
 (0)