Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.storage.chat_store.base import BaseChatStore

import json


def safe_model_dump_json(model):
"""
Safely dumps a Pydantic model to JSON, even if it contains
non-serializable objects like Google FunctionCall.
"""
try:
return model.model_dump_json(exclude_none=True)
except Exception:
data = model.model_dump(exclude_none=True)
return json.dumps(data, default=lambda o: getattr(o, "__dict__", str(o)))


def get_data_model(
base: type,
Expand Down Expand Up @@ -192,7 +206,7 @@ def set_messages(self, key: str, messages: list[ChatMessage]) -> None:

params = {
"key": key,
"value": [message.model_dump_json() for message in messages],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the better fix here is fixing the pydantic model for the chat message

It already has handling for special objects in additional_kwargs, but probably is just missing a case to handle whatever type you are encountering with google

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@field_serializer("additional_kwargs", check_fields=False)

"value": [safe_model_dump_json(message) for message in messages],
}

# Execute the bulk upsert
Expand All @@ -214,7 +228,7 @@ async def aset_messages(self, key: str, messages: list[ChatMessage]) -> None:

params = {
"key": key,
"value": [message.model_dump_json() for message in messages],
"value": [safe_model_dump_json(message) for message in messages],
}

# Execute the bulk upsert
Expand Down Expand Up @@ -257,7 +271,7 @@ def add_message(self, key: str, message: ChatMessage) -> None:
value = array_cat({self._table_class.__tablename__}.value, :value);
"""
)
params = {"key": key, "value": [message.model_dump_json()]}
params = {"key": key, "value": [safe_model_dump_json(message)]}
session.execute(stmt, params)
session.commit()

Expand All @@ -273,7 +287,7 @@ async def async_add_message(self, key: str, message: ChatMessage) -> None:
value = array_cat({self._table_class.__tablename__}.value, :value);
"""
)
params = {"key": key, "value": [message.model_dump_json()]}
params = {"key": key, "value": [safe_model_dump_json(message)]}
await session.execute(stmt, params)
await session.commit()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dev = [

[project]
name = "llama-index-storage-chat-store-postgres"
version = "0.3.1"
version = "0.3.2"
description = "llama-index storage-chat-store postgres integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand Down
Loading