Skip to content

Commit 0218109

Browse files
committed
chore: replace SentenceTransformer embeddings by Gemini embeddings
1 parent 7847746 commit 0218109

24 files changed

+386
-32
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
88
"cryptography>=44.0.1",
9+
"google>=3.0.0",
10+
"google-genai>=1.2.0",
11+
"google-generativeai>=0.8.4",
912
"httpx>=0.28.1",
1013
"openrouter>=1.0",
1114
"pandas>=2.2.3",

src/data/rag_answer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"query": "What is the block time for the Flare blockchain?",
3-
"answer": "The block time for the Flare blockchain is approximately **1.8 seconds**. This is evidenced by the context provided in **Document 2-feeds.mdx**, which states: \"FTSOv2's block-latency feeds update incrementally with each new block on Flare, approximately every 1.8 seconds\" (Document 2-feeds.mdx). This indicates that new blocks are produced on the Flare blockchain roughly every 1.8 seconds."
3+
"answer": "The block time for the Flare blockchain is approximately **1.8 seconds**. This is evidenced by the context provided in Document 0-overview.mdx, which states that the Flare Time Series Oracle (FTSOv2) features block-latency feeds that update \"with each new block on Flare, every \u22481.8 seconds\" [Document 0-overview.mdx]. This indicates that the time interval between consecutive blocks on the Flare blockchain is roughly 1.8 seconds."
44
}

src/flare_ai_rag/ai/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .base_client import AsyncBaseClient, BaseClient
2+
from .gemini import GeminiClient
3+
from .model import Model
4+
from .openrouter import OpenRouterClient
5+
6+
__all__ = ["AsyncBaseClient", "BaseClient", "GeminiClient", "Model", "OpenRouterClient"]
File renamed without changes.

src/flare_ai_rag/ai/gemini.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from google import genai
2+
3+
4+
class GeminiClient:
5+
def __init__(self, api_key: str) -> None:
6+
self.client = genai.Client(api_key=api_key)
7+
8+
def embed_content(self, model: str, contents: str) -> list[float]:
9+
result = self.client.models.embed_content(model=model, contents=contents)
10+
embedding = result.embeddings
11+
12+
if not embedding or embedding[0].values is None:
13+
msg = "No embedding was returned from the API."
14+
raise ValueError(msg)
15+
return embedding[0].values
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from flare_ai_rag.openrouter.base_client import AsyncBaseClient, BaseClient
1+
from flare_ai_rag.ai import AsyncBaseClient, BaseClient
22

33

44
class OpenRouterClient(BaseClient):

src/flare_ai_rag/input_parameters.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"qdrant_config": {
88
"embedding_model": "all-MiniLM-L6-v2",
99
"collection_name": "docs_collection",
10-
"vector_size": 384,
10+
"vector_size": 768,
1111
"host": "localhost",
1212
"port": 6333
1313
},

src/flare_ai_rag/main.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import structlog
33
from qdrant_client import QdrantClient
44

5-
from flare_ai_rag.openrouter import OpenRouterClient
5+
from flare_ai_rag.ai import GeminiClient, OpenRouterClient
66
from flare_ai_rag.responder import OpenRouterResponder, ResponderConfig
77
from flare_ai_rag.retriever import QdrantConfig, QdrantRetriever, generate_collection
88
from flare_ai_rag.router import QueryRouter, RouterConfig
@@ -54,6 +54,7 @@ def setup_retriever(
5454
qdrant_client: QdrantClient,
5555
input_config: dict,
5656
df_docs: pd.DataFrame,
57+
gemini_client: GeminiClient,
5758
collection_name: str | None = None,
5859
) -> QdrantRetriever:
5960
"""Initialize the Qdrant retriever."""
@@ -62,13 +63,19 @@ def setup_retriever(
6263
# (Re)generate qdrant collection
6364
if collection_name:
6465
generate_collection(
65-
df_docs, qdrant_client, qdrant_config, collection_name=collection_name
66+
df_docs,
67+
qdrant_client,
68+
qdrant_config,
69+
collection_name=collection_name,
70+
gemini_client=gemini_client,
6671
)
6772
logger.info(
6873
"The Qdrant collection has been generated.", collection_name=collection_name
6974
)
7075
# Return retriever
71-
return QdrantRetriever(client=qdrant_client, qdrant_config=qdrant_config)
76+
return QdrantRetriever(
77+
client=qdrant_client, qdrant_config=qdrant_config, gemini_client=gemini_client
78+
)
7279

7380

7481
def main() -> None:
@@ -93,11 +100,13 @@ def main() -> None:
93100
logger.info("Loaded CSV Data.", num_rows=len(df_docs))
94101

95102
# Retrieve docs
103+
gemini_client = GeminiClient(api_key=settings.gemini_api_key)
96104
retriever = setup_retriever(
97105
qdrant_client,
98106
input_config,
99107
df_docs,
100108
collection_name="docs_collection",
109+
gemini_client=gemini_client,
101110
)
102111
retrieved_docs = retriever.semantic_search(query, top_k=5)
103112
logger.info("Docs have been retrieved.")

src/flare_ai_rag/openrouter/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)