|
1 | | -from typing import Dict, Generator |
| 1 | +import uuid |
| 2 | +from abc import ABC, abstractmethod |
| 3 | +from typing import Dict, Generator, List, Optional |
2 | 4 |
|
3 | 5 | from application.agents.llm_handler import get_llm_handler |
4 | 6 | from application.agents.tools.tool_action_parser import ToolActionParser |
5 | 7 | from application.agents.tools.tool_manager import ToolManager |
6 | 8 |
|
7 | 9 | from application.core.mongo_db import MongoDB |
8 | 10 | from application.llm.llm_creator import LLMCreator |
| 11 | +from application.logging import build_stack_data, log_activity, LogContext |
| 12 | +from application.retriever.base import BaseRetriever |
9 | 13 |
|
10 | 14 |
|
11 | | -class BaseAgent: |
| 15 | +class BaseAgent(ABC): |
12 | 16 | def __init__( |
13 | 17 | self, |
14 | | - endpoint, |
15 | | - llm_name, |
16 | | - gpt_model, |
17 | | - api_key, |
18 | | - user_api_key=None, |
19 | | - decoded_token=None, |
| 18 | + endpoint: str, |
| 19 | + llm_name: str, |
| 20 | + gpt_model: str, |
| 21 | + api_key: str, |
| 22 | + user_api_key: Optional[str] = None, |
| 23 | + prompt: str = "", |
| 24 | + chat_history: Optional[List[Dict]] = None, |
| 25 | + decoded_token: Optional[Dict] = None, |
20 | 26 | ): |
21 | 27 | self.endpoint = endpoint |
| 28 | + self.llm_name = llm_name |
| 29 | + self.gpt_model = gpt_model |
| 30 | + self.api_key = api_key |
| 31 | + self.user_api_key = user_api_key |
| 32 | + self.prompt = prompt |
| 33 | + self.decoded_token = decoded_token or {} |
| 34 | + self.user: str = decoded_token.get("sub") |
| 35 | + self.tool_config: Dict = {} |
| 36 | + self.tools: List[Dict] = [] |
| 37 | + self.tool_calls: List[Dict] = [] |
| 38 | + self.chat_history: List[Dict] = chat_history if chat_history is not None else [] |
22 | 39 | self.llm = LLMCreator.create_llm( |
23 | 40 | llm_name, |
24 | 41 | api_key=api_key, |
25 | 42 | user_api_key=user_api_key, |
26 | 43 | decoded_token=decoded_token, |
27 | 44 | ) |
28 | 45 | self.llm_handler = get_llm_handler(llm_name) |
29 | | - self.gpt_model = gpt_model |
30 | | - self.tools = [] |
31 | | - self.tool_config = {} |
32 | | - self.tool_calls = [] |
33 | 46 |
|
34 | | - def gen(self, *args, **kwargs) -> Generator[Dict, None, None]: |
35 | | - raise NotImplementedError('Method "gen" must be implemented in the child class') |
| 47 | + @log_activity() |
| 48 | + def gen( |
| 49 | + self, query: str, retriever: BaseRetriever, log_context: LogContext = None |
| 50 | + ) -> Generator[Dict, None, None]: |
| 51 | + yield from self._gen_inner(query, retriever, log_context) |
| 52 | + |
| 53 | + @abstractmethod |
| 54 | + def _gen_inner( |
| 55 | + self, query: str, retriever: BaseRetriever, log_context: LogContext |
| 56 | + ) -> Generator[Dict, None, None]: |
| 57 | + pass |
36 | 58 |
|
37 | 59 | def _get_user_tools(self, user="local"): |
38 | 60 | mongo = MongoDB.get_client() |
@@ -109,14 +131,12 @@ def _execute_tool_action(self, tools_dict, call): |
109 | 131 | for param, details in action_data[param_type]["properties"].items(): |
110 | 132 | if param not in call_args and "value" in details: |
111 | 133 | target_dict[param] = details["value"] |
112 | | - |
113 | 134 | for param, value in call_args.items(): |
114 | 135 | for param_type, target_dict in param_types.items(): |
115 | 136 | if param_type in action_data and param in action_data[param_type].get( |
116 | 137 | "properties", {} |
117 | 138 | ): |
118 | 139 | target_dict[param] = value |
119 | | - |
120 | 140 | tm = ToolManager(config={}) |
121 | 141 | tool = tm.load_tool( |
122 | 142 | tool_data["name"], |
@@ -151,3 +171,79 @@ def _execute_tool_action(self, tools_dict, call): |
151 | 171 | self.tool_calls.append(tool_call_data) |
152 | 172 |
|
153 | 173 | return result, call_id |
| 174 | + |
| 175 | + def _build_messages( |
| 176 | + self, |
| 177 | + system_prompt: str, |
| 178 | + query: str, |
| 179 | + retrieved_data: List[Dict], |
| 180 | + ) -> List[Dict]: |
| 181 | + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) |
| 182 | + p_chat_combine = system_prompt.replace("{summaries}", docs_together) |
| 183 | + messages_combine = [{"role": "system", "content": p_chat_combine}] |
| 184 | + |
| 185 | + for i in self.chat_history: |
| 186 | + if "prompt" in i and "response" in i: |
| 187 | + messages_combine.append({"role": "user", "content": i["prompt"]}) |
| 188 | + messages_combine.append({"role": "assistant", "content": i["response"]}) |
| 189 | + if "tool_calls" in i: |
| 190 | + for tool_call in i["tool_calls"]: |
| 191 | + call_id = tool_call.get("call_id") or str(uuid.uuid4()) |
| 192 | + |
| 193 | + function_call_dict = { |
| 194 | + "function_call": { |
| 195 | + "name": tool_call.get("action_name"), |
| 196 | + "args": tool_call.get("arguments"), |
| 197 | + "call_id": call_id, |
| 198 | + } |
| 199 | + } |
| 200 | + function_response_dict = { |
| 201 | + "function_response": { |
| 202 | + "name": tool_call.get("action_name"), |
| 203 | + "response": {"result": tool_call.get("result")}, |
| 204 | + "call_id": call_id, |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + messages_combine.append( |
| 209 | + {"role": "assistant", "content": [function_call_dict]} |
| 210 | + ) |
| 211 | + messages_combine.append( |
| 212 | + {"role": "tool", "content": [function_response_dict]} |
| 213 | + ) |
| 214 | + messages_combine.append({"role": "user", "content": query}) |
| 215 | + return messages_combine |
| 216 | + |
| 217 | + def _retriever_search( |
| 218 | + self, |
| 219 | + retriever: BaseRetriever, |
| 220 | + query: str, |
| 221 | + log_context: Optional[LogContext] = None, |
| 222 | + ) -> List[Dict]: |
| 223 | + retrieved_data = retriever.search(query) |
| 224 | + if log_context: |
| 225 | + data = build_stack_data(retriever, exclude_attributes=["llm"]) |
| 226 | + log_context.stacks.append({"component": "retriever", "data": data}) |
| 227 | + return retrieved_data |
| 228 | + |
| 229 | + def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): |
| 230 | + resp = self.llm.gen_stream( |
| 231 | + model=self.gpt_model, messages=messages, tools=self.tools |
| 232 | + ) |
| 233 | + if log_context: |
| 234 | + data = build_stack_data(self.llm) |
| 235 | + log_context.stacks.append({"component": "llm", "data": data}) |
| 236 | + return resp |
| 237 | + |
| 238 | + def _llm_handler( |
| 239 | + self, |
| 240 | + resp, |
| 241 | + tools_dict: Dict, |
| 242 | + messages: List[Dict], |
| 243 | + log_context: Optional[LogContext] = None, |
| 244 | + ): |
| 245 | + resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) |
| 246 | + if log_context: |
| 247 | + data = build_stack_data(self.llm_handler) |
| 248 | + log_context.stacks.append({"component": "llm_handler", "data": data}) |
| 249 | + return resp |
0 commit comments