22from fastapi import APIRouter , HTTPException
33from pydantic import BaseModel , Field
44
5+ from flare_ai_rag .ai import GeminiProvider
6+ from flare_ai_rag .attestation import Vtpm , VtpmAttestationError
7+ from flare_ai_rag .prompts import PromptService , SemanticRouterResponse
58from flare_ai_rag .responder import GeminiResponder
69from flare_ai_rag .retriever import QdrantRetriever
710from flare_ai_rag .router import GeminiRouter
@@ -29,26 +32,37 @@ class ChatRouter:
2932 generation components to handle a conversation in a single endpoint.
3033 """
3134
32- def __init__ (
35+ def __init__ ( # noqa: PLR0913
3336 self ,
3437 router : APIRouter ,
38+ ai : GeminiProvider ,
3539 query_router : GeminiRouter ,
3640 retriever : QdrantRetriever ,
3741 responder : GeminiResponder ,
42+ attestation : Vtpm ,
43+ prompts : PromptService ,
3844 ) -> None :
3945 """
4046 Initialize the ChatRouter.
4147
4248 Args:
4349 router (APIRouter): FastAPI router to attach endpoints.
44- query_router: Component that classifies the query.
45- retriever: Component that retrieves relevant documents.
46- responder: Component that generates a response.
50+ ai (GeminiProvider): AI client used by a simple semantic router
51+ to determine if an attestation was requested or if RAG
52+ pipeline should be used.
53+ query_router: RAG Component that classifies the query.
54+ retriever: RAG Component that retrieves relevant documents.
55+ responder: RAG Component that generates a response.
56+ attestation (Vtpm): Provider for attestation services
57+ prompts (PromptService): Service for managing prompts
4758 """
4859 self ._router = router
60+ self .ai = ai
4961 self .query_router = query_router
5062 self .retriever = retriever
5163 self .responder = responder
64+ self .attestation = attestation
65+ self .prompts = prompts
5266 self .logger = logger .bind (router = "chat" )
5367 self ._setup_routes ()
5468
@@ -65,35 +79,18 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor
6579 """
6680 try :
6781 self .logger .debug ("Received chat message" , message = message .message )
68- # Classify the query.
69- classification = self .query_router .route_query (message .message )
70- self .logger .info ("Query classified" , classification = classification )
71-
72- if classification == "ANSWER" :
73- # Retrieve relevant documents.
74- retrieved_docs = self .retriever .semantic_search (
75- message .message , top_k = 5
76- )
77- self .logger .info ("Documents retrieved" )
78-
79- # Generate the final answer using retrieved context.
80- answer = self .responder .generate_response (
81- message .message , retrieved_docs
82- )
83- self .logger .info ("Response generated" , answer = answer )
84- return {"classification" : classification , "response" : answer }
85-
86- # Map static responses for CLARIFY and REJECT.
87- static_responses = {
88- "CLARIFY" : "Please provide additional context." ,
89- "REJECT" : "The query is out of scope." ,
90- }
91-
92- if classification in static_responses :
93- return {
94- "classification" : classification ,
95- "response" : static_responses [classification ],
96- }
82+
83+ # If attestation has previously been requested:
84+ if self .attestation .attestation_requested :
85+ try :
86+ resp = self .attestation .get_token ([message .message ])
87+ except VtpmAttestationError as e :
88+ resp = f"The attestation failed with error:\n { e .args [0 ]} "
89+ self .attestation .attestation_requested = False
90+ return {"response" : resp }
91+
92+ route = await self .get_semantic_route (message .message )
93+ return await self .route_message (route , message .message )
9794
9895 except Exception as e :
9996 self .logger .exception ("Chat processing failed" , error = str (e ))
@@ -103,3 +100,120 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor
103100 def router (self ) -> APIRouter :
104101 """Return the underlying FastAPI router with registered endpoints."""
105102 return self ._router
103+
104+ async def get_semantic_route (self , message : str ) -> SemanticRouterResponse :
105+ """
106+ Determine the semantic route for a message using AI provider.
107+
108+ Args:
109+ message: Message to route
110+
111+ Returns:
112+ SemanticRouterResponse: Determined route for the message
113+ """
114+ try :
115+ prompt , mime_type , schema = self .prompts .get_formatted_prompt (
116+ "semantic_router" , user_input = message
117+ )
118+ route_response = self .ai .generate (
119+ prompt = prompt , response_mime_type = mime_type , response_schema = schema
120+ )
121+ return SemanticRouterResponse (route_response .text )
122+ except Exception as e :
123+ self .logger .exception ("routing_failed" , error = str (e ))
124+ return SemanticRouterResponse .CONVERSATIONAL
125+
126+ async def route_message (
127+ self , route : SemanticRouterResponse , message : str
128+ ) -> dict [str , str ]:
129+ """
130+ Route a message to the appropriate handler based on semantic route.
131+
132+ Args:
133+ route: Determined semantic route
134+ message: Original message to handle
135+
136+ Returns:
137+ dict[str, str]: Response from the appropriate handler
138+ """
139+ handlers = {
140+ SemanticRouterResponse .RAG_ROUTER : self .handle_rag_pipeline ,
141+ SemanticRouterResponse .REQUEST_ATTESTATION : self .handle_attestation ,
142+ SemanticRouterResponse .CONVERSATIONAL : self .handle_conversation ,
143+ }
144+
145+ handler = handlers .get (route )
146+ if not handler :
147+ return {"response" : "Unsupported route" }
148+
149+ return await handler (message )
150+
151+ async def handle_rag_pipeline (self , _ : str ) -> dict [str , str ]:
152+ """
153+ Handle attestation requests.
154+
155+ Args:
156+ _: Unused message parameter
157+
158+ Returns:
159+ dict[str, str]: Response containing attestation request
160+ """
161+ # Step 1. Classify the user query.
162+ prompt , mime_type , schema = self .prompts .get_formatted_prompt ("rag_router" )
163+ classification = self .query_router .route_query (
164+ prompt = prompt , response_mime_type = mime_type , response_schema = schema
165+ )
166+ self .logger .info ("Query classified" , classification = classification )
167+
168+ if classification == "ANSWER" :
169+ # Step 2. Retrieve relevant documents.
170+ retrieved_docs = self .retriever .semantic_search (_ , top_k = 5 )
171+ self .logger .info ("Documents retrieved" )
172+
173+ # Step 3. Generate the final answer.
174+ answer = self .responder .generate_response (_ , retrieved_docs )
175+ self .logger .info ("Response generated" , answer = answer )
176+ return {"classification" : classification , "response" : answer }
177+
178+ # Map static responses for CLARIFY and REJECT.
179+ static_responses = {
180+ "CLARIFY" : "Please provide additional context." ,
181+ "REJECT" : "The query is out of scope." ,
182+ }
183+
184+ if classification in static_responses :
185+ return {
186+ "classification" : classification ,
187+ "response" : static_responses [classification ],
188+ }
189+
190+ self .logger .exception ("RAG Routing failed" )
191+ raise ValueError (classification )
192+
193+ async def handle_attestation (self , _ : str ) -> dict [str , str ]:
194+ """
195+ Handle attestation requests.
196+
197+ Args:
198+ _: Unused message parameter
199+
200+ Returns:
201+ dict[str, str]: Response containing attestation request
202+ """
203+ prompt = self .prompts .get_formatted_prompt ("request_attestation" )[0 ]
204+ request_attestation_response = self .ai .generate (prompt = prompt )
205+ self .attestation .attestation_requested = True
206+ return {"response" : request_attestation_response .text }
207+
208+ async def handle_conversation (self , message : str ) -> dict [str , str ]:
209+ """
210+ Handle general conversation messages.
211+
212+ Args:
213+ message: Message to process
214+
215+ Returns:
216+ dict[str, str]: Response from AI provider
217+ """
218+ response = self .ai .send_message (message )
219+ return {"response" : response .text }
0 commit comments