Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions griptape/drivers/embedding/google_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,28 @@ class GoogleEmbeddingDriver(BaseEmbeddingDriver):
title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
genai = import_optional_dependency("google.generativeai")
genai.configure(api_key=self.api_key)
genai = import_optional_dependency("google.genai")
types = import_optional_dependency("google.genai.types")

result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)
client = genai.Client(api_key=self.api_key)

return result["embedding"]
# Build config with task_type and title if provided
config_params = {}
if self.task_type:
config_params["task_type"] = self.task_type
if self.title:
config_params["title"] = self.title

config = types.EmbedContentConfig(**config_params) if config_params else None

result = client.models.embed_content(
model=self.model,
contents=chunk,
config=config,
)

# The new SDK returns embeddings in result.embeddings[0].values
return result.embeddings[0].values

def _params(self, chunk: str) -> dict:
return {"input": chunk, "model": self.model}
Loading
Loading