Skip to content

Commit dc883b0

Browse files
[test] Add comprehensive test suite for SQLiteManager (#3494)
1 parent 5616844 commit dc883b0

File tree

1 file changed

+282
-0
lines changed

1 file changed

+282
-0
lines changed

tests/memory/test_storage.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import os
2+
import sqlite3
3+
import tempfile
4+
import uuid
5+
from datetime import datetime
6+
7+
import pytest
8+
9+
from mem0.memory.storage import SQLiteManager
10+
11+
12+
class TestSQLiteManager:
13+
"""Comprehensive test cases for SQLiteManager class."""
14+
15+
@pytest.fixture
16+
def temp_db_path(self):
17+
"""Create temporary database file."""
18+
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
19+
temp_db.close()
20+
yield temp_db.name
21+
if os.path.exists(temp_db.name):
22+
os.unlink(temp_db.name)
23+
24+
@pytest.fixture
25+
def sqlite_manager(self, temp_db_path):
26+
"""Create SQLiteManager instance with temporary database."""
27+
manager = SQLiteManager(temp_db_path)
28+
yield manager
29+
if manager.connection:
30+
manager.close()
31+
32+
@pytest.fixture
33+
def memory_manager(self):
34+
"""Create in-memory SQLiteManager instance."""
35+
manager = SQLiteManager(":memory:")
36+
yield manager
37+
if manager.connection:
38+
manager.close()
39+
40+
@pytest.fixture
41+
def sample_data(self):
42+
"""Sample test data."""
43+
now = datetime.now().isoformat()
44+
return {
45+
"memory_id": str(uuid.uuid4()),
46+
"old_memory": "Old memory content",
47+
"new_memory": "New memory content",
48+
"event": "ADD",
49+
"created_at": now,
50+
"updated_at": now,
51+
"actor_id": "test_actor",
52+
"role": "user",
53+
}
54+
55+
# ========== Initialization Tests ==========
56+
57+
@pytest.mark.parametrize("db_type,path", [("file", "temp_db_path"), ("memory", ":memory:")])
58+
def test_initialization(self, db_type, path, request):
59+
"""Test SQLiteManager initialization with different database types."""
60+
if db_type == "file":
61+
db_path = request.getfixturevalue(path)
62+
else:
63+
db_path = path
64+
65+
manager = SQLiteManager(db_path)
66+
assert manager.connection is not None
67+
assert manager.db_path == db_path
68+
manager.close()
69+
70+
def test_table_schema_creation(self, sqlite_manager):
71+
"""Test that history table is created with correct schema."""
72+
cursor = sqlite_manager.connection.cursor()
73+
cursor.execute("PRAGMA table_info(history)")
74+
columns = {row[1] for row in cursor.fetchall()}
75+
76+
expected_columns = {
77+
"id",
78+
"memory_id",
79+
"old_memory",
80+
"new_memory",
81+
"event",
82+
"created_at",
83+
"updated_at",
84+
"is_deleted",
85+
"actor_id",
86+
"role",
87+
}
88+
assert columns == expected_columns
89+
90+
# ========== Add History Tests ==========
91+
92+
def test_add_history_basic(self, sqlite_manager, sample_data):
93+
"""Test basic add_history functionality."""
94+
sqlite_manager.add_history(
95+
memory_id=sample_data["memory_id"],
96+
old_memory=sample_data["old_memory"],
97+
new_memory=sample_data["new_memory"],
98+
event=sample_data["event"],
99+
created_at=sample_data["created_at"],
100+
actor_id=sample_data["actor_id"],
101+
role=sample_data["role"],
102+
)
103+
104+
cursor = sqlite_manager.connection.cursor()
105+
cursor.execute("SELECT * FROM history WHERE memory_id = ?", (sample_data["memory_id"],))
106+
result = cursor.fetchone()
107+
108+
assert result is not None
109+
assert result[1] == sample_data["memory_id"]
110+
assert result[2] == sample_data["old_memory"]
111+
assert result[3] == sample_data["new_memory"]
112+
assert result[4] == sample_data["event"]
113+
assert result[8] == sample_data["actor_id"]
114+
assert result[9] == sample_data["role"]
115+
116+
@pytest.mark.parametrize(
117+
"old_memory,new_memory,is_deleted", [(None, "New memory", 0), ("Old memory", None, 1), (None, None, 1)]
118+
)
119+
def test_add_history_optional_params(self, sqlite_manager, sample_data, old_memory, new_memory, is_deleted):
120+
"""Test add_history with various optional parameter combinations."""
121+
sqlite_manager.add_history(
122+
memory_id=sample_data["memory_id"],
123+
old_memory=old_memory,
124+
new_memory=new_memory,
125+
event="UPDATE",
126+
updated_at=sample_data["updated_at"],
127+
is_deleted=is_deleted,
128+
actor_id=sample_data["actor_id"],
129+
role=sample_data["role"],
130+
)
131+
132+
cursor = sqlite_manager.connection.cursor()
133+
cursor.execute("SELECT * FROM history WHERE memory_id = ?", (sample_data["memory_id"],))
134+
result = cursor.fetchone()
135+
136+
assert result[2] == old_memory
137+
assert result[3] == new_memory
138+
assert result[6] == sample_data["updated_at"]
139+
assert result[7] == is_deleted
140+
141+
def test_add_history_generates_unique_ids(self, sqlite_manager, sample_data):
142+
"""Test that add_history generates unique IDs for each record."""
143+
for i in range(3):
144+
sqlite_manager.add_history(
145+
memory_id=sample_data["memory_id"],
146+
old_memory=f"Memory {i}",
147+
new_memory=f"Updated Memory {i}",
148+
event="ADD" if i == 0 else "UPDATE",
149+
)
150+
151+
cursor = sqlite_manager.connection.cursor()
152+
cursor.execute("SELECT id FROM history WHERE memory_id = ?", (sample_data["memory_id"],))
153+
ids = [row[0] for row in cursor.fetchall()]
154+
155+
assert len(ids) == 3
156+
assert len(set(ids)) == 3
157+
158+
# ========== Get History Tests ==========
159+
160+
def test_get_history_empty(self, sqlite_manager):
161+
"""Test get_history for non-existent memory_id."""
162+
result = sqlite_manager.get_history("non-existent-id")
163+
assert result == []
164+
165+
def test_get_history_single_record(self, sqlite_manager, sample_data):
166+
"""Test get_history for single record."""
167+
sqlite_manager.add_history(
168+
memory_id=sample_data["memory_id"],
169+
old_memory=sample_data["old_memory"],
170+
new_memory=sample_data["new_memory"],
171+
event=sample_data["event"],
172+
created_at=sample_data["created_at"],
173+
actor_id=sample_data["actor_id"],
174+
role=sample_data["role"],
175+
)
176+
177+
result = sqlite_manager.get_history(sample_data["memory_id"])
178+
179+
assert len(result) == 1
180+
record = result[0]
181+
assert record["memory_id"] == sample_data["memory_id"]
182+
assert record["old_memory"] == sample_data["old_memory"]
183+
assert record["new_memory"] == sample_data["new_memory"]
184+
assert record["event"] == sample_data["event"]
185+
assert record["created_at"] == sample_data["created_at"]
186+
assert record["actor_id"] == sample_data["actor_id"]
187+
assert record["role"] == sample_data["role"]
188+
assert record["is_deleted"] is False
189+
190+
def test_get_history_chronological_ordering(self, sqlite_manager, sample_data):
191+
"""Test get_history returns records in chronological order."""
192+
import time
193+
194+
timestamps = []
195+
for i in range(3):
196+
ts = datetime.now().isoformat()
197+
timestamps.append(ts)
198+
sqlite_manager.add_history(
199+
memory_id=sample_data["memory_id"],
200+
old_memory=f"Memory {i}",
201+
new_memory=f"Memory {i+1}",
202+
event="ADD" if i == 0 else "UPDATE",
203+
created_at=ts,
204+
updated_at=ts if i > 0 else None,
205+
)
206+
time.sleep(0.01)
207+
208+
result = sqlite_manager.get_history(sample_data["memory_id"])
209+
result_timestamps = [r["created_at"] for r in result]
210+
assert result_timestamps == sorted(timestamps)
211+
212+
def test_migration_preserves_data(self, temp_db_path, sample_data):
213+
"""Test that migration preserves existing data."""
214+
manager1 = SQLiteManager(temp_db_path)
215+
manager1.add_history(
216+
memory_id=sample_data["memory_id"],
217+
old_memory=sample_data["old_memory"],
218+
new_memory=sample_data["new_memory"],
219+
event=sample_data["event"],
220+
created_at=sample_data["created_at"],
221+
)
222+
original_data = manager1.get_history(sample_data["memory_id"])
223+
manager1.close()
224+
225+
manager2 = SQLiteManager(temp_db_path)
226+
migrated_data = manager2.get_history(sample_data["memory_id"])
227+
manager2.close()
228+
229+
assert len(migrated_data) == len(original_data)
230+
assert migrated_data[0]["memory_id"] == original_data[0]["memory_id"]
231+
assert migrated_data[0]["new_memory"] == original_data[0]["new_memory"]
232+
233+
def test_large_batch_operations(self, sqlite_manager):
234+
"""Test performance with large batch of operations."""
235+
batch_size = 1000
236+
memory_ids = [str(uuid.uuid4()) for _ in range(batch_size)]
237+
for i, memory_id in enumerate(memory_ids):
238+
sqlite_manager.add_history(
239+
memory_id=memory_id, old_memory=None, new_memory=f"Batch memory {i}", event="ADD"
240+
)
241+
242+
cursor = sqlite_manager.connection.cursor()
243+
cursor.execute("SELECT COUNT(*) FROM history")
244+
count = cursor.fetchone()[0]
245+
assert count == batch_size
246+
247+
for memory_id in memory_ids[:10]:
248+
result = sqlite_manager.get_history(memory_id)
249+
assert len(result) == 1
250+
251+
# ========== Tests for Migration, Reset, and Close ==========
252+
253+
def test_explicit_old_schema_migration(self, temp_db_path):
254+
"""Test migration path from a legacy schema to new schema."""
255+
# Create a legacy 'history' table missing new columns
256+
legacy_conn = sqlite3.connect(temp_db_path)
257+
legacy_conn.execute("""
258+
CREATE TABLE history (
259+
id TEXT PRIMARY KEY,
260+
memory_id TEXT,
261+
old_memory TEXT,
262+
new_memory TEXT,
263+
event TEXT,
264+
created_at DATETIME
265+
)
266+
""")
267+
legacy_id = str(uuid.uuid4())
268+
legacy_conn.execute(
269+
"INSERT INTO history (id, memory_id, old_memory, new_memory, event, created_at) VALUES (?, ?, ?, ?, ?, ?)",
270+
(legacy_id, "m1", "o", "n", "ADD", datetime.now().isoformat()),
271+
)
272+
legacy_conn.commit()
273+
legacy_conn.close()
274+
275+
# Trigger migration
276+
mgr = SQLiteManager(temp_db_path)
277+
history = mgr.get_history("m1")
278+
assert len(history) == 1
279+
assert history[0]["id"] == legacy_id
280+
assert history[0]["actor_id"] is None
281+
assert history[0]["is_deleted"] is False
282+
mgr.close()

0 commit comments

Comments
 (0)