Skip to content

Commit c6e1459

Browse files
committed
feat: add responder and create RAG routine
1 parent 7715cdf commit c6e1459

File tree

19 files changed

+428
-102
lines changed

19 files changed

+428
-102
lines changed

.env.dummy

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# OpenRouter base url and API key
2+
OPENROUTER_BASE_URL="https://openrouter.ai/api/v1"
3+
OPENROUTER_API_KEY=""
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"router_model": {
33
"id": "qwen/qwen-vl-plus:free",
4-
"max_tokens": 20,
4+
"max_tokens": 5,
55
"temperature": 0
66
},
77
"qdrant_config": {
@@ -10,5 +10,10 @@
1010
"vector_size": 384,
1111
"host": "localhost",
1212
"port": 6333
13+
},
14+
"responder_model": {
15+
"id": "deepseek/deepseek-chat:free",
16+
"max_tokens": 200,
17+
"temperature": 0
1318
}
1419
}

src/flare_ai_rag/main.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import pandas as pd
2+
import structlog
3+
from qdrant_client import QdrantClient
4+
5+
from flare_ai_rag.config import config
6+
from flare_ai_rag.openrouter.client import OpenRouterClient
7+
from flare_ai_rag.responder.config import ResponderConfig
8+
from flare_ai_rag.responder.responder import OpenRouterResponder
9+
from flare_ai_rag.retriever.config import QdrantConfig
10+
from flare_ai_rag.retriever.qdrant_collection import generate_collection
11+
from flare_ai_rag.retriever.qdrant_retriever import QdrantRetriever
12+
from flare_ai_rag.router.config import RouterConfig
13+
from flare_ai_rag.router.router import QueryRouter
14+
from flare_ai_rag.utils import loader
15+
16+
logger = structlog.get_logger(__name__)
17+
18+
19+
def setup_clients(input_config: dict) -> tuple[OpenRouterClient, QdrantClient]:
20+
"""Initialize OpenRouter and Qdrant clients."""
21+
# Setup OpenRouter client.
22+
openrouter_client = OpenRouterClient(
23+
api_key=config.open_router_api_key, base_url=config.open_router_base_url
24+
)
25+
26+
# Setup Qdrant client.
27+
qdrant_config = QdrantConfig.load(input_config["qdrant_config"])
28+
qdrant_client = QdrantClient(host=qdrant_config.host, port=qdrant_config.port)
29+
30+
return openrouter_client, qdrant_client
31+
32+
33+
def setup_router(
34+
openrouter_client: OpenRouterClient, input_config: dict
35+
) -> QueryRouter:
36+
"""Initialize the query router."""
37+
router_model_config = input_config["router_model"]
38+
router_config = RouterConfig.load(router_model_config)
39+
return QueryRouter(client=openrouter_client, config=router_config)
40+
41+
42+
def setup_responder(
43+
openrouter_client: OpenRouterClient, input_config: dict
44+
) -> OpenRouterResponder:
45+
"""Initialize the responder."""
46+
responder_config = input_config["responder_model"]
47+
responder_config = ResponderConfig.load(responder_config)
48+
return OpenRouterResponder(
49+
client=openrouter_client, responder_config=responder_config
50+
)
51+
52+
53+
def setup_retriever(
54+
qdrant_client: QdrantClient,
55+
input_config: dict,
56+
df_docs: pd.DataFrame,
57+
collection: str | None = None,
58+
) -> QdrantRetriever:
59+
"""Initialize the Qdrant retriever."""
60+
qdrant_config = QdrantConfig.load(input_config["qdrant_config"])
61+
62+
# (Re)generate qdrant collection
63+
if collection:
64+
generate_collection(
65+
df_docs, qdrant_client, qdrant_config, collection_name=collection
66+
)
67+
# Return retriever
68+
return QdrantRetriever(client=qdrant_client, qdrant_config=qdrant_config)
69+
70+
71+
def main() -> None:
72+
# Load input configuration.
73+
input_config = loader.load_json(config.input_path / "input_parameters.json")
74+
75+
# Setup clients.
76+
openrouter_client, qdrant_client = setup_clients(input_config)
77+
78+
# Setup the router.
79+
router = setup_router(openrouter_client, input_config)
80+
81+
# Process user query.
82+
query = loader.load_txt(config.input_path / "query.txt")
83+
classification = router.route_query(query)
84+
logger.info("Queried classified.", classification=classification)
85+
86+
if classification == "ANSWER":
87+
df_docs = pd.read_csv(config.data_path / "docs.csv", delimiter=",")
88+
logger.info("Loaded CSV Data.", num_rows=len(df_docs))
89+
90+
# Retrieve docs
91+
retriever = setup_retriever(
92+
qdrant_client, input_config, df_docs, collection="docs_collection"
93+
)
94+
retrieved_docs = retriever.semantic_search(query, top_k=5)
95+
96+
# Prepare answer
97+
responder = setup_responder(openrouter_client, input_config)
98+
answer = responder.generate_response(query, retrieved_docs)
99+
logger.info("Answer retrieved.", answer=answer)
100+
elif classification == "CLARIFY":
101+
logger.info("Your query needs clarification. Please provide more details.")
102+
elif classification == "REJECT":
103+
logger.info("Your query has been rejected as it is out of scope.")
104+
else:
105+
logger.info("Unexpected classification.", classification=classification)
106+
107+
108+
if __name__ == "__main__":
109+
main()

