Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai): Fix gemini tool usage #1892

Merged
merged 7 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@
"urllib.*",
]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["google.*", "google.genai.*"]
ignore_missing_imports = true
disable_error_code = ["attr-defined", "import-untyped"]
2 changes: 1 addition & 1 deletion requirements-constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,5 @@ rapidfuzz==3.10.*
pytest-vcr==1.*
vcrpy==6.*
chardet==5.2.*
google-genai==0.*
google-genai==1.*
httpx==0.27.2 # v0.28 is breaking OpenAI and Anthropic as of Dec 16 2024
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# pip-compile --output-file=requirements.txt --strip-extras requirements-constraints.txt
#
aiohappyeyeballs==2.4.4
aiohappyeyeballs==2.4.6
# via aiohttp
aiohttp==3.11.12
# via
Expand Down Expand Up @@ -220,7 +220,7 @@ google-crc32c==1.6.0
# via
# google-cloud-storage
# google-resumable-media
google-genai==0.8.0
google-genai==1.0.0
# via -r requirements-constraints.txt
google-resumable-media==2.7.2
# via
Expand Down Expand Up @@ -678,7 +678,7 @@ sentencepiece==0.2.0
# via
# sentence-transformers
# transformers
sentry-protos==0.1.59
sentry-protos==0.1.60
# via -r requirements-constraints.txt
sentry-sdk==2.18.0
# via -r requirements-constraints.txt
Expand Down
7 changes: 6 additions & 1 deletion src/seer/automation/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ def call_tool(self, tool_call: ToolCall) -> Message:
kwargs = self.parse_tool_arguments(tool, tool_call.args)
tool_result = tool.call(**kwargs)

return Message(role="tool", content=tool_result, tool_call_id=tool_call.id)
return Message(
role="tool",
content=tool_result,
tool_call_id=tool_call.id,
tool_call_function=tool_call.function,
)

def get_tool_by_name(self, name: str) -> FunctionTool:
try:
Expand Down
106 changes: 65 additions & 41 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
)
from google import genai # type: ignore[attr-defined]
from google.api_core.exceptions import ResourceExhausted
from google.genai.types import ( # type: ignore[import-untyped]
from google.genai.types import (
Content,
FunctionCall,
FunctionDeclaration,
FunctionResponse,
GenerateContentConfig,
GenerateContentResponse,
GoogleSearch,
Part,
Tool,
)
from google.genai.types import Tool as GeminiTool
from langfuse.decorators import langfuse_context, observe
from langfuse.openai import openai
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
Expand Down Expand Up @@ -722,7 +720,7 @@ def _get_config(cls, model_name: str):
@observe(as_type="generation", name="Gemini Generation with Grounding")
def search_the_web(self, prompt: str, temperature: float | None = None) -> str:
client = self.get_client()
google_search_tool = Tool(google_search=GoogleSearch())
google_search_tool = GeminiTool(google_search=GoogleSearch())

response = client.models.generate_content(
model=self.model_name,
Expand All @@ -734,8 +732,14 @@ def search_the_web(self, prompt: str, temperature: float | None = None) -> str:
),
)
answer = ""
for each in response.candidates[0].content.parts:
answer += each.text
if (
response.candidates
and response.candidates[0].content
and response.candidates[0].content.parts
):
for each in response.candidates[0].content.parts:
if each.text:
answer += each.text
return answer

@staticmethod
Expand Down Expand Up @@ -923,54 +927,74 @@ def _prep_message_and_tools(
prompt: str | None = None,
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
) -> tuple[list[Content], list[Tool] | None, str | None]:
contents = [cls.to_content(message) for message in messages] if messages else []
) -> tuple[list[Content], list[GeminiTool] | None, str | None]:
contents: list[Content] = []

if messages:
# Group consecutive tool messages together
grouped_messages: list[list[Message]] = []
current_group: list[Message] = []

for message in messages:
if message.role == "tool":
current_group.append(message)
else:
if current_group:
grouped_messages.append(current_group)
current_group = []
grouped_messages.append([message])

if current_group:
grouped_messages.append(current_group)

# Convert each group into a Content object
for group in grouped_messages:
if len(group) == 1 and group[0].role != "tool":
contents.append(cls.to_content(group[0]))
elif group[0].role == "tool":
# Combine multiple tool messages into a single Content
parts = [
Part.from_function_response(
name=msg.tool_call_function or "",
response={"response": msg.content},
)
for msg in group
]
contents.append(Content(role="user", parts=parts))

if prompt:
contents.append(cls.to_content(Message(role="user", content=prompt)))
contents.append(
Content(
role="user",
parts=[Part(text=prompt)],
)
)

tools = [cls.to_tool(tool) for tool in tools] if tools else []
processed_tools = [cls.to_tool(tool) for tool in tools] if tools else []

return contents, tools, system_prompt
return contents, processed_tools, system_prompt

@staticmethod
def to_content(message: Message) -> Content:
if message.role == "tool":
return Content(
role="user",
parts=[
Part(
function_response=FunctionResponse(
name=message.tool_calls[0].function if message.tool_calls else "",
response=(
json.loads(message.tool_calls[0].args) if message.tool_calls else {}
),
),
)
],
)
elif message.role == "tool_use":
if message.role == "tool_use":
if not message.tool_calls:
return Content(
role="model",
parts=[Part(text=message.content or "")],
)
tool_call = message.tool_calls[0] # Assuming only one tool call per message

