-
Notifications
You must be signed in to change notification settings - Fork 6.4k
fixV2(core): Resolve PydanticSerializationError for Google FunctionCall #20059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
22ecffa
a000254
959d6c9
81a59ae
6404b32
70c4434
123d216
593e3fd
263db4b
17d9e07
dd841e3
fc1d9a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import json | ||
import base64 | ||
import pytest | ||
|
||
# Use a relative import to refer to the local types.py file | ||
from . import types | ||
|
||
# --- Mock Object --- | ||
class MockGoogleFunctionCall: | ||
"""A mock object that mimics the structure of Google's FunctionCall.""" | ||
|
||
def __init__(self, name: str, args: dict): | ||
self.name = name | ||
self.args = args | ||
|
||
def to_dict(self) -> dict: | ||
"""The method our serializer fix relies on.""" | ||
return {"name": self.name, "args": self.args} | ||
|
||
def __repr__(self) -> str: | ||
return f"<MockGoogleFunctionCall object name='{self.name}'>" | ||
|
||
|
||
# --- Test Cases --- | ||
|
||
def test_chat_message_with_google_function_call_serialization(monkeypatch): | ||
""" | ||
Tests if a ChatMessage containing a mock Google FunctionCall object | ||
can be successfully serialized to JSON. | ||
""" | ||
# 1. Arrange | ||
monkeypatch.setattr(types, "GOOGLE_FUNCTION_CALL_AVAILABLE", True) | ||
monkeypatch.setattr(types, "FunctionCall", MockGoogleFunctionCall) | ||
|
||
function_call_object = MockGoogleFunctionCall( | ||
name="get_current_weather", | ||
args={"location": "Boston, MA"}, | ||
) | ||
message = types.ChatMessage( | ||
role=types.MessageRole.ASSISTANT, | ||
additional_kwargs={"tool_calls": [function_call_object]}, | ||
) | ||
|
||
# 2. Act | ||
serialized_json = message.model_dump_json() | ||
|
||
# 3. Assert | ||
deserialized_data = json.loads(serialized_json) | ||
tool_calls = deserialized_data["additional_kwargs"]["tool_calls"] | ||
assert tool_calls[0]["name"] == "get_current_weather" | ||
|
||
def test_chat_message_str_method(): | ||
"""Tests the string representation of a ChatMessage.""" | ||
message = types.ChatMessage(role="user", content="Hello, world!") | ||
assert str(message) == "user: Hello, world!" | ||
|
||
def test_chat_message_from_str(): | ||
"""Tests creating a ChatMessage using the from_str factory method.""" | ||
# Test with default role | ||
message_user = types.ChatMessage.from_str("This is a test.") | ||
assert message_user.role == types.MessageRole.USER | ||
assert message_user.content == "This is a test." | ||
|
||
# Test with a specified string role | ||
message_asst = types.ChatMessage.from_str("I can help.", role="assistant") | ||
assert message_asst.role == types.MessageRole.ASSISTANT | ||
assert message_asst.content == "I can help." | ||
|
||
def test_serialization_of_bytes_in_kwargs(): | ||
"""Tests if raw bytes in additional_kwargs are correctly base64 encoded.""" | ||
raw_bytes = b"some binary data" | ||
message = types.ChatMessage( | ||
role="user", | ||
additional_kwargs={"data": raw_bytes} | ||
) | ||
|
||
# Act | ||
serialized_json = message.model_dump_json() | ||
deserialized_data = json.loads(serialized_json) | ||
|
||
# Assert | ||
expected_b64_string = base64.b64encode(raw_bytes).decode("utf-8") | ||
assert deserialized_data["additional_kwargs"]["data"] == expected_b64_string |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,16 @@ | |
from llama_index.core.schema import ImageDocument | ||
from llama_index.core.utils import resolve_binary | ||
|
||
GOOGLE_FUNCTION_CALL_AVAILABLE = False | ||
|
||
try: | ||
from google.ai.generativelanguage_v1beta.types.content import FunctionCall | ||
|
||
GOOGLE_FUNCTION_CALL_AVAILABLE = True | ||
except ImportError: | ||
|
||
class FunctionCall: | ||
pass | ||
|
||
|
||
class MessageRole(str, Enum): | ||
"""Message role.""" | ||
|
@@ -566,6 +576,23 @@ def _recursive_serialization(self, value: Any) -> Any: | |
|
||
@field_serializer("additional_kwargs", check_fields=False) | ||
def serialize_additional_kwargs(self, value: Any, _info: Any) -> Any: | ||
if GOOGLE_FUNCTION_CALL_AVAILABLE and isinstance(value, dict): | ||
original_tool_calls = value.get("tool_calls") | ||
if isinstance(original_tool_calls, list): | ||
# Create a new list that is guaranteed to be serializable | ||
serializable_tool_calls = [] | ||
for tc in original_tool_calls: | ||
# If we find a FunctionCall object, convert it to a dict | ||
if isinstance(tc, FunctionCall): | ||
|
||
serializable_tool_calls.append(tc.to_dict()) | ||
else: | ||
# Otherwise, append the item as is | ||
serializable_tool_calls.append(tc) | ||
|
||
# Update the dictionary with the sanitized list of tool calls | ||
value["tool_calls"] = serializable_tool_calls | ||
|
||
# Now, safely serialize the entire dictionary | ||
return self._recursive_serialization(value) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to commit this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no it was intented for local testing only. I have deleted the file now.