Skip to content

Commit 70ee027

Browse files
committed
feat(api): add support for attestation
1 parent e363e3f commit 70ee027

File tree

6 files changed

+188
-85
lines changed

6 files changed

+188
-85
lines changed

src/flare_ai_rag/api/routes/chat.py

Lines changed: 147 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from fastapi import APIRouter, HTTPException
33
from pydantic import BaseModel, Field
44

5+
from flare_ai_rag.ai import GeminiProvider
6+
from flare_ai_rag.attestation import Vtpm, VtpmAttestationError
7+
from flare_ai_rag.prompts import PromptService, SemanticRouterResponse
58
from flare_ai_rag.responder import GeminiResponder
69
from flare_ai_rag.retriever import QdrantRetriever
710
from flare_ai_rag.router import GeminiRouter
@@ -29,26 +32,37 @@ class ChatRouter:
2932
generation components to handle a conversation in a single endpoint.
3033
"""
3134

32-
def __init__(
35+
def __init__( # noqa: PLR0913
3336
self,
3437
router: APIRouter,
38+
ai: GeminiProvider,
3539
query_router: GeminiRouter,
3640
retriever: QdrantRetriever,
3741
responder: GeminiResponder,
42+
attestation: Vtpm,
43+
prompts: PromptService,
3844
) -> None:
3945
"""
4046
Initialize the ChatRouter.
4147
4248
Args:
4349
router (APIRouter): FastAPI router to attach endpoints.
44-
query_router: Component that classifies the query.
45-
retriever: Component that retrieves relevant documents.
46-
responder: Component that generates a response.
50+
ai (GeminiProvider): AI client used by a simple semantic router
51+
to determine if an attestation was requested or if RAG
52+
pipeline should be used.
53+
query_router: RAG Component that classifies the query.
54+
retriever: RAG Component that retrieves relevant documents.
55+
responder: RAG Component that generates a response.
56+
attestation (Vtpm): Provider for attestation services
57+
prompts (PromptService): Service for managing prompts
4758
"""
4859
self._router = router
60+
self.ai = ai
4961
self.query_router = query_router
5062
self.retriever = retriever
5163
self.responder = responder
64+
self.attestation = attestation
65+
self.prompts = prompts
5266
self.logger = logger.bind(router="chat")
5367
self._setup_routes()
5468

@@ -65,35 +79,18 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor
6579
"""
6680
try:
6781
self.logger.debug("Received chat message", message=message.message)
68-
# Classify the query.
69-
classification = self.query_router.route_query(message.message)
70-
self.logger.info("Query classified", classification=classification)
71-
72-
if classification == "ANSWER":
73-
# Retrieve relevant documents.
74-
retrieved_docs = self.retriever.semantic_search(
75-
message.message, top_k=5
76-
)
77-
self.logger.info("Documents retrieved")
78-
79-
# Generate the final answer using retrieved context.
80-
answer = self.responder.generate_response(
81-
message.message, retrieved_docs
82-
)
83-
self.logger.info("Response generated", answer=answer)
84-
return {"classification": classification, "response": answer}
85-
86-
# Map static responses for CLARIFY and REJECT.
87-
static_responses = {
88-
"CLARIFY": "Please provide additional context.",
89-
"REJECT": "The query is out of scope.",
90-
}
91-
92-
if classification in static_responses:
93-
return {
94-
"classification": classification,
95-
"response": static_responses[classification],
96-
}
82+
83+
# If attestation has previously been requested:
84+
if self.attestation.attestation_requested:
85+
try:
86+
resp = self.attestation.get_token([message.message])
87+
except VtpmAttestationError as e:
88+
resp = f"The attestation failed with error:\n{e.args[0]}"
89+
self.attestation.attestation_requested = False
90+
return {"response": resp}
91+
92+
route = await self.get_semantic_route(message.message)
93+
return await self.route_message(route, message.message)
9794