parts = []
if message.content:
parts.append(Part(text=message.content))
if tool_call:
for tool_call in message.tool_calls:
parts.append(
Part(
function_call=FunctionCall(
name=tool_call.function,
args=json.loads(tool_call.args),
),
Part.from_function_call(
name=tool_call.function,
args=json.loads(tool_call.args),
)
)
return Content(
role="model",
parts=parts,
)
return Content(role="model", parts=parts)

elif message.role == "assistant":
return Content(
role="model",
Expand All @@ -983,8 +1007,8 @@ def to_content(message: Message) -> Content:
)

@staticmethod
def to_tool(tool: FunctionTool) -> Tool:
return Tool(
def to_tool(tool: FunctionTool) -> GeminiTool:
return GeminiTool(
function_declarations=[
FunctionDeclaration(
name=tool.name,
Expand Down Expand Up @@ -1054,7 +1078,7 @@ def _format_gemini_response_to_message(self, response: GenerateContentResponse)
message.tool_calls.append(
ToolCall(
id=part.function_call.id,
function=part.function_call.name,
function=part.function_call.name or "",
args=json.dumps(part.function_call.args),
)
)
Expand Down
4 changes: 4 additions & 0 deletions src/seer/automation/agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class Message(BaseModel):
"""The tool calls generated by the model, such as function calls."""

tool_call_id: Optional[str] = None
"""The ID of the tool call."""

tool_call_function: Optional[str] = None
"""The function of the tool call."""


class LlmResponseMetadata(BaseModel):
Expand Down
39 changes: 27 additions & 12 deletions tests/automation/agent/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def test_gemini_generate_text():
)

assert isinstance(response, LlmGenerateTextResponse)
assert response.message.content is not None
assert response.message.content.strip() == "Hello! How can I help you today?"
assert response.message.role == "assistant"
assert response.metadata.model == "gemini-2.0-flash-exp"
Expand Down Expand Up @@ -543,8 +544,10 @@ def test_gemini_generate_text_with_tools():
)

assert isinstance(response, LlmGenerateTextResponse)
assert response.message.content is not None
assert len(response.message.content) > 0
assert response.message.role == "tool_use"
assert response.message.tool_calls is not None
assert response.message.tool_calls == [
ToolCall(
function="test_function",
Expand Down Expand Up @@ -599,18 +602,34 @@ def test_gemini_prep_message_and_tools():

assert len(message_dicts) == 3
assert message_dicts[0].role == "user"
assert message_dicts[0].parts is not None
assert len(message_dicts[0].parts) > 0
assert message_dicts[0].parts[0].text == "Hello"
assert message_dicts[1].role == "model"
assert message_dicts[1].parts is not None
assert len(message_dicts[1].parts) > 0
assert message_dicts[1].parts[0].text == "Hi there!"
assert message_dicts[2].role == "user"
assert message_dicts[2].parts is not None
assert len(message_dicts[2].parts) > 0
assert message_dicts[2].parts[0].text == prompt

assert tool_dicts
assert tool_dicts is not None
assert len(tool_dicts) == 1
if tool_dicts:
assert len(tool_dicts) == 1
assert tool_dicts[0].function_declarations[0].name == "test_function"
assert tool_dicts[0].function_declarations[0].description == "A test function"
assert tool_dicts[0].function_declarations[0].parameters.properties["x"].type == "STRING"
assert tool_dicts[0].function_declarations
if tool_dicts[0].function_declarations:
assert len(tool_dicts[0].function_declarations) > 0
assert tool_dicts[0].function_declarations[0].name == "test_function"
assert tool_dicts[0].function_declarations[0].description == "A test function"
assert tool_dicts[0].function_declarations[0].parameters
if tool_dicts[0].function_declarations[0].parameters:
assert tool_dicts[0].function_declarations[0].parameters.type == "OBJECT"
assert "x" in tool_dicts[0].function_declarations[0].parameters.properties
assert (
tool_dicts[0].function_declarations[0].parameters.properties["x"].type
== "STRING"
)


@pytest.mark.vcr()
Expand Down Expand Up @@ -680,20 +699,16 @@ def test_gemini_generate_text_stream_with_tools():


def test_construct_message_from_stream_gemini():
llm_client = LlmClient()
model = GeminiProvider.model("gemini-2.0-flash-exp")

content_chunks = ["Hello", " world", "!"]
tool_calls = [ToolCall(id="123", function="test_function", args='{"x": "test"}')]

message = llm_client.construct_message_from_stream(
content_chunks=content_chunks,
tool_calls=tool_calls,
model=model,
)
message = model.construct_message_from_stream(content_chunks, tool_calls)

assert message.role == "tool_use"
assert message.role == ("tool_use" if tool_calls else "assistant")
assert message.content == "Hello world!"
assert message.tool_calls is not None
assert len(message.tool_calls) == 1
assert message.tool_calls[0].id == "123"
assert message.tool_calls[0].function == "test_function"
Expand Down
1 change: 1 addition & 0 deletions tests/automation/autofix/test_autofix_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def test_store_memory(self):
"content": "Test message",
"tool_call_id": None,
"tool_calls": None,
"tool_call_function": None,
}
]
},
Expand Down
Loading