src/flare_ai_rag/query.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
What is the block time for the Flare blockchain?
2+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class BaseResponder(ABC):
5+
@abstractmethod
6+
def generate_response(self, query: str, retrieved_documents: list[dict]) -> str:
7+
"""
8+
Generate a final answer given the query and a list of retrieved documents.
9+
"""
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from dataclasses import dataclass
2+
3+
from flare_ai_rag.config import config
4+
from flare_ai_rag.openrouter.model import Model
5+
from flare_ai_rag.utils import loader
6+
7+
# Load base prompt
8+
BASE_PROMPT = loader.load_txt(config.input_path / "responder" / "prompts.txt")
9+
10+
11+
@dataclass(frozen=True)
12+
class ResponderConfig:
13+
model: Model
14+
base_prompt: str
15+
16+
@staticmethod
17+
def load(model_config: dict) -> "ResponderConfig":
18+
"""Loads the Responder config."""
19+
model = Model(
20+
model_id=model_config["id"],
21+
max_tokens=model_config["max_tokens"],
22+
temperature=model_config["temperature"],
23+
)
24+
25+
return ResponderConfig(model=model, base_prompt=BASE_PROMPT)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Answer the following question using the provided context.
2+
Include citations for supporting evidence in your answer.
3+
4+
Question: {query}
5+
6+
Context: {context}
7+
8+
Answer:
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import override
2+
3+
from flare_ai_rag.openrouter.client import OpenRouterClient
4+
from flare_ai_rag.responder.base_responder import BaseResponder
5+
from flare_ai_rag.responder.config import ResponderConfig
6+
from flare_ai_rag.utils import parser
7+
8+
9+
class OpenRouterResponder(BaseResponder):
10+
def __init__(
11+
self, client: OpenRouterClient, responder_config: ResponderConfig
12+
) -> None:
13+
"""
14+
Initialize the responder with an OpenRouter client and the model to use.
15+
16+
:param client: An instance of OpenRouterClient.
17+
:param model: The model identifier to be used by the API.
18+
"""
19+
self.client = client
20+
self.responder_config = responder_config
21+
22+
@override
23+
def generate_response(self, query: str, retrieved_documents: list[dict]) -> str:
24+
"""
25+
Generate a final answer using the query and the retrieved context,
26+
and include citations.
27+
28+
:param query: The input query.
29+
:param retrieved_documents: A list of dictionaries containing retrieved docs.
30+
:return: The generated answer as a string.
31+
"""
32+
context = ""
33+
34+
# Build context from the retrieved documents.
35+
for idx, doc in enumerate(retrieved_documents, start=1):
36+
identifier = doc.get("metadata", {}).get("filename", f"Doc{idx}")
37+
context += f"Document {identifier}:\n{doc.get('text', '')}\n\n"
38+
39+
# Compose the prompt
40+
prompt = self.responder_config.base_prompt.format(query=query, context=context)
41+
# Prepare the payload for the completion endpoint.
42+
payload = {
43+
"model": self.responder_config.model.model_id,
44+
"messages": [{"role": "user", "content": prompt}],
45+
"max_tokens": self.responder_config.model.max_tokens,
46+
"temperature": self.responder_config.model.temperature,
47+
}
48+
# Send the prompt to the OpenRouter API.
49+
response = self.client.send_chat_completion(payload)
50+
51+
return parser.parse_openrouter_response(response)