9895
except Exception as e:
9996
self.logger.exception("Chat processing failed", error=str(e))
@@ -103,3 +100,120 @@ async def chat(message: ChatMessage) -> dict[str, str] | None: # pyright: ignor
103100
def router(self) -> APIRouter:
104101
"""Return the underlying FastAPI router with registered endpoints."""
105102
return self._router
103+
104+
async def get_semantic_route(self, message: str) -> SemanticRouterResponse:
105+
"""
106+
Determine the semantic route for a message using AI provider.
107+
108+
Args:
109+
message: Message to route
110+
111+
Returns:
112+
SemanticRouterResponse: Determined route for the message
113+
"""
114+
try:
115+
prompt, mime_type, schema = self.prompts.get_formatted_prompt(
116+
"semantic_router", user_input=message
117+
)
118+
route_response = self.ai.generate(
119+
prompt=prompt, response_mime_type=mime_type, response_schema=schema
120+
)
121+
return SemanticRouterResponse(route_response.text)
122+
except Exception as e:
123+
self.logger.exception("routing_failed", error=str(e))
124+
return SemanticRouterResponse.CONVERSATIONAL
125+
126+
async def route_message(
127+
self, route: SemanticRouterResponse, message: str
128+
) -> dict[str, str]:
129+
"""
130+
Route a message to the appropriate handler based on semantic route.
131+
132+
Args:
133+
route: Determined semantic route
134+
message: Original message to handle
135+
136+
Returns:
137+
dict[str, str]: Response from the appropriate handler
138+
"""
139+
handlers = {
140+
SemanticRouterResponse.RAG_ROUTER: self.handle_rag_pipeline,
141+
SemanticRouterResponse.REQUEST_ATTESTATION: self.handle_attestation,
142+
SemanticRouterResponse.CONVERSATIONAL: self.handle_conversation,
143+
}
144+
145+
handler = handlers.get(route)
146+
if not handler:
147+
return {"response": "Unsupported route"}
148+
149+
return await handler(message)
150+
151+
async def handle_rag_pipeline(self, _: str) -> dict[str, str]:
152+
"""
153+
Handle attestation requests.
154+
155+
Args:
156+
_: Unused message parameter
157+
158+
Returns:
159+
dict[str, str]: Response containing attestation request
160+
"""
161+
# Step 1. Classify the user query.
162+
prompt, mime_type, schema = self.prompts.get_formatted_prompt("rag_router")
163+
classification = self.query_router.route_query(
164+
prompt=prompt, response_mime_type=mime_type, response_schema=schema
165+
)
166+
self.logger.info("Query classified", classification=classification)
167+
168+
if classification == "ANSWER":
169+
# Step 2. Retrieve relevant documents.
170+
retrieved_docs = self.retriever.semantic_search(_, top_k=5)
171+
self.logger.info("Documents retrieved")
172+
173+
# Step 3. Generate the final answer.
174+
answer = self.responder.generate_response(_, retrieved_docs)
175+
self.logger.info("Response generated", answer=answer)
176+
return {"classification": classification, "response": answer}
177+
178+
# Map static responses for CLARIFY and REJECT.
179+
static_responses = {
180+
"CLARIFY": "Please provide additional context.",
181+
"REJECT": "The query is out of scope.",
182+
}
183+
184+
if classification in static_responses:
185+
return {
186+
"classification": classification,
187+
"response": static_responses[classification],
188+
}
189+
190+
self.logger.exception("RAG Routing failed")
191+
raise ValueError(classification)
192+
193+
async def handle_attestation(self, _: str) -> dict[str, str]:
194+
"""
195+
Handle attestation requests.
196+
197+
Args:
198+
_: Unused message parameter
199+
200+
Returns:
201+
dict[str, str]: Response containing attestation request
202+
"""
203+
prompt = self.prompts.get_formatted_prompt("request_attestation")[0]
204+
request_attestation_response = self.ai.generate(prompt=prompt)
205+
self.attestation.attestation_requested = True
206+
return {"response": request_attestation_response.text}
207+
208+
async def handle_conversation(self, message: str) -> dict[str, str]:
209+
"""
210+
Handle general conversation messages.
211+
212+
Args:
213+
message: Message to process
214+
215+
Returns:
216+
dict[str, str]: Response from AI provider
217+
"""
218+
response = self.ai.send_message(message)
219+
return {"response": response.text}

src/flare_ai_rag/main.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from flare_ai_rag.ai import GeminiEmbedding, GeminiProvider
1717
from flare_ai_rag.api import ChatRouter
18+
from flare_ai_rag.attestation import Vtpm
19+
from flare_ai_rag.prompts import PromptService
1820
from flare_ai_rag.responder import GeminiResponder, ResponderConfig
1921
from flare_ai_rag.retriever import QdrantRetriever, RetrieverConfig, generate_collection
2022
from flare_ai_rag.router import GeminiRouter, RouterConfig
@@ -24,20 +26,20 @@
2426
logger = structlog.get_logger(__name__)
2527

2628

