Skip to content

Commit dfa85c8

Browse files
feat: Add Google Gemini support
- Introduced and classes for embedding and LLM functionalities using Google's Gemini API. - Updated to include dependency. - Added example scripts for using and . - Enhanced documentation to include new classes and their usage. - Updated tests for the new functionalities, ensuring proper error handling and response validation.
1 parent 8bc6a62 commit dfa85c8

File tree

11 files changed

+771
-5
lines changed

11 files changed

+771
-5
lines changed

docs/source/api.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,12 @@ SentenceTransformerEmbeddings
278278
.. autoclass:: neo4j_graphrag.embeddings.sentence_transformers.SentenceTransformerEmbeddings
279279
:members:
280280

281+
GeminiEmbedder
282+
==============
283+
284+
.. autoclass:: neo4j_graphrag.embeddings.google_genai.GeminiEmbedder
285+
:members:
286+
281287
OpenAIEmbeddings
282288
================
283289

@@ -336,6 +342,13 @@ OpenAILLM
336342
:undoc-members: get_messages, client_class, async_client_class
337343

338344

345+
GeminiLLM
346+
---------
347+
348+
.. autoclass:: neo4j_graphrag.llm.google_genai_llm.GeminiLLM
349+
:members:
350+
351+
339352
AzureOpenAILLM
340353
--------------
341354

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from neo4j_graphrag.embeddings import GeminiEmbedder
2+
3+
# set api key here on in the GOOGLE_API_KEY env var
4+
api_key = None
5+
6+
embedder = GeminiEmbedder(
7+
model="gemini-embedding-001",
8+
api_key=api_key,
9+
)
10+
res = embedder.embed_query("my question")
11+
print(res[:10])
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from neo4j_graphrag.llm import GeminiLLM
2+
3+
# set api key here on in the GOOGLE_API_KEY env var
4+
api_key = None
5+
6+
llm = GeminiLLM(
7+
model_name="gemini-2.5-flash",
8+
api_key=api_key,
9+
)
10+
res = llm.invoke("say something")
11+
print(res.content)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Repository = "https://github.com/neo4j/neo4j-graphrag-python"
4646
weaviate = ["weaviate-client>=4.6.1,<5.0.0"]
4747
pinecone = ["pinecone-client>=4.1.0,<5.0.0"]
4848
google = ["google-cloud-aiplatform>=1.66.0,<2.0.0"]
49+
google-genai = ["google-genai>=1.62.0,<2.0.0"]
4950
cohere = ["cohere>=5.9.0,<6.0.0"]
5051
anthropic = ["anthropic>=0.49.0,<0.50.0"]
5152
ollama = ["ollama>=0.4.4,<0.5.0"]

src/neo4j_graphrag/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from .base import Embedder
1616
from .cohere import CohereEmbeddings
17+
from .google_genai import GeminiEmbedder
1718
from .mistral import MistralAIEmbeddings
1819
from .ollama import OllamaEmbeddings
1920
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
@@ -29,4 +30,5 @@
2930
"VertexAIEmbeddings",
3031
"MistralAIEmbeddings",
3132
"CohereEmbeddings",
33+
"GeminiEmbedder",
3234
]
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
# built-in dependencies
18+
from typing import Any, Optional
19+
20+
# project dependencies
21+
from neo4j_graphrag.embeddings.base import Embedder
22+
from neo4j_graphrag.exceptions import EmbeddingsGenerationError
23+
from neo4j_graphrag.utils.rate_limit import (
24+
RateLimitHandler,
25+
async_rate_limit_handler,
26+
rate_limit_handler,
27+
)
28+
29+
try:
30+
from google import genai
31+
from google.genai import types
32+
except ImportError:
33+
genai = None
34+
types = None
35+
36+
DEFAULT_EMBEDDING_MODEL = "text-embedding-004"
37+
DEFAULT_EMBEDDING_DIM = 768
38+
39+
40+
class GeminiEmbedder(Embedder):
41+
"""Embedder that uses Google's Gemini API via the google.genai SDK.
42+
43+
Args:
44+
model: Embedding model name. Defaults to "text-embedding-004".
45+
embedding_dim: Output dimensionality. Defaults to 768.
46+
rate_limit_handler: Optional rate limit handler.
47+
**kwargs: Arguments passed to the genai.Client.
48+
"""
49+
50+
def __init__(
51+
self,
52+
model: str = DEFAULT_EMBEDDING_MODEL,
53+
embedding_dim: int = DEFAULT_EMBEDDING_DIM,
54+
rate_limit_handler: Optional[RateLimitHandler] = None,
55+
**kwargs: Any,
56+
) -> None:
57+
if genai is None or types is None:
58+
raise ImportError(
59+
"Could not import google-genai python client. "
60+
'Please install it with `pip install "neo4j-graphrag[google-genai]"`.'
61+
)
62+
super().__init__(rate_limit_handler)
63+
self.model = model
64+
self.embedding_dim = embedding_dim
65+
self.client = genai.Client(**kwargs)
66+
67+
@rate_limit_handler
68+
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
69+
try:
70+
result = self.client.models.embed_content(
71+
model=self.model,
72+
contents=[text],
73+
config=types.EmbedContentConfig(
74+
output_dimensionality=self.embedding_dim
75+
),
76+
**kwargs,
77+
)
78+
if not result.embeddings or not result.embeddings[0].values:
79+
raise ValueError("No embeddings returned from Gemini API")
80+
return list(result.embeddings[0].values)
81+
except Exception as e:
82+
raise EmbeddingsGenerationError(
83+
f"Failed to generate embedding with Gemini: {e}"
84+
) from e
85+
86+
@async_rate_limit_handler
87+
async def async_embed_query(self, text: str, **kwargs: Any) -> list[float]:
88+
try:
89+
result = await self.client.aio.models.embed_content(
90+
model=self.model,
91+
contents=[text],
92+
config=types.EmbedContentConfig(
93+
output_dimensionality=self.embedding_dim
94+
),
95+
**kwargs,
96+
)
97+
if not result.embeddings or not result.embeddings[0].values:
98+
raise ValueError("No embeddings returned from Gemini API")
99+
return list(result.embeddings[0].values)
100+
except Exception as e:
101+
raise EmbeddingsGenerationError(
102+
f"Failed to generate embedding with Gemini: {e}"
103+
) from e

src/neo4j_graphrag/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .anthropic_llm import AnthropicLLM
1919
from .base import LLMInterface, LLMInterfaceV2
2020
from .cohere_llm import CohereLLM
21+
from .google_genai_llm import GeminiLLM
2122
from .mistralai_llm import MistralAILLM
2223
from .ollama_llm import OllamaLLM
2324
from .openai_llm import AzureOpenAILLM, OpenAILLM
@@ -28,6 +29,7 @@
2829
__all__ = [
2930
"AnthropicLLM",
3031
"CohereLLM",
32+
"GeminiLLM",
3133
"LLMResponse",
3234
"LLMInterface",
3335
"LLMInterfaceV2",

0 commit comments

Comments
 (0)