Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"cryptography>=44.0.1",
"google-genai>=1.2.0",
"google-generativeai>=0.8.4",
"httpx>=0.28.1",
"openrouter>=1.0",
Expand Down
2 changes: 1 addition & 1 deletion src/data/rag_answer.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"query": "What is the block time for the Flare blockchain?",
"answer": "The Flare blockchain generates a new block approximately every 1.8 seconds [Document 0-overview.mdx].\n"
"answer": "The provided text states that Flare's block time is approximately 1.8 seconds [Document 1-intro.mdx, Document 0-overview.mdx]. Each block triggers the Flare Time Series Oracle (FTSO) to select data providers for updates [Document 0-overview.mdx].\n"
}
3 changes: 2 additions & 1 deletion src/flare_ai_rag/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .base import AsyncBaseClient, BaseClient
from .gemini import GeminiEmbedding, GeminiProvider
from .gemini import EmbeddingTaskType, GeminiEmbedding, GeminiProvider
from .model import Model
from .openrouter import OpenRouterClient

__all__ = [
"AsyncBaseClient",
"BaseClient",
"EmbeddingTaskType",
"GeminiEmbedding",
"GeminiProvider",
"Model",
Expand Down
57 changes: 26 additions & 31 deletions src/flare_ai_rag/ai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
from typing import Any, override

import structlog
from google import genai, generativeai
from google.generativeai.client import configure
from google.generativeai.embedding import (
EmbeddingTaskType,
)
from google.generativeai.embedding import (
embed_content as _embed_content,
)
from google.generativeai.generative_models import ChatSession, GenerativeModel
from google.generativeai.types import GenerationConfig

from flare_ai_rag.ai.base import BaseAIProvider, ModelResponse

Expand Down Expand Up @@ -55,9 +63,9 @@ def __init__(self, api_key: str, model: str, **kwargs: str) -> None:
**kwargs (str): Additional configuration parameters including:
- system_instruction: Custom system prompt for the AI personality
"""
generativeai.configure(api_key=api_key)
self.chat: generativeai.ChatSession | None = None
self.model = generativeai.GenerativeModel(
configure(api_key=api_key)
self.chat: ChatSession | None = None
self.model = GenerativeModel(
model_name=model,
system_instruction=kwargs.get("system_instruction", SYSTEM_INSTRUCTION),
)
Expand Down Expand Up @@ -90,7 +98,7 @@ def reset_model(self, model: str, **kwargs: str) -> None:
"""
new_system_instruction = kwargs.get("system_instruction", SYSTEM_INSTRUCTION)
# Reinitialize the generative model.
self.model = generativeai.GenerativeModel(
self.model = GenerativeModel(
model_name=model,
system_instruction=new_system_instruction,
)
Expand Down Expand Up @@ -126,7 +134,7 @@ def generate(
"""
response = self.model.generate_content(
prompt,
generation_config=generativeai.GenerationConfig(
generation_config=GenerationConfig(
response_mime_type=response_mime_type, response_schema=response_schema
),
)
Expand Down Expand Up @@ -175,7 +183,7 @@ def send_message(
)


class NewGeminiEmbedding:
class GeminiEmbedding:
def __init__(self, api_key: str) -> None:
"""
Initialize Gemini with API credentials.
Expand All @@ -184,9 +192,15 @@ def __init__(self, api_key: str) -> None:
Args:
api_key (str): Google API key for authentication
"""
generativeai.configure(api_key=api_key)
configure(api_key=api_key)

def embed_content(self, embedding_model: str, contents: str) -> list[float]:
def embed_content(
self,
embedding_model: str,
contents: str,
task_type: EmbeddingTaskType,
title: str | None = None,
) -> list[float]:
"""
Generate text embeddings using Gemini.

Expand All @@ -197,31 +211,12 @@ def embed_content(self, embedding_model: str, contents: str) -> list[float]:
Returns:
list[float]: The generated embedding vector.
"""
response = generativeai.embed_content(model=embedding_model, content=contents)
response = _embed_content(
model=embedding_model, content=contents, task_type=task_type, title=title
)
try:
embedding = response["embedding"]
except (KeyError, IndexError) as e:
msg = "Failed to extract embedding from response."
raise ValueError(msg) from e
return embedding


class GeminiEmbedding:
"""
Initialize Gemini with API credentials.
"""

def __init__(self, api_key: str) -> None:
"""This client uses google.genai"""
self.client = genai.Client(api_key=api_key)

def embed_content(self, embedding_model: str, contents: str) -> list[float]:
result = self.client.models.embed_content(
model=embedding_model, contents=contents
)
embedding = result.embeddings

if not embedding or embedding[0].values is None:
msg = "No embedding was returned from the API."
raise ValueError(msg)
return embedding[0].values
2 changes: 1 addition & 1 deletion src/flare_ai_rag/input_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"id": "gemini-1.5-flash"
},
"retriever_config": {
"embedding_model": "text-embedding-004",
"embedding_model": "models/text-embedding-004",
"vector_size": 768,
"collection_name": "docs_collection",
"host": "localhost",
Expand Down
7 changes: 5 additions & 2 deletions src/flare_ai_rag/retriever/qdrant_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, PointStruct, VectorParams

from flare_ai_rag.ai import GeminiEmbedding
from flare_ai_rag.ai import EmbeddingTaskType, GeminiEmbedding
from flare_ai_rag.retriever.config import RetrieverConfig

logger = structlog.get_logger(__name__)
Expand Down Expand Up @@ -56,7 +56,10 @@ def generate_collection(
try:
# Compute the embedding for the document content.
embedding = embedding_client.embed_content(
embedding_model=retriever_config.embedding_model, contents=content
embedding_model=retriever_config.embedding_model,
task_type=EmbeddingTaskType.RETRIEVAL_DOCUMENT,
contents=content,
title=str(row["Filename"]),
)
except Exception as e:
logger.exception(
Expand Down
6 changes: 4 additions & 2 deletions src/flare_ai_rag/retriever/qdrant_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from qdrant_client import QdrantClient

from flare_ai_rag.ai import GeminiEmbedding
from flare_ai_rag.ai import EmbeddingTaskType, GeminiEmbedding
from flare_ai_rag.retriever.base import BaseRetriever
from flare_ai_rag.retriever.config import RetrieverConfig

Expand Down Expand Up @@ -31,7 +31,9 @@ def semantic_search(self, query: str, top_k: int = 5) -> list[dict]:
"""
# Convert the query into a vector embedding using Gemini
query_vector = self.embedding_client.embed_content(
embedding_model="text-embedding-004", contents=query
embedding_model="models/text-embedding-004",
contents=query,
task_type=EmbeddingTaskType.RETRIEVAL_QUERY,
)

# Search Qdrant for similar vectors.
Expand Down
48 changes: 0 additions & 48 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.