Skip to content

Commit 0d74b80

Browse files
committed
feat(api): add a ChatRouter for processing multiple messages
1 parent 4bc57e5 commit 0d74b80

File tree

6 files changed

+137
-0
lines changed

6 files changed

+137
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
88
"cryptography>=44.0.1",
9+
"fastapi>=0.115.8",
910
"google-generativeai>=0.8.4",
1011
"httpx>=0.28.1",
1112
"openrouter>=1.0",

src/flare_ai_rag/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .routes.chat import ChatMessage, ChatRouter, router
2+
3+
__all__ = ["ChatMessage", "ChatRouter", "router"]

src/flare_ai_rag/api/middleware/__init__.py

Whitespace-only changes.

src/flare_ai_rag/api/routes/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import structlog
2+
from fastapi import APIRouter, HTTPException
3+
from pydantic import BaseModel, Field
4+
5+
from flare_ai_rag.responder import GeminiResponder
6+
from flare_ai_rag.retriever import QdrantRetriever
7+
from flare_ai_rag.router import GeminiRouter
8+
9+
logger = structlog.get_logger(__name__)
10+
router = APIRouter()
11+
12+
13+
class ChatMessage(BaseModel):
14+
"""
15+
Pydantic model for chat message validation.
16+
17+
Attributes:
18+
message (str): The chat message content, must not be empty
19+
"""
20+
21+
message: str = Field(..., min_length=1)
22+
23+
24+
class ChatRouter:
25+
"""
26+
A simple chat router that processes incoming messages using the RAG pipeline.
27+
28+
It wraps the existing query classification, document retrieval, and response
29+
generation components to handle a conversation in a single endpoint.
30+
"""
31+
32+
def __init__(
33+
self,
34+
router: APIRouter,
35+
query_router: GeminiRouter,
36+
retriever: QdrantRetriever,
37+
responder: GeminiResponder,
38+
) -> None:
39+
"""
40+
Initialize the ChatRouter.
41+
42+
Args:
43+
router (APIRouter): FastAPI router to attach endpoints.
44+
query_router: Component that classifies the query.
45+
retriever: Component that retrieves relevant documents.
46+
responder: Component that generates a response.
47+
"""
48+
self._router = router
49+
self.query_router = query_router
50+
self.retriever = retriever
51+
self.responder = responder
52+
self.logger = logger.bind(router="chat")
53+
self._setup_routes()
54+
55+
def _setup_routes(self) -> None:
56+
"""
57+
Set up FastAPI routes for the chat endpoint.
58+
"""
59+
60+
@self._router.post("/")
61+
async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignore [reportUnusedFunction]
62+
"""
63+
Process a chat message through the RAG pipeline.
64+
Returns a response containing the query classification and the answer.
65+
"""
66+
try:
67+
self.logger.debug("Received chat message", message=message.message)
68+
# Classify the query.
69+
classification = self.query_router.route_query(message.message)
70+
self.logger.info("Query classified", classification=classification)
71+
72+
if classification == "ANSWER":
73+
# Retrieve relevant documents.
74+
retrieved_docs = self.retriever.semantic_search(
75+
message.message, top_k=5
76+
)
77+
self.logger.info("Documents retrieved")
78+
79+
# Generate the final answer using retrieved context.
80+
answer = self.responder.generate_response(
81+
message.message, retrieved_docs
82+
)
83+
self.logger.info("Response generated", answer=answer)
84+
return {"classification": classification, "response": answer}
85+
86+
# Map static responses for CLARIFY and REJECT.
87+
static_responses = {
88+
"CLARIFY": "Please provide additional context.",
89+
"REJECT": "The query is out of scope.",
90+
}
91+
92+
if classification in static_responses:
93+
return {
94+
"classification": classification,
95+
"response": static_responses[classification],
96+
}
97+
98+
except Exception as e:
99+
self.logger.exception("Chat processing failed", error=str(e))
100+
raise HTTPException(status_code=500, detail=str(e)) from e
101+
102+
@property
103+
def router(self) -> APIRouter:
104+
"""Return the underlying FastAPI router with registered endpoints."""
105+
return self._router

uv.lock

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)