Skip to content

Commit

Permalink
Add routes tests (#1971)
Browse files Browse the repository at this point in the history
* Add routes tests

* Recomment rate limit on users.create

* Add response models for completions and embeddings
  • Loading branch information
NolanTrem authored Feb 12, 2025
1 parent e44a257 commit 01b6627
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 29 deletions.
4 changes: 4 additions & 0 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
WrappedAgentResponse,
WrappedCompletionResponse,
WrappedDocumentSearchResponse,
WrappedEmbeddingResponse,
WrappedLLMChatCompletion,
WrappedRAGResponse,
WrappedSearchResponse,
WrappedVectorSearchResponse,
Expand Down Expand Up @@ -167,4 +169,6 @@
"WrappedCompletionResponse",
"WrappedRAGResponse",
"WrappedAgentResponse",
"WrappedLLMChatCompletion",
"WrappedEmbeddingResponse",
]
1 change: 1 addition & 0 deletions py/core/main/api/v3/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def wrapper(*args, **kwargs):
},
) from e

wrapper._is_base_endpoint = True
return wrapper

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ async def update_graph(
None, description="An optional description of the graph"
),
auth_user=Depends(self.providers.auth.auth_wrapper()),
):
) -> WrappedGraphResponse:
"""
Update an existing graphs's configuration.
Expand Down Expand Up @@ -1748,7 +1748,7 @@ async def delete_community(
description="The ID of the community to delete.",
),
auth_user=Depends(self.providers.auth.auth_wrapper()),
):
) -> WrappedBooleanResponse:
if (
not auth_user.is_superuser
and collection_id not in auth_user.graph_ids
Expand Down
3 changes: 2 additions & 1 deletion py/core/main/api/v3/logs_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import Depends, WebSocket
from fastapi.requests import Request
from fastapi.templating import Jinja2Templates
from starlette.templating import _TemplateResponse

from ...abstractions import R2RProviders, R2RServices
from ...config import R2RConfig
Expand Down Expand Up @@ -102,7 +103,7 @@ async def stream_logs(
"/logs/viewer",
dependencies=[Depends(self.rate_limit_dependency)],
)
async def get_log_viewer(request: Request):
async def get_log_viewer(request: Request) -> _TemplateResponse:
return self.templates.TemplateResponse(
"log_viewer.html", {"request": request}
)
12 changes: 6 additions & 6 deletions py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from core.base.api.models import (
WrappedAgentResponse,
WrappedCompletionResponse,
WrappedEmbeddingResponse,
WrappedLLMChatCompletion,
WrappedRAGResponse,
WrappedSearchResponse,
)
Expand All @@ -42,11 +44,11 @@ def merge_search_settings(
return SearchSettings(**base_dict)


class RetrievalRouterV3(BaseRouterV3):
class RetrievalRouter(BaseRouterV3):
def __init__(
self, providers: R2RProviders, services: R2RServices, config: R2RConfig
):
logging.info("Initializing RetrievalRouterV3")
logging.info("Initializing RetrievalRouter")
super().__init__(providers, services, config)

def _register_workflows(self):
Expand Down Expand Up @@ -996,8 +998,7 @@ async def completion(
),
auth_user=Depends(self.providers.auth.auth_wrapper()),
response_model=WrappedCompletionResponse,
):
# FIXME: Needs a proper return type
) -> WrappedLLMChatCompletion:
"""
Generate completions for a list of messages.
Expand Down Expand Up @@ -1075,8 +1076,7 @@ async def embedding(
description="Text to generate embeddings for",
),
auth_user=Depends(self.providers.auth.auth_wrapper()),
):
# FIXME: Needs a proper return type
) -> WrappedEmbeddingResponse:
"""
Generate embeddings for the provided text using the specified model.
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/v3/users_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ async def change_password(
current_password: str = Body(..., description="Current password"),
new_password: str = Body(..., description="New password"),
auth_user=Depends(self.providers.auth.auth_wrapper()),
) -> GenericMessageResponse:
) -> WrappedGenericMessageResponse:
"""Change the authenticated user's password."""
result = await self.services.auth.change_password(
auth_user, current_password, new_password
Expand Down
8 changes: 4 additions & 4 deletions py/core/main/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .api.v3.indices_router import IndicesRouter
from .api.v3.logs_router import LogsRouter
from .api.v3.prompts_router import PromptsRouter
from .api.v3.retrieval_router import RetrievalRouterV3
from .api.v3.retrieval_router import RetrievalRouter
from .api.v3.system_router import SystemRouter
from .api.v3.users_router import UsersRouter
from .config import R2RConfig
Expand All @@ -41,7 +41,7 @@ def __init__(
indices_router: IndicesRouter,
logs_router: LogsRouter,
prompts_router: PromptsRouter,
retrieval_router_v3: RetrievalRouterV3,
retrieval_router: RetrievalRouter,
system_router: SystemRouter,
users_router: UsersRouter,
):
Expand All @@ -58,7 +58,7 @@ def __init__(
self.logs_router = logs_router
self.orchestration_provider = orchestration_provider
self.prompts_router = prompts_router
self.retrieval_router_v3 = retrieval_router_v3
self.retrieval_router = retrieval_router
self.system_router = system_router
self.users_router = users_router

Expand Down Expand Up @@ -86,7 +86,7 @@ def _setup_routes(self):
self.app.include_router(self.indices_router, prefix="/v3")
self.app.include_router(self.logs_router, prefix="/v3")
self.app.include_router(self.prompts_router, prefix="/v3")
self.app.include_router(self.retrieval_router_v3, prefix="/v3")
self.app.include_router(self.retrieval_router, prefix="/v3")
self.app.include_router(self.system_router, prefix="/v3")
self.app.include_router(self.users_router, prefix="/v3")

Expand Down
4 changes: 2 additions & 2 deletions py/core/main/assembly/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..api.v3.indices_router import IndicesRouter
from ..api.v3.logs_router import LogsRouter
from ..api.v3.prompts_router import PromptsRouter
from ..api.v3.retrieval_router import RetrievalRouterV3
from ..api.v3.retrieval_router import RetrievalRouter
from ..api.v3.system_router import SystemRouter
from ..api.v3.users_router import UsersRouter
from ..app import R2RApp
Expand Down Expand Up @@ -90,7 +90,7 @@ async def build(self, *args, **kwargs) -> R2RApp:
services=services,
config=self.config,
).get_router(),
"retrieval_router_v3": RetrievalRouterV3(
"retrieval_router": RetrievalRouter(
providers=providers,
services=services,
config=self.config,
Expand Down
16 changes: 10 additions & 6 deletions py/sdk/asnyc_methods/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from shared.api.models import (
WrappedAgentResponse,
WrappedEmbeddingResponse,
WrappedLLMChatCompletion,
WrappedRAGResponse,
WrappedSearchResponse,
)
Expand Down Expand Up @@ -65,8 +67,7 @@ async def completion(
self,
messages: list[dict | Message],
generation_config: Optional[dict | GenerationConfig] = None,
):
# FIXME: Needs a proper return type
) -> WrappedLLMChatCompletion:
cast_messages: list[Message] = [
Message(**msg) if isinstance(msg, dict) else msg
for msg in messages
Expand All @@ -79,29 +80,32 @@ async def completion(
"messages": [msg.model_dump() for msg in cast_messages],
"generation_config": generation_config,
}
return await self.client._make_request(
response_dict = await self.client._make_request(
"POST",
"retrieval/completion",
json=data,
version="v3",
)

return WrappedLLMChatCompletion(**response_dict)

async def embedding(
self,
text: str,
):
# FIXME: Needs a proper return type
) -> WrappedEmbeddingResponse:
data: dict[str, Any] = {
"text": text,
}

return await self.client._make_request(
response_dict = await self.client._make_request(
"POST",
"retrieval/embedding",
data=data,
version="v3",
)

return WrappedEmbeddingResponse(**response_dict)

async def rag(
self,
query: str,
Expand Down
16 changes: 10 additions & 6 deletions py/sdk/sync_methods/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from shared.api.models import (
WrappedAgentResponse,
WrappedEmbeddingResponse,
WrappedLLMChatCompletion,
WrappedRAGResponse,
WrappedSearchResponse,
)
Expand Down Expand Up @@ -66,8 +68,7 @@ def completion(
self,
messages: list[dict | Message],
generation_config: Optional[dict | GenerationConfig] = None,
):
# FIXME: Needs a proper return type
) -> WrappedLLMChatCompletion:
cast_messages: list[Message] = [
Message(**msg) if isinstance(msg, dict) else msg
for msg in messages
Expand All @@ -80,29 +81,32 @@ def completion(
"messages": [msg.model_dump() for msg in cast_messages],
"generation_config": generation_config,
}
return self.client._make_request(
response_dict = self.client._make_request(
"POST",
"retrieval/completion",
json=data,
version="v3",
)

return WrappedLLMChatCompletion(**response_dict)

def embedding(
self,
text: str,
):
# FIXME: Needs a proper return type
) -> WrappedEmbeddingResponse:
data: dict[str, Any] = {
"text": text,
}

return self.client._make_request(
response_dict = self.client._make_request(
"POST",
"retrieval/embedding",
data=data,
version="v3",
)

return WrappedEmbeddingResponse(**response_dict)

def rag(
self,
query: str,
Expand Down
4 changes: 4 additions & 0 deletions py/shared/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
RAGResponse,
WrappedAgentResponse,
WrappedDocumentSearchResponse,
WrappedEmbeddingResponse,
WrappedLLMChatCompletion,
WrappedRAGResponse,
WrappedSearchResponse,
WrappedVectorSearchResponse,
Expand Down Expand Up @@ -150,4 +152,6 @@
"WrappedDocumentSearchResponse",
"WrappedVectorSearchResponse",
"WrappedAgentResponse",
"WrappedLLMChatCompletion",
"WrappedEmbeddingResponse",
]
4 changes: 3 additions & 1 deletion py/shared/api/models/retrieval/responses.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Optional

from deprecated import deprecated
from pydantic import BaseModel, Field

from shared.abstractions import (
AggregateSearchResult,
ChunkSearchResult,
LLMChatCompletion,
Message,
)
from shared.abstractions.llm import LLMChatCompletion
Expand Down Expand Up @@ -227,3 +227,5 @@ class DocumentSearchResult(BaseModel):
WrappedDocumentSearchResponse = R2RResults[list[DocumentResponse]]
WrappedRAGResponse = R2RResults[RAGResponse]
WrappedAgentResponse = R2RResults[AgentResponse]
WrappedLLMChatCompletion = R2RResults[LLMChatCompletion]
WrappedEmbeddingResponse = R2RResults[list[float]]
Loading

0 comments on commit 01b6627

Please sign in to comment.