Skip to content

Commit 58cc44f

Browse files
authored
fix: handle missing 'data' key in memory payload during search operations (#3524)
1 parent 445286a commit 58cc44f

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

mem0/memory/main.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def _add_to_vector_store(self, messages, metadata, filters, infer):
382382
filters=filters,
383383
)
384384
for mem in existing_memories:
385-
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
385+
retrieved_old_memory.append({"id": mem.id, "text": mem.payload.get("data", "")})
386386

387387
unique_data = {}
388388
for item in retrieved_old_memory:
@@ -518,7 +518,7 @@ def get(self, memory_id):
518518

519519
result_item = MemoryItem(
520520
id=memory.id,
521-
memory=memory.payload["data"],
521+
memory=memory.payload.get("data", ""),
522522
hash=memory.payload.get("hash"),
523523
created_at=memory.payload.get("created_at"),
524524
updated_at=memory.payload.get("updated_at"),
@@ -623,7 +623,7 @@ def _get_all_from_vector_store(self, filters, limit):
623623
for mem in actual_memories:
624624
memory_item_dict = MemoryItem(
625625
id=mem.id,
626-
memory=mem.payload["data"],
626+
memory=mem.payload.get("data", ""),
627627
hash=mem.payload.get("hash"),
628628
created_at=mem.payload.get("created_at"),
629629
updated_at=mem.payload.get("updated_at"),
@@ -735,7 +735,7 @@ def _search_vector_store(self, query, filters, limit, threshold: Optional[float]
735735
for mem in memories:
736736
memory_item_dict = MemoryItem(
737737
id=mem.id,
738-
memory=mem.payload["data"],
738+
memory=mem.payload.get("data", ""),
739739
hash=mem.payload.get("hash"),
740740
created_at=mem.payload.get("created_at"),
741741
updated_at=mem.payload.get("updated_at"),
@@ -962,7 +962,7 @@ def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
962962
def _delete_memory(self, memory_id):
963963
logger.info(f"Deleting memory with {memory_id=}")
964964
existing_memory = self.vector_store.get(vector_id=memory_id)
965-
prev_value = existing_memory.payload["data"]
965+
prev_value = existing_memory.payload.get("data", "")
966966
self.vector_store.delete(vector_id=memory_id)
967967
self.db.add_history(
968968
memory_id,
@@ -1234,7 +1234,7 @@ async def process_fact_for_search(new_mem_content):
12341234
limit=5,
12351235
filters=effective_filters, # 'filters' is query_filters_for_inference
12361236
)
1237-
return [{"id": mem.id, "text": mem.payload["data"]} for mem in existing_mems]
1237+
return [{"id": mem.id, "text": mem.payload.get("data", "")} for mem in existing_mems]
12381238

12391239
search_tasks = [process_fact_for_search(fact) for fact in new_retrieved_facts]
12401240
search_results_list = await asyncio.gather(*search_tasks)
@@ -1382,7 +1382,7 @@ async def get(self, memory_id):
13821382

13831383
result_item = MemoryItem(
13841384
id=memory.id,
1385-
memory=memory.payload["data"],
1385+
memory=memory.payload.get("data", ""),
13861386
hash=memory.payload.get("hash"),
13871387
created_at=memory.payload.get("created_at"),
13881388
updated_at=memory.payload.get("updated_at"),
@@ -1492,7 +1492,7 @@ async def _get_all_from_vector_store(self, filters, limit):
14921492
for mem in actual_memories:
14931493
memory_item_dict = MemoryItem(
14941494
id=mem.id,
1495-
memory=mem.payload["data"],
1495+
memory=mem.payload.get("data", ""),
14961496
hash=mem.payload.get("hash"),
14971497
created_at=mem.payload.get("created_at"),
14981498
updated_at=mem.payload.get("updated_at"),
@@ -1609,7 +1609,7 @@ async def _search_vector_store(self, query, filters, limit, threshold: Optional[
16091609
for mem in memories:
16101610
memory_item_dict = MemoryItem(
16111611
id=mem.id,
1612-
memory=mem.payload["data"],
1612+
memory=mem.payload.get("data", ""),
16131613
hash=mem.payload.get("hash"),
16141614
created_at=mem.payload.get("created_at"),
16151615
updated_at=mem.payload.get("updated_at"),
@@ -1860,7 +1860,7 @@ async def _update_memory(self, memory_id, data, existing_embeddings, metadata=No
18601860
async def _delete_memory(self, memory_id):
18611861
logger.info(f"Deleting memory with {memory_id=}")
18621862
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
1863-
prev_value = existing_memory.payload["data"]
1863+
prev_value = existing_memory.payload.get("data", "")
18641864

18651865
await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id)
18661866
await asyncio.to_thread(

tests/test_memory.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
from mem0.configs.base import MemoryConfig
77

88

9+
class MockVectorMemory:
10+
"""Mock memory object for testing incomplete payloads."""
11+
12+
def __init__(self, memory_id: str, payload: dict, score: float = 0.8):
13+
self.id = memory_id
14+
self.payload = payload
15+
self.score = score
16+
17+
918
@pytest.fixture
1019
def memory_client():
1120
with patch.object(Memory, "__init__", return_value=None):
@@ -95,3 +104,40 @@ def test_collection_name_preserved_after_reset(mock_sqlite, mock_llm_factory, mo
95104
if reset_calls:
96105
reset_config = reset_calls[-1][0][1]
97106
assert reset_config.collection_name == test_collection_name, f"Reset used wrong collection name: {reset_config.collection_name}"
107+
108+
109+
@patch('mem0.utils.factory.EmbedderFactory.create')
110+
@patch('mem0.utils.factory.VectorStoreFactory.create')
111+
@patch('mem0.utils.factory.LlmFactory.create')
112+
@patch('mem0.memory.storage.SQLiteManager')
113+
def test_search_handles_incomplete_payloads(mock_sqlite, mock_llm_factory, mock_vector_factory, mock_embedder_factory):
114+
"""Test that search operations handle memory objects with missing 'data' key gracefully."""
115+
mock_embedder_factory.return_value = MagicMock()
116+
mock_vector_store = MagicMock()
117+
mock_vector_factory.return_value = mock_vector_store
118+
mock_llm_factory.return_value = MagicMock()
119+
mock_sqlite.return_value = MagicMock()
120+
121+
from mem0.memory.main import Memory as MemoryClass
122+
config = MemoryConfig()
123+
memory = MemoryClass(config)
124+
125+
# Create test data with both complete and incomplete payloads
126+
incomplete_memory = MockVectorMemory("mem_1", {"hash": "abc123"})
127+
complete_memory = MockVectorMemory("mem_2", {"data": "content", "hash": "def456"})
128+
129+
mock_vector_store.search.return_value = [incomplete_memory, complete_memory]
130+
131+
mock_embedder = MagicMock()
132+
mock_embedder.embed.return_value = [0.1, 0.2, 0.3]
133+
memory.embedding_model = mock_embedder
134+
135+
# This should not raise KeyError even with incomplete payloads
136+
result = memory._search_vector_store("test", {"user_id": "test"}, 10)
137+
138+
assert len(result) == 2
139+
memories_by_id = {mem["id"]: mem for mem in result}
140+
141+
# Verify defensive programming works correctly
142+
assert memories_by_id["mem_1"]["memory"] == "" # Missing data gets empty string
143+
assert memories_by_id["mem_2"]["memory"] == "content" # Normal data preserved

0 commit comments

Comments
 (0)