Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions llama-index-core/llama_index/core/base/llms/test_llm_types.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to commit this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it was intented for local testing only. I have deleted the file now.

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import base64
import pytest

# Use a relative import to refer to the local types.py file
from . import types

# --- Mock Object ---
class MockGoogleFunctionCall:
"""A mock object that mimics the structure of Google's FunctionCall."""

def __init__(self, name: str, args: dict):
self.name = name
self.args = args

def to_dict(self) -> dict:
"""The method our serializer fix relies on."""
return {"name": self.name, "args": self.args}

def __repr__(self) -> str:
return f"<MockGoogleFunctionCall object name='{self.name}'>"


# --- Test Cases ---

def test_chat_message_with_google_function_call_serialization(monkeypatch):
"""
Tests if a ChatMessage containing a mock Google FunctionCall object
can be successfully serialized to JSON.
"""
# 1. Arrange
monkeypatch.setattr(types, "GOOGLE_FUNCTION_CALL_AVAILABLE", True)
monkeypatch.setattr(types, "FunctionCall", MockGoogleFunctionCall)

function_call_object = MockGoogleFunctionCall(
name="get_current_weather",
args={"location": "Boston, MA"},
)
message = types.ChatMessage(
role=types.MessageRole.ASSISTANT,
additional_kwargs={"tool_calls": [function_call_object]},
)

# 2. Act
serialized_json = message.model_dump_json()

# 3. Assert
deserialized_data = json.loads(serialized_json)
tool_calls = deserialized_data["additional_kwargs"]["tool_calls"]
assert tool_calls[0]["name"] == "get_current_weather"

def test_chat_message_str_method():
"""Tests the string representation of a ChatMessage."""
message = types.ChatMessage(role="user", content="Hello, world!")
assert str(message) == "user: Hello, world!"

def test_chat_message_from_str():
"""Tests creating a ChatMessage using the from_str factory method."""
# Test with default role
message_user = types.ChatMessage.from_str("This is a test.")
assert message_user.role == types.MessageRole.USER
assert message_user.content == "This is a test."

# Test with a specified string role
message_asst = types.ChatMessage.from_str("I can help.", role="assistant")
assert message_asst.role == types.MessageRole.ASSISTANT
assert message_asst.content == "I can help."

def test_serialization_of_bytes_in_kwargs():
"""Tests if raw bytes in additional_kwargs are correctly base64 encoded."""
raw_bytes = b"some binary data"
message = types.ChatMessage(
role="user",
additional_kwargs={"data": raw_bytes}
)

# Act
serialized_json = message.model_dump_json()
deserialized_data = json.loads(serialized_json)

# Assert
expected_b64_string = base64.b64encode(raw_bytes).decode("utf-8")
assert deserialized_data["additional_kwargs"]["data"] == expected_b64_string
27 changes: 27 additions & 0 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
from llama_index.core.schema import ImageDocument
from llama_index.core.utils import resolve_binary

GOOGLE_FUNCTION_CALL_AVAILABLE = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to import types from google here

try:
from google.ai.generativelanguage_v1beta.types.content import FunctionCall

GOOGLE_FUNCTION_CALL_AVAILABLE = True
except ImportError:

class FunctionCall:
pass


class MessageRole(str, Enum):
"""Message role."""
Expand Down Expand Up @@ -566,6 +576,23 @@ def _recursive_serialization(self, value: Any) -> Any:

@field_serializer("additional_kwargs", check_fields=False)
def serialize_additional_kwargs(self, value: Any, _info: Any) -> Any:
if GOOGLE_FUNCTION_CALL_AVAILABLE and isinstance(value, dict):
original_tool_calls = value.get("tool_calls")
if isinstance(original_tool_calls, list):
# Create a new list that is guaranteed to be serializable
serializable_tool_calls = []
for tc in original_tool_calls:
# If we find a FunctionCall object, convert it to a dict
if isinstance(tc, FunctionCall):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to check for FunctionCall, we should be checking for whatever the root type of FunctionCall is

Looking at the source code, its a base model. Which is actually already being checked in self._recursive_serialization(value) -- but I think its not recursing nicely into the pydantic object?

serializable_tool_calls.append(tc.to_dict())
else:
# Otherwise, append the item as is
serializable_tool_calls.append(tc)

# Update the dictionary with the sanitized list of tool calls
value["tool_calls"] = serializable_tool_calls

# Now, safely serialize the entire dictionary
return self._recursive_serialization(value)


Expand Down
Loading