Skip to content

Commit

Permalink
Fix agent endpoint (#1981)
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem authored Feb 17, 2025
1 parent 4575485 commit 6601162
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ async def stream_generator():
return response

@self.router.post(
"/retrieval/rag_agent",
"/retrieval/agent",
dependencies=[Depends(self.rate_limit_dependency)],
summary="RAG-powered Conversational Agent",
openapi_extra={
Expand Down
19 changes: 12 additions & 7 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def convert_nonserializable_objects(obj):
new_obj = {}
for key, value in obj.items():
# Convert key to string if it is a UUID or not already a string.
new_key = str(key) if not isinstance(key, str) else key
new_key = key if isinstance(key, str) else str(key)
new_obj[new_key] = convert_nonserializable_objects(value)
return new_obj
elif isinstance(obj, list):
Expand Down Expand Up @@ -104,10 +104,8 @@ def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:

def tokens_count_for_message(message, encoding):
"""Return the number of tokens used by a single message."""
tokens_per_message = 3
num_tokens = 3

num_tokens = 0
num_tokens += tokens_per_message
if message.get("function_call"):
num_tokens += len(encoding.encode(message["function_call"]["name"]))
num_tokens += len(
Expand Down Expand Up @@ -1136,8 +1134,8 @@ def _parse_user_and_collection_filters(
filters: dict[str, Any],
):
### TODO - Come up with smarter way to extract owner / collection ids for non-admin
filter_starts_with_and = filters.get("$and", None)
filter_starts_with_or = filters.get("$or", None)
filter_starts_with_and = filters.get("$and")
filter_starts_with_or = filters.get("$or")
if filter_starts_with_and:
try:
filter_starts_with_and_then_or = filter_starts_with_and[0][
Expand Down Expand Up @@ -1262,8 +1260,15 @@ async def _build_aware_system_instruction(
else self.config.agent.agent_static_prompt
)

# TODO: This should just be enforced in the config
if model is None:
raise R2RException(
status_code=400,
message="Model not provided for system instruction",
)

if ("gemini" in model or "claude" in model) and reasoning_agent:
prompt_name = prompt_name + "_prompted_reasoning"
prompt_name = f"{prompt_name}_prompted_reasoning"

if use_system_context or reasoning_agent:
doc_context_str = await self._build_documents_context(
Expand Down

0 comments on commit 6601162

Please sign in to comment.