27-
def setup_router(input_config: dict) -> GeminiRouter:
28-
"""Initialize the Gemini Provider and the Gemini Router."""
29+
def setup_router(input_config: dict) -> tuple[GeminiProvider, GeminiRouter]:
30+
"""Initialize a Gemini Provider for routing."""
2931
# Setup router config
3032
router_model_config = input_config["router_model"]
3133
router_config = RouterConfig.load(router_model_config)
3234

3335
# Setup Gemini client based on Router config
36+
# Older version used a system_instruction
3437
gemini_provider = GeminiProvider(
35-
api_key=settings.gemini_api_key,
36-
model=router_config.model.model_id,
37-
system_instruction=router_config.system_prompt,
38+
api_key=settings.gemini_api_key, model=router_config.model.model_id
3839
)
40+
gemini_router = GeminiRouter(client=gemini_provider, config=router_config)
3941

40-
return GeminiRouter(client=gemini_provider, config=router_config)
42+
return gemini_provider, gemini_router
4143

4244

4345
def setup_retriever(
@@ -128,8 +130,8 @@ def create_app() -> FastAPI:
128130
df_docs = pd.read_csv(settings.data_path / "docs.csv", delimiter=",")
129131
logger.info("Loaded CSV Data.", num_rows=len(df_docs))
130132

131-
# Set up the RAG components: 1. Gemini Router
132-
router_component = setup_router(input_config)
133+
# Set up the RAG components: 1. Gemini Provider
134+
base_ai, router_component = setup_router(input_config)
133135

134136
# 2a. Set up Qdrant client.
135137
qdrant_client = setup_qdrant(input_config)
@@ -143,9 +145,12 @@ def create_app() -> FastAPI:
143145
# Create an APIRouter for chat endpoints and initialize ChatRouter.
144146
chat_router = ChatRouter(
145147
router=APIRouter(),
148+
ai=base_ai,
146149
query_router=router_component,
147150
retriever=retriever_component,
148151
responder=responder_component,
152+
attestation=Vtpm(simulate=settings.simulate_attestation),
153+
prompts=PromptService(),
149154
)
150155
app.include_router(chat_router.router, prefix="/api/routes/chat", tags=["chat"])
151156

src/flare_ai_rag/prompts/service.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
11
"""
2-
Prompt Service Module for Flare AI DeFAI
2+
Prompt Service Module for Flare AI RAG
33
44
This module provides a service layer for managing and formatting AI prompts.
55
It acts as a wrapper around the PromptLibrary, adding error handling and
66
logging capabilities. The service is responsible for retrieving prompts,
77
formatting them with provided parameters, and returning the formatted prompts
88
along with their associated metadata.
9-
10-
Example:
11-
```python
12-
service = PromptService()
13-
prompt, mime_type, schema = service.get_formatted_prompt(
14-
"token_send", amount="100", address="0x123..."
15-
)
16-
```
179
"""
1810

1911
from typing import Any
@@ -40,17 +32,6 @@ class to provide additional functionality and safety checks.
4032
library (PromptLibrary): Instance of the prompt library containing all
4133
prompt templates
4234
logger (BoundLogger): Structured logger bound with service context
43-
44-
Example:
45-
```python
46-
service = PromptService()
47-
try:
48-
prompt, mime_type, schema = service.get_formatted_prompt(
49-
"token_send", to_address="0x123...", amount=100
50-
)
51-
except Exception as e:
52-
print(f"Failed to format prompt: {e}")
53-
```
5435
"""
5536

5637
def __init__(self) -> None:
@@ -89,19 +70,6 @@ def get_formatted_prompt(
8970
ValueError: If required format parameters are missing
9071
Exception: For other formatting or processing errors
9172
92-
Example:
93-
```python
94-
service = PromptService()
95-
try:
96-
prompt, mime_type, schema = service.get_formatted_prompt(
97-
"token_swap", from_token="ETH", to_token="USDC", amount=1.5
98-
)
99-
except KeyError:
100-
print("Prompt template not found")
101-
except ValueError:
102-
print("Missing required parameters")
103-
```
104-
10573
Logs:
10674
- Exceptions during prompt formatting with prompt name and error details
10775
"""

src/flare_ai_rag/router/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
23

34

45
class BaseQueryRouter(ABC):
@@ -7,7 +8,12 @@ class BaseQueryRouter(ABC):
78
"""
89

910
@abstractmethod
10-
def route_query(self, query: str) -> str:
11+
def route_query(
12+
self,
13+
prompt: str,
14+
response_mime_type: str | None = None,
15+
response_schema: Any | None = None,
16+
) -> str:
1117
"""
1218
Determine the type of the query: ANSWER, CLARIFY, or REJECT.
1319
"""

0 commit comments

Comments
 (0)