Skip to content

Commit 9f2005c

Browse files
committed
fix(settings): update configs for smooth transition between Gemini and OpenRouter
1 parent c4bfca6 commit 9f2005c

File tree

17 files changed

+75
-88
lines changed

17 files changed

+75
-88
lines changed

src/data/rag_answer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"query": "What is the block time for the Flare blockchain?",
3-
"answer": "Based on the provided text, a new block is generated on the Flare blockchain approximately every 1.8 seconds [Document 0].\n"
3+
"answer": "The Flare blockchain produces a block approximately every 1.8 seconds [Document 0].\n"
44
}

src/flare_ai_rag/ai/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
@dataclass(frozen=True)
55
class Model:
66
model_id: str
7-
max_tokens: int
8-
temperature: float
7+
max_tokens: int | None
8+
temperature: float | None
Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
{
22
"router_model": {
3-
"id": "qwen/qwen-vl-plus:free",
4-
"max_tokens": 50,
5-
"temperature": 0
3+
"id": "gemini-1.5-flash"
64
},
7-
"qdrant_config": {
8-
"embedding_model": "all-MiniLM-L6-v2",
9-
"collection_name": "docs_collection",
5+
"retriever_config": {
6+
"embedding_model": "text-embedding-004",
107
"vector_size": 768,
8+
"collection_name": "docs_collection",
119
"host": "localhost",
1210
"port": 6333
1311
},
1412
"responder_model": {
15-
"id": "deepseek/deepseek-chat:free",
16-
"max_tokens": 200,
17-
"temperature": 0
13+
"id": "gemini-1.5-flash"
1814
}
1915
}

src/flare_ai_rag/main.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,14 @@
44

55
from flare_ai_rag.ai import GeminiEmbedding, GeminiProvider
66
from flare_ai_rag.responder import GeminiResponder, ResponderConfig
7-
from flare_ai_rag.retriever import QdrantConfig, QdrantRetriever, generate_collection
7+
from flare_ai_rag.retriever import QdrantRetriever, RetrieverConfig, generate_collection
88
from flare_ai_rag.router import GeminiRouter, RouterConfig
99
from flare_ai_rag.settings import settings
1010
from flare_ai_rag.utils import load_json, load_txt, save_json
1111

1212
logger = structlog.get_logger(__name__)
1313

1414

15-
def setup_qdrant(input_config: dict) -> QdrantClient:
16-
"""Initialize Qdrant client."""
17-
logger.info("Setting up Qdrant client...")
18-
qdrant_config = QdrantConfig.load(input_config["qdrant_config"])
19-
qdrant_client = QdrantClient(host=qdrant_config.host, port=qdrant_config.port)
20-
logger.info("Qdrant client has been set up.")
21-
22-
return qdrant_client
23-
24-
2515
def setup_router(input_config: dict) -> GeminiRouter:
2616
"""Initialize the Gemini Provider and the Gemini Router."""
2717
# Setup router config
@@ -31,7 +21,7 @@ def setup_router(input_config: dict) -> GeminiRouter:
3121
# Setup Gemini client based on Router config
3222
gemini_provider = GeminiProvider(
3323
api_key=settings.gemini_api_key,
34-
model=settings.gemini_model,
24+
model=router_config.model.model_id,
3525
system_instruction=router_config.system_prompt,
3626
)
3727

@@ -46,7 +36,7 @@ def setup_retriever(
4636
) -> QdrantRetriever:
4737
"""Initialize the Qdrant retriever."""
4838
# Set up Qdrant config
49-
qdrant_config = QdrantConfig.load(input_config["qdrant_config"])
39+
retriever_config = RetrieverConfig.load(input_config["retriever_config"])
5040

5141
# Set up Gemini Embedding client
5242
embedding_client = GeminiEmbedding(settings.gemini_api_key)
@@ -55,7 +45,7 @@ def setup_retriever(
5545
generate_collection(
5646
df_docs,
5747
qdrant_client,
58-
qdrant_config,
48+
retriever_config,
5949
collection_name=collection_name,
6050
embedding_client=embedding_client,
6151
)
@@ -65,11 +55,21 @@ def setup_retriever(
6555
# Return retriever
6656
return QdrantRetriever(
6757
client=qdrant_client,
68-
qdrant_config=qdrant_config,
58+
retriever_config=retriever_config,
6959
embedding_client=embedding_client,
7060
)
7161

7262

63+
def setup_qdrant(input_config: dict) -> QdrantClient:
64+
"""Initialize Qdrant client."""
65+
logger.info("Setting up Qdrant client...")
66+
retriever_config = RetrieverConfig.load(input_config["retriever_config"])
67+
qdrant_client = QdrantClient(host=retriever_config.host, port=retriever_config.port)
68+
logger.info("Qdrant client has been set up.")
69+
70+
return qdrant_client
71+
72+
7373
def setup_responder(input_config: dict) -> GeminiResponder:
7474
"""Initialize the responder."""
7575
# Set up Responder Config.
@@ -79,7 +79,7 @@ def setup_responder(input_config: dict) -> GeminiResponder:
7979
# Set up a new Gemini Provider based on Responder Config.
8080
gemini_provider = GeminiProvider(
8181
api_key=settings.gemini_api_key,
82-
model=settings.gemini_model,
82+
model=responder_config.model.model_id,
8383
system_instruction=responder_config.system_prompt,
8484
)
8585
return GeminiResponder(client=gemini_provider, responder_config=responder_config)

src/flare_ai_rag/responder/config.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
from dataclasses import dataclass
2+
from typing import Any
23

34
from flare_ai_rag.ai import Model
45
from flare_ai_rag.responder.prompts import RESPONDER_INSTRUCTION, RESPONDER_PROMPT
56

67

78
@dataclass(frozen=True)
89
class ResponderConfig:
9-
model: Model | None
10+
model: Model
1011
system_prompt: str
1112
query_prompt: str
1213

1314
@staticmethod
14-
def load(model_config: dict | None = None) -> "ResponderConfig":
15+
def load(model_config: dict[str, Any]) -> "ResponderConfig":
1516
"""Loads the Responder config."""
16-
if not model_config:
17-
# When using Gemini
18-
model = None
19-
else:
20-
# When using OpenRouter
21-
model = Model(
22-
model_id=model_config["id"],
23-
max_tokens=model_config["max_tokens"],
24-
temperature=model_config["temperature"],
25-
)
17+
model = Model(
18+
model_id=model_config["id"],
19+
max_tokens=model_config.get("max_tokens"),
20+
temperature=model_config.get("temperature"),
21+
)
2622

2723
return ResponderConfig(
2824
model=model,

src/flare_ai_rag/responder/responder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,16 @@ def generate_response(self, query: str, retrieved_documents: list[dict]) -> str:
8181
prompt = context + f"User query: {query}\n" + self.responder_config.query_prompt
8282
# Prepare the payload for the completion endpoint.
8383
payload: dict[str, Any] = {
84+
"model": self.responder_config.model.model_id,
8485
"messages": [
8586
{"role": "system", "content": self.responder_config.system_prompt},
8687
{"role": "user", "content": prompt},
87-
]
88+
],
8889
}
8990

90-
if self.responder_config.model is not None:
91-
payload["model"] = self.responder_config.model.model_id
91+
if self.responder_config.model.max_tokens is not None:
9292
payload["max_tokens"] = self.responder_config.model.max_tokens
93+
if self.responder_config.model.temperature is not None:
9394
payload["temperature"] = self.responder_config.model.temperature
9495

9596
# Send the prompt to the OpenRouter API.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .base import BaseRetriever
2-
from .config import QdrantConfig
2+
from .config import RetrieverConfig
33
from .qdrant_collection import generate_collection
44
from .qdrant_retriever import QdrantRetriever
55

6-
__all__ = ["BaseRetriever", "QdrantConfig", "QdrantRetriever", "generate_collection"]
6+
__all__ = ["BaseRetriever", "QdrantRetriever", "RetrieverConfig", "generate_collection"]

src/flare_ai_rag/retriever/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from dataclasses import dataclass
2+
from typing import Any
23

34

45
@dataclass(frozen=True)
5-
class QdrantConfig:
6+
class RetrieverConfig:
67
"""Configuration for the embedding model used in the retriever."""
78

89
embedding_model: str
@@ -12,8 +13,8 @@ class QdrantConfig:
1213
port: int
1314

1415
@staticmethod
15-
def load(retriever_config: dict) -> "QdrantConfig":
16-
return QdrantConfig(
16+
def load(retriever_config: dict[str, Any]) -> "RetrieverConfig":
17+
return RetrieverConfig(
1718
embedding_model=retriever_config["embedding_model"],
1819
collection_name=retriever_config["collection_name"],
1920
vector_size=retriever_config["vector_size"],

src/flare_ai_rag/retriever/qdrant_collection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from qdrant_client.http.models import Distance, PointStruct, VectorParams
55

66
from flare_ai_rag.ai import GeminiEmbedding
7-
from flare_ai_rag.retriever.config import QdrantConfig
8-
from flare_ai_rag.settings import settings
7+
from flare_ai_rag.retriever.config import RetrieverConfig
98

109
logger = structlog.get_logger(__name__)
1110

@@ -28,13 +27,13 @@ def _create_collection(
2827
def generate_collection(
2928
df_docs: pd.DataFrame,
3029
qdrant_client: QdrantClient,
31-
qdrant_config: QdrantConfig,
30+
retriever_config: RetrieverConfig,
3231
collection_name: str,
3332
embedding_client: GeminiEmbedding,
3433
) -> None:
3534
"""Routine for generating a Qdrant collection for a specific CSV file type."""
3635
# Create the collection.
37-
_create_collection(qdrant_client, collection_name, qdrant_config.vector_size)
36+
_create_collection(qdrant_client, collection_name, retriever_config.vector_size)
3837
logger.info("Created the collection.", collection_name=collection_name)
3938

4039
# For each document in the CSV, compute its embedding and prepare a Qdrant point.
@@ -54,7 +53,7 @@ def generate_collection(
5453
try:
5554
# Compute the embedding for the document content.
5655
embedding = embedding_client.embed_content(
57-
embedding_model=settings.gemini_embedding_model, contents=content
56+
embedding_model=retriever_config.embedding_model, contents=content
5857
)
5958
except Exception as e:
6059
logger.exception(

src/flare_ai_rag/retriever/qdrant_retriever.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44

55
from flare_ai_rag.ai import GeminiEmbedding
66
from flare_ai_rag.retriever.base import BaseRetriever
7-
from flare_ai_rag.retriever.config import QdrantConfig
7+
from flare_ai_rag.retriever.config import RetrieverConfig
88

99

1010
class QdrantRetriever(BaseRetriever):
1111
def __init__(
1212
self,
1313
client: QdrantClient,
14-
qdrant_config: QdrantConfig,
14+
retriever_config: RetrieverConfig,
1515
embedding_client: GeminiEmbedding,
1616
) -> None:
1717
"""Initialize the QdrantRetriever."""
1818
self.client = client
19-
self.qdrant_config = qdrant_config
19+
self.retriever_config = retriever_config
2020
self.embedding_client = embedding_client
2121

2222
@override
@@ -36,7 +36,7 @@ def semantic_search(self, query: str, top_k: int = 5) -> list[dict]:
3636

3737
# Search Qdrant for similar vectors.
3838
results = self.client.search(
39-
collection_name=self.qdrant_config.collection_name,
39+
collection_name=self.retriever_config.collection_name,
4040
query_vector=query_vector,
4141
limit=top_k,
4242
)

0 commit comments

Comments
 (0)