src/flare_ai_rag/retriever/qdrant_collection.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
import pandas as pd
2+
import structlog
13
from qdrant_client import QdrantClient
2-
from qdrant_client.http.models import Distance, VectorParams
4+
from qdrant_client.http.models import Distance, PointStruct, VectorParams
5+
from sentence_transformers import SentenceTransformer
6+
7+
from flare_ai_rag.retriever.config import QdrantConfig
8+
9+
logger = structlog.get_logger(__name__)
310

411

512
def create_collection(
@@ -15,3 +22,63 @@ def create_collection(
1522
collection_name=collection_name,
1623
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
1724
)
25+
26+
27+
def generate_collection(
28+
df_docs: pd.DataFrame,
29+
client: QdrantClient,
30+
qdrant_config: QdrantConfig,
31+
collection_name: str,
32+
) -> None:
33+
"""Routine for generating a Qdrant collection for a specific CSV file type."""
34+
# Create the collection.
35+
create_collection(client, collection_name, qdrant_config.vector_size)
36+
logger.info("Created the collection.", collection_name=collection_name)
37+
38+
# Load the embedding model.
39+
embedding_model = SentenceTransformer(qdrant_config.embedding_model)
40+
41+
# For each document in the CSV, compute its embedding and prepare a Qdrant point.
42+
points = []
43+
for i, row in df_docs.iterrows():
44+
doc_id = str(i)
45+
content = row["Contents"]
46+
47+
# Check if content is missing or not a string.
48+
if not isinstance(content, str):
49+
logger.warning(
50+
"Skipping document due to missing or invalid content.",
51+
filename=row["Filename"],
52+
)
53+
continue
54+
55+
try:
56+
# Compute the embedding for the document content.
57+
embedding = embedding_model.encode(content).tolist()
58+
except Exception as e:
59+
logger.exception(
60+
"Error encoding document.", filename=row["Filename"], error=str(e)
61+
)
62+
continue
63+
64+
# Prepare the payload.
65+
payload = {
66+
"filename": row["Filename"],
67+
"metadata": row["Metadata"],
68+
"text": content,
69+
}
70+
71+
# Create a Qdrant point.
72+
point = PointStruct(id=doc_id, vector=embedding, payload=payload)
73+
points.append(point)
74+
75+
if points:
76+
# Upload the points into the Qdrant collection.
77+
client.upsert(collection_name=collection_name, points=points)
78+
logger.info(
79+
"Collection generated and documents inserted into Qdrant successfully.",
80+
collection_name=collection_name,
81+
num_points=len(points),
82+
)
83+
else:
84+
logger.warning("No valid documents found to insert.")

src/flare_ai_rag/retriever/qdrant_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def semantic_search(self, query: str, top_k: int = 5) -> list[dict]:
3131
"""
3232
# Convert the query into a vector embedding using the
3333
# SentenceTransformer instance.
34-
query_vector = self.embedding_model.encode(query)
34+
query_vector = self.embedding_model.encode(query).tolist()
3535

3636
# Search Qdrant for similar vectors.
3737
results = self.client.search(

0 commit comments

Comments
 (0)