Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions application/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
from bson.objectid import ObjectId


class BaseAgent(ABC):
Expand All @@ -23,7 +24,7 @@ def __init__(
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None,
attachments: Optional[List[Dict]]=None,
attachments: Optional[List[Dict]] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
Expand Down Expand Up @@ -58,6 +59,27 @@ def _gen_inner(
) -> Generator[Dict, None, None]:
pass

def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
agents_collection = db["agents"]
tools_collection = db["user_tools"]

agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []

tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}

return tools_by_id

def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
Expand Down Expand Up @@ -243,9 +265,11 @@ def _llm_handler(
tools_dict: Dict,
messages: List[Dict],
log_context: Optional[LogContext] = None,
attachments: Optional[List[Dict]] = None
attachments: Optional[List[Dict]] = None,
):
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments)
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages, attachments
)
if log_context:
data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data})
Expand Down
12 changes: 8 additions & 4 deletions application/agents/classic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@

from application.retriever.base import BaseRetriever
import logging

logger = logging.getLogger(__name__)


class ClassicAgent(BaseAgent):
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)

tools_dict = self._get_user_tools(self.user)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)

messages = self._build_messages(self.prompt, query, retrieved_data)

resp = self._llm_gen(messages, log_context)

attachments = self.attachments

if isinstance(resp, str):
Expand All @@ -33,7 +37,7 @@ def _gen_inner(
yield {"answer": resp.message.content}
return

resp = self._llm_handler(resp, tools_dict, messages, log_context,attachments)
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)

if isinstance(resp, str):
yield {"answer": resp}
Expand Down
5 changes: 4 additions & 1 deletion application/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def _gen_inner(
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)

tools_dict = self._get_user_tools(self.user)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)

docs_together = "\n".join([doc["text"] for doc in retrieved_data])
Expand Down
99 changes: 72 additions & 27 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
agents_collection = db["agents"]
user_logs_collection = db["user_logs"]
attachments_collection = db["attachments"]

Expand Down Expand Up @@ -86,19 +86,42 @@ def run_async_chain(chain, question, chat_history):
return result


def get_agent_key(agent_id, user_id):
if not agent_id:
return None

try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)

if agent.get("user") == user_id:
agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
)
return str(agent["key"])

raise Exception("Unauthorized access to the agent", 403)

except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}")
raise


def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})
# # Raise custom exception if the API key is not found
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
data = agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)

if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
source = data.get("source")
if isinstance(source, DBRef):
source_doc = db.dereference(source)
data["source"] = str(source_doc["_id"])
if "retriever" in source_doc:
data["retriever"] = source_doc["retriever"]
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
else:
data["source"] = {}

return data


Expand Down Expand Up @@ -128,7 +151,8 @@ def save_conversation(
llm,
decoded_token,
index=None,
api_key=None
api_key=None,
agent_id=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
Expand Down Expand Up @@ -202,7 +226,9 @@ def save_conversation(
],
}
if api_key:
api_key_doc = api_key_collection.find_one({"key": api_key})
if agent_id:
conversation_data["agent_id"] = agent_id
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
conversation_id = conversations_collection.insert_one(
Expand Down Expand Up @@ -234,14 +260,17 @@ def complete_stream(
index=None,
should_save_conversation=True,
attachments=None,
agent_id=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
attachment_ids = []

if attachments:
attachment_ids = [attachment["id"] for attachment in attachments]
logger.info(f"Processing request with {len(attachments)} attachments: {attachment_ids}")
logger.info(
f"Processing request with {len(attachments)} attachments: {attachment_ids}"
)

answer = agent.gen(query=question, retriever=retriever)

Expand Down Expand Up @@ -294,7 +323,8 @@ def complete_stream(
llm,
decoded_token,
index,
api_key=user_api_key
api_key=user_api_key,
agent_id=agent_id,
)
else:
conversation_id = None
Expand Down Expand Up @@ -366,7 +396,9 @@ class Stream(Resource):
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False, default=True, description="Whether to save the conversation"
required=False,
default=True,
description="Whether to save the conversation",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
Expand Down Expand Up @@ -400,6 +432,14 @@ def post(self):
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
agent_key = get_agent_key(agent_id, request.decoded_token.get("sub"))

if agent_key:
data.update({"api_key": agent_key})
else:
agent_id = None

if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
Expand All @@ -408,6 +448,7 @@ def post(self):
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}

elif "active_docs" in data:
Expand All @@ -423,8 +464,10 @@ def post(self):

if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)

attachments = get_attachments_content(attachment_ids, decoded_token.get("sub"))

attachments = get_attachments_content(
attachment_ids, decoded_token.get("sub")
)

logger.info(
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
Expand All @@ -436,7 +479,7 @@ def post(self):
chunks = 0

agent = AgentCreator.create_agent(
settings.AGENT_NAME,
agent_type,
endpoint="stream",
llm_name=settings.LLM_NAME,
gpt_model=gpt_model,
Expand Down Expand Up @@ -471,6 +514,7 @@ def post(self):
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
agent_id=agent_id,
),
mimetype="text/event-stream",
)
Expand Down Expand Up @@ -552,6 +596,7 @@ def post(self):
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_type = settings.AGENT_NAME

if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
Expand All @@ -560,6 +605,7 @@ def post(self):
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}

elif "active_docs" in data:
Expand All @@ -584,7 +630,7 @@ def post(self):
)

agent = AgentCreator.create_agent(
settings.AGENT_NAME,
agent_type,
endpoint="api/answer",
llm_name=settings.LLM_NAME,
gpt_model=gpt_model,
Expand Down Expand Up @@ -815,28 +861,27 @@ def post(self):
def get_attachments_content(attachment_ids, user):
"""
Retrieve content from attachment documents based on their IDs.

Args:
attachment_ids (list): List of attachment document IDs
user (str): User identifier to verify ownership

Returns:
list: List of dictionaries containing attachment content and metadata
"""
if not attachment_ids:
return []

attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = attachments_collection.find_one({
"_id": ObjectId(attachment_id),
"user": user
})

attachment_doc = attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user}
)

if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(f"Error retrieving attachment {attachment_id}: {e}")

return attachments
Loading