From 35c710ff112095cabebd242ccc6a9091392eff75 Mon Sep 17 00:00:00 2001 From: Dinesh Pinto Date: Fri, 21 Feb 2025 18:20:44 +0400 Subject: [PATCH 1/2] fix(dependencies): use only generative-ai --- pyproject.toml | 1 - src/flare_ai_rag/ai/gemini.py | 42 +++++++++--------------------- uv.lock | 48 ----------------------------------- 3 files changed, 12 insertions(+), 79 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df5cb2f..6007677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "cryptography>=44.0.1", - "google-genai>=1.2.0", "google-generativeai>=0.8.4", "httpx>=0.28.1", "openrouter>=1.0", diff --git a/src/flare_ai_rag/ai/gemini.py b/src/flare_ai_rag/ai/gemini.py index 2f6402b..bebd1fb 100644 --- a/src/flare_ai_rag/ai/gemini.py +++ b/src/flare_ai_rag/ai/gemini.py @@ -9,7 +9,10 @@ from typing import Any, override import structlog -from google import genai, generativeai +from google.generativeai.client import configure +from google.generativeai.embedding import embed_content as _embed_content +from google.generativeai.generative_models import ChatSession, GenerativeModel +from google.generativeai.types import GenerationConfig from flare_ai_rag.ai.base import BaseAIProvider, ModelResponse @@ -55,9 +58,9 @@ def __init__(self, api_key: str, model: str, **kwargs: str) -> None: **kwargs (str): Additional configuration parameters including: - system_instruction: Custom system prompt for the AI personality """ - generativeai.configure(api_key=api_key) - self.chat: generativeai.ChatSession | None = None - self.model = generativeai.GenerativeModel( + configure(api_key=api_key) + self.chat: ChatSession | None = None + self.model = GenerativeModel( model_name=model, system_instruction=kwargs.get("system_instruction", SYSTEM_INSTRUCTION), ) @@ -90,7 +93,7 @@ def reset_model(self, model: str, **kwargs: str) -> None: """ new_system_instruction = kwargs.get("system_instruction", SYSTEM_INSTRUCTION) # Reinitialize the generative model. - self.model = generativeai.GenerativeModel( + self.model = GenerativeModel( model_name=model, system_instruction=new_system_instruction, ) @@ -126,7 +129,7 @@ def generate( """ response = self.model.generate_content( prompt, - generation_config=generativeai.GenerationConfig( + generation_config=GenerationConfig( response_mime_type=response_mime_type, response_schema=response_schema ), ) @@ -175,7 +178,7 @@ def send_message( ) -class NewGeminiEmbedding: +class GeminiEmbedding: def __init__(self, api_key: str) -> None: """ Initialize Gemini with API credentials. @@ -184,7 +187,7 @@ def __init__(self, api_key: str) -> None: Args: api_key (str): Google API key for authentication """ - generativeai.configure(api_key=api_key) + configure(api_key=api_key) def embed_content(self, embedding_model: str, contents: str) -> list[float]: """ @@ -197,31 +200,10 @@ def embed_content(self, embedding_model: str, contents: str) -> list[float]: Returns: list[float]: The generated embedding vector. """ - response = generativeai.embed_content(model=embedding_model, content=contents) + response = _embed_content(model=embedding_model, content=contents) try: embedding = response["embedding"] except (KeyError, IndexError) as e: msg = "Failed to extract embedding from response." raise ValueError(msg) from e return embedding - - -class GeminiEmbedding: - """ - Initialize Gemini with API credentials. - """ - - def __init__(self, api_key: str) -> None: - """This client uses google.genai""" - self.client = genai.Client(api_key=api_key) - - def embed_content(self, embedding_model: str, contents: str) -> list[float]: - result = self.client.models.embed_content( - model=embedding_model, contents=contents - ) - embedding = result.embeddings - - if not embedding or embedding[0].values is None: - msg = "No embedding was returned from the API." - raise ValueError(msg) - return embedding[0].values diff --git a/uv.lock b/uv.lock index 9c63e00..ed93cc9 100644 --- a/uv.lock +++ b/uv.lock @@ -164,7 +164,6 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "cryptography" }, - { name = "google-genai" }, { name = "google-generativeai" }, { name = "httpx" }, { name = "openrouter" }, @@ -185,7 +184,6 @@ dev = [ [package.metadata] requires-dist = [ { name = "cryptography", specifier = ">=44.0.1" }, - { name = "google-genai", specifier = ">=1.2.0" }, { name = "google-generativeai", specifier = ">=0.8.4" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "openrouter", specifier = ">=1.0" }, @@ -283,21 +281,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/8a/fe34d2f3f9470a27b01c9e76226965863f153d5fbe276f83608562e49c04/google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d", size = 9253 }, ] -[[package]] -name = "google-genai" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "pydantic" }, - { name = "requests" }, - { name = "typing-extensions" }, - { name = "websockets" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/0d/ed/985f2d2e2b5fbd912ab0fdb11d6dc48c22553a6c4edffabb8146d53b974a/google_genai-1.2.0-py3-none-any.whl", hash = "sha256:609d61bee73f1a6ae5b47e9c7dd4b469d50318f050c5ceacf835b0f80f79d2d9", size = 130744 }, -] - [[package]] name = "google-generativeai" version = "0.8.4" @@ -961,34 +944,3 @@ sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf wheels = [ { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, ] - -[[package]] -name = "websockets" -version = "14.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/54/8359678c726243d19fae38ca14a334e740782336c9f19700858c4eb64a1e/websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5", size = 164394 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/81/04f7a397653dc8bec94ddc071f34833e8b99b13ef1a3804c149d59f92c18/websockets-14.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f20522e624d7ffbdbe259c6b6a65d73c895045f76a93719aa10cd93b3de100c", size = 163096 }, - { url = "https://files.pythonhosted.org/packages/ec/c5/de30e88557e4d70988ed4d2eabd73fd3e1e52456b9f3a4e9564d86353b6d/websockets-14.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:647b573f7d3ada919fd60e64d533409a79dcf1ea21daeb4542d1d996519ca967", size = 160758 }, - { url = "https://files.pythonhosted.org/packages/e5/8c/d130d668781f2c77d106c007b6c6c1d9db68239107c41ba109f09e6c218a/websockets-14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6af99a38e49f66be5a64b1e890208ad026cda49355661549c507152113049990", size = 160995 }, - { url = "https://files.pythonhosted.org/packages/a6/bc/f6678a0ff17246df4f06765e22fc9d98d1b11a258cc50c5968b33d6742a1/websockets-14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:091ab63dfc8cea748cc22c1db2814eadb77ccbf82829bac6b2fbe3401d548eda", size = 170815 }, - { url = "https://files.pythonhosted.org/packages/d8/b2/8070cb970c2e4122a6ef38bc5b203415fd46460e025652e1ee3f2f43a9a3/websockets-14.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b374e8953ad477d17e4851cdc66d83fdc2db88d9e73abf755c94510ebddceb95", size = 169759 }, - { url = "https://files.pythonhosted.org/packages/81/da/72f7caabd94652e6eb7e92ed2d3da818626e70b4f2b15a854ef60bf501ec/websockets-14.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a39d7eceeea35db85b85e1169011bb4321c32e673920ae9c1b6e0978590012a3", size = 170178 }, - { url = "https://files.pythonhosted.org/packages/31/e0/812725b6deca8afd3a08a2e81b3c4c120c17f68c9b84522a520b816cda58/websockets-14.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0a6f3efd47ffd0d12080594f434faf1cd2549b31e54870b8470b28cc1d3817d9", size = 170453 }, - { url = "https://files.pythonhosted.org/packages/66/d3/8275dbc231e5ba9bb0c4f93144394b4194402a7a0c8ffaca5307a58ab5e3/websockets-14.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:065ce275e7c4ffb42cb738dd6b20726ac26ac9ad0a2a48e33ca632351a737267", size = 169830 }, - { url = "https://files.pythonhosted.org/packages/a3/ae/e7d1a56755ae15ad5a94e80dd490ad09e345365199600b2629b18ee37bc7/websockets-14.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e9d0e53530ba7b8b5e389c02282f9d2aa47581514bd6049d3a7cffe1385cf5fe", size = 169824 }, - { url = "https://files.pythonhosted.org/packages/b6/32/88ccdd63cb261e77b882e706108d072e4f1c839ed723bf91a3e1f216bf60/websockets-14.2-cp312-cp312-win32.whl", hash = "sha256:20e6dd0984d7ca3037afcb4494e48c74ffb51e8013cac71cf607fffe11df7205", size = 163981 }, - { url = "https://files.pythonhosted.org/packages/b3/7d/32cdb77990b3bdc34a306e0a0f73a1275221e9a66d869f6ff833c95b56ef/websockets-14.2-cp312-cp312-win_amd64.whl", hash = "sha256:44bba1a956c2c9d268bdcdf234d5e5ff4c9b6dc3e300545cbe99af59dda9dcce", size = 164421 }, - { url = "https://files.pythonhosted.org/packages/82/94/4f9b55099a4603ac53c2912e1f043d6c49d23e94dd82a9ce1eb554a90215/websockets-14.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6f1372e511c7409a542291bce92d6c83320e02c9cf392223272287ce55bc224e", size = 163102 }, - { url = "https://files.pythonhosted.org/packages/8e/b7/7484905215627909d9a79ae07070057afe477433fdacb59bf608ce86365a/websockets-14.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4da98b72009836179bb596a92297b1a61bb5a830c0e483a7d0766d45070a08ad", size = 160766 }, - { url = "https://files.pythonhosted.org/packages/a3/a4/edb62efc84adb61883c7d2c6ad65181cb087c64252138e12d655989eec05/websockets-14.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8a86a269759026d2bde227652b87be79f8a734e582debf64c9d302faa1e9f03", size = 160998 }, - { url = "https://files.pythonhosted.org/packages/f5/79/036d320dc894b96af14eac2529967a6fc8b74f03b83c487e7a0e9043d842/websockets-14.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86cf1aaeca909bf6815ea714d5c5736c8d6dd3a13770e885aafe062ecbd04f1f", size = 170780 }, - { url = "https://files.pythonhosted.org/packages/63/75/5737d21ee4dd7e4b9d487ee044af24a935e36a9ff1e1419d684feedcba71/websockets-14.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9b0f6c3ba3b1240f602ebb3971d45b02cc12bd1845466dd783496b3b05783a5", size = 169717 }, - { url = "https://files.pythonhosted.org/packages/2c/3c/bf9b2c396ed86a0b4a92ff4cdaee09753d3ee389be738e92b9bbd0330b64/websockets-14.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669c3e101c246aa85bc8534e495952e2ca208bd87994650b90a23d745902db9a", size = 170155 }, - { url = "https://files.pythonhosted.org/packages/75/2d/83a5aca7247a655b1da5eb0ee73413abd5c3a57fc8b92915805e6033359d/websockets-14.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eabdb28b972f3729348e632ab08f2a7b616c7e53d5414c12108c29972e655b20", size = 170495 }, - { url = "https://files.pythonhosted.org/packages/79/dd/699238a92761e2f943885e091486378813ac8f43e3c84990bc394c2be93e/websockets-14.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2066dc4cbcc19f32c12a5a0e8cc1b7ac734e5b64ac0a325ff8353451c4b15ef2", size = 169880 }, - { url = "https://files.pythonhosted.org/packages/c8/c9/67a8f08923cf55ce61aadda72089e3ed4353a95a3a4bc8bf42082810e580/websockets-14.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ab95d357cd471df61873dadf66dd05dd4709cae001dd6342edafc8dc6382f307", size = 169856 }, - { url = "https://files.pythonhosted.org/packages/17/b1/1ffdb2680c64e9c3921d99db460546194c40d4acbef999a18c37aa4d58a3/websockets-14.2-cp313-cp313-win32.whl", hash = "sha256:a9e72fb63e5f3feacdcf5b4ff53199ec8c18d66e325c34ee4c551ca748623bbc", size = 163974 }, - { url = "https://files.pythonhosted.org/packages/14/13/8b7fc4cb551b9cfd9890f0fd66e53c18a06240319915533b033a56a3d520/websockets-14.2-cp313-cp313-win_amd64.whl", hash = "sha256:b439ea828c4ba99bb3176dc8d9b933392a2413c0f6b149fdcba48393f573377f", size = 164420 }, - { url = "https://files.pythonhosted.org/packages/7b/c8/d529f8a32ce40d98309f4470780631e971a5a842b60aec864833b3615786/websockets-14.2-py3-none-any.whl", hash = "sha256:7a6ceec4ea84469f15cf15807a747e9efe57e369c384fa86e022b3bea679b79b", size = 157416 }, -] From 1e928aa9ca6602cab8bb809a272e1c0405257e4a Mon Sep 17 00:00:00 2001 From: Dinesh Pinto Date: Fri, 21 Feb 2025 19:04:49 +0400 Subject: [PATCH 2/2] fix(embeddings): model name and task types --- src/data/rag_answer.json | 2 +- src/flare_ai_rag/ai/__init__.py | 3 ++- src/flare_ai_rag/ai/gemini.py | 19 ++++++++++++++++--- src/flare_ai_rag/input_parameters.json | 2 +- .../retriever/qdrant_collection.py | 7 +++++-- .../retriever/qdrant_retriever.py | 6 ++++-- 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/data/rag_answer.json b/src/data/rag_answer.json index b22fecc..ed2f75b 100644 --- a/src/data/rag_answer.json +++ b/src/data/rag_answer.json @@ -1,4 +1,4 @@ { "query": "What is the block time for the Flare blockchain?", - "answer": "The Flare blockchain generates a new block approximately every 1.8 seconds [Document 0-overview.mdx].\n" + "answer": "The provided text states that Flare's block time is approximately 1.8 seconds [Document 1-intro.mdx, Document 0-overview.mdx]. Each block triggers the Flare Time Series Oracle (FTSO) to select data providers for updates [Document 0-overview.mdx].\n" } \ No newline at end of file diff --git a/src/flare_ai_rag/ai/__init__.py b/src/flare_ai_rag/ai/__init__.py index ad504f4..666cf54 100644 --- a/src/flare_ai_rag/ai/__init__.py +++ b/src/flare_ai_rag/ai/__init__.py @@ -1,11 +1,12 @@ from .base import AsyncBaseClient, BaseClient -from .gemini import GeminiEmbedding, GeminiProvider +from .gemini import EmbeddingTaskType, GeminiEmbedding, GeminiProvider from .model import Model from .openrouter import OpenRouterClient __all__ = [ "AsyncBaseClient", "BaseClient", + "EmbeddingTaskType", "GeminiEmbedding", "GeminiProvider", "Model", diff --git a/src/flare_ai_rag/ai/gemini.py b/src/flare_ai_rag/ai/gemini.py index bebd1fb..d07e0ca 100644 --- a/src/flare_ai_rag/ai/gemini.py +++ b/src/flare_ai_rag/ai/gemini.py @@ -10,7 +10,12 @@ import structlog from google.generativeai.client import configure -from google.generativeai.embedding import embed_content as _embed_content +from google.generativeai.embedding import ( + EmbeddingTaskType, +) +from google.generativeai.embedding import ( + embed_content as _embed_content, +) from google.generativeai.generative_models import ChatSession, GenerativeModel from google.generativeai.types import GenerationConfig @@ -189,7 +194,13 @@ def __init__(self, api_key: str) -> None: """ configure(api_key=api_key) - def embed_content(self, embedding_model: str, contents: str) -> list[float]: + def embed_content( + self, + embedding_model: str, + contents: str, + task_type: EmbeddingTaskType, + title: str | None = None, + ) -> list[float]: """ Generate text embeddings using Gemini. @@ -200,7 +211,9 @@ def embed_content(self, embedding_model: str, contents: str) -> list[float]: Returns: list[float]: The generated embedding vector. """ - response = _embed_content(model=embedding_model, content=contents) + response = _embed_content( + model=embedding_model, content=contents, task_type=task_type, title=title + ) try: embedding = response["embedding"] except (KeyError, IndexError) as e: diff --git a/src/flare_ai_rag/input_parameters.json b/src/flare_ai_rag/input_parameters.json index 07ff9c3..adbe6b4 100644 --- a/src/flare_ai_rag/input_parameters.json +++ b/src/flare_ai_rag/input_parameters.json @@ -3,7 +3,7 @@ "id": "gemini-1.5-flash" }, "retriever_config": { - "embedding_model": "text-embedding-004", + "embedding_model": "models/text-embedding-004", "vector_size": 768, "collection_name": "docs_collection", "host": "localhost", diff --git a/src/flare_ai_rag/retriever/qdrant_collection.py b/src/flare_ai_rag/retriever/qdrant_collection.py index c871824..81e7001 100644 --- a/src/flare_ai_rag/retriever/qdrant_collection.py +++ b/src/flare_ai_rag/retriever/qdrant_collection.py @@ -3,7 +3,7 @@ from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, PointStruct, VectorParams -from flare_ai_rag.ai import GeminiEmbedding +from flare_ai_rag.ai import EmbeddingTaskType, GeminiEmbedding from flare_ai_rag.retriever.config import RetrieverConfig logger = structlog.get_logger(__name__) @@ -56,7 +56,10 @@ def generate_collection( try: # Compute the embedding for the document content. embedding = embedding_client.embed_content( - embedding_model=retriever_config.embedding_model, contents=content + embedding_model=retriever_config.embedding_model, + task_type=EmbeddingTaskType.RETRIEVAL_DOCUMENT, + contents=content, + title=str(row["Filename"]), ) except Exception as e: logger.exception( diff --git a/src/flare_ai_rag/retriever/qdrant_retriever.py b/src/flare_ai_rag/retriever/qdrant_retriever.py index 06d461a..6312196 100644 --- a/src/flare_ai_rag/retriever/qdrant_retriever.py +++ b/src/flare_ai_rag/retriever/qdrant_retriever.py @@ -2,7 +2,7 @@ from qdrant_client import QdrantClient -from flare_ai_rag.ai import GeminiEmbedding +from flare_ai_rag.ai import EmbeddingTaskType, GeminiEmbedding from flare_ai_rag.retriever.base import BaseRetriever from flare_ai_rag.retriever.config import RetrieverConfig @@ -31,7 +31,9 @@ def semantic_search(self, query: str, top_k: int = 5) -> list[dict]: """ # Convert the query into a vector embedding using Gemini query_vector = self.embedding_client.embed_content( - embedding_model="text-embedding-004", contents=query + embedding_model="models/text-embedding-004", + contents=query, + task_type=EmbeddingTaskType.RETRIEVAL_QUERY, ) # Search Qdrant for similar vectors.