Skip to content

Commit ead210f

Browse files
Added weaviate db test (#3483)
1 parent a015e2f commit ead210f

File tree

1 file changed

+150
-163
lines changed

1 file changed

+150
-163
lines changed
Lines changed: 150 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,220 +1,207 @@
1-
# import os
2-
# import uuid
3-
# import httpx
4-
# import unittest
5-
# from unittest.mock import MagicMock, patch
1+
import os
2+
import uuid
3+
import httpx
4+
import unittest
5+
from unittest.mock import MagicMock, patch
66

7-
# import dotenv
8-
# import weaviate
9-
# from weaviate.classes.query import MetadataQuery, Filter
10-
# from weaviate.exceptions import UnexpectedStatusCodeException
7+
import dotenv
8+
import weaviate
9+
from weaviate.exceptions import UnexpectedStatusCodeException
1110

12-
# from mem0.vector_stores.weaviate import Weaviate, OutputData
11+
from mem0.vector_stores.weaviate import Weaviate
1312

1413

15-
# class TestWeaviateDB(unittest.TestCase):
16-
# @classmethod
17-
# def setUpClass(cls):
18-
# dotenv.load_dotenv()
14+
class TestWeaviateDB(unittest.TestCase):
15+
@classmethod
16+
def setUpClass(cls):
17+
dotenv.load_dotenv()
1918

20-
# cls.original_env = {
21-
# 'WEAVIATE_CLUSTER_URL': os.getenv('WEAVIATE_CLUSTER_URL', 'http://localhost:8080'),
22-
# 'WEAVIATE_API_KEY': os.getenv('WEAVIATE_API_KEY', 'test_api_key'),
23-
# }
19+
cls.original_env = {
20+
"WEAVIATE_CLUSTER_URL": os.getenv("WEAVIATE_CLUSTER_URL", "http://localhost:8080"),
21+
"WEAVIATE_API_KEY": os.getenv("WEAVIATE_API_KEY", "test_api_key"),
22+
}
2423

25-
# os.environ['WEAVIATE_CLUSTER_URL'] = 'http://localhost:8080'
26-
# os.environ['WEAVIATE_API_KEY'] = 'test_api_key'
24+
os.environ["WEAVIATE_CLUSTER_URL"] = "http://localhost:8080"
25+
os.environ["WEAVIATE_API_KEY"] = "test_api_key"
2726

28-
# def setUp(self):
29-
# self.client_mock = MagicMock(spec=weaviate.WeaviateClient)
30-
# self.client_mock.collections = MagicMock()
31-
# self.client_mock.collections.exists.return_value = False
32-
# self.client_mock.collections.create.return_value = None
33-
# self.client_mock.collections.delete.return_value = None
27+
def setUp(self):
28+
self.client_mock = MagicMock(spec=weaviate.WeaviateClient)
29+
self.client_mock.collections = MagicMock()
30+
self.client_mock.collections.exists.return_value = False
31+
self.client_mock.collections.create.return_value = None
32+
self.client_mock.collections.delete.return_value = None
3433

35-
# patcher = patch('mem0.vector_stores.weaviate.weaviate.connect_to_local', return_value=self.client_mock)
36-
# self.mock_weaviate = patcher.start()
37-
# self.addCleanup(patcher.stop)
34+
patcher = patch("mem0.vector_stores.weaviate.weaviate.connect_to_local", return_value=self.client_mock)
35+
self.mock_weaviate = patcher.start()
36+
self.addCleanup(patcher.stop)
3837

39-
# self.weaviate_db = Weaviate(
40-
# collection_name="test_collection",
41-
# embedding_model_dims=1536,
42-
# cluster_url=os.getenv('WEAVIATE_CLUSTER_URL'),
43-
# auth_client_secret=os.getenv('WEAVIATE_API_KEY'),
44-
# additional_headers={"X-OpenAI-Api-Key": "test_key"},
45-
# )
38+
self.weaviate_db = Weaviate(
39+
collection_name="test_collection",
40+
embedding_model_dims=1536,
41+
cluster_url=os.getenv("WEAVIATE_CLUSTER_URL"),
42+
auth_client_secret=os.getenv("WEAVIATE_API_KEY"),
43+
additional_headers={"X-OpenAI-Api-Key": "test_key"},
44+
)
4645

47-
# self.client_mock.reset_mock()
46+
self.client_mock.reset_mock()
4847

49-
# @classmethod
50-
# def tearDownClass(cls):
51-
# for key, value in cls.original_env.items():
52-
# if value is not None:
53-
# os.environ[key] = value
54-
# else:
55-
# os.environ.pop(key, None)
48+
@classmethod
49+
def tearDownClass(cls):
50+
for key, value in cls.original_env.items():
51+
if value is not None:
52+
os.environ[key] = value
53+
else:
54+
os.environ.pop(key, None)
5655

57-
# def tearDown(self):
58-
# self.client_mock.reset_mock()
56+
def tearDown(self):
57+
self.client_mock.reset_mock()
5958

60-
# def test_create_col(self):
61-
# self.client_mock.collections.exists.return_value = False
62-
# self.weaviate_db.create_col(vector_size=1536)
59+
def test_create_col(self):
60+
self.client_mock.collections.exists.return_value = False
61+
self.weaviate_db.create_col(vector_size=1536)
6362

63+
self.client_mock.collections.create.assert_called_once()
6464

65-
# self.client_mock.collections.create.assert_called_once()
65+
self.client_mock.reset_mock()
6666

67+
self.client_mock.collections.exists.return_value = True
68+
self.weaviate_db.create_col(vector_size=1536)
6769

68-
# self.client_mock.reset_mock()
70+
self.client_mock.collections.create.assert_not_called()
6971

70-
# self.client_mock.collections.exists.return_value = True
71-
# self.weaviate_db.create_col(vector_size=1536)
72+
def test_insert(self):
73+
self.client_mock.batch = MagicMock()
7274

73-
# self.client_mock.collections.create.assert_not_called()
75+
self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
7476

75-
# def test_insert(self):
76-
# self.client_mock.batch = MagicMock()
77+
self.client_mock.collections.get.return_value.data.insert_many.return_value = {
78+
"results": [{"id": "id1"}, {"id": "id2"}]
79+
}
7780

78-
# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock()
81+
vectors = [[0.1] * 1536, [0.2] * 1536]
82+
payloads = [{"key1": "value1"}, {"key2": "value2"}]
83+
ids = [str(uuid.uuid4()), str(uuid.uuid4())]
7984

80-
# self.client_mock.collections.get.return_value.data.insert_many.return_value = {
81-
# "results": [{"id": "id1"}, {"id": "id2"}]
82-
# }
85+
self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
8386

84-
# vectors = [[0.1] * 1536, [0.2] * 1536]
85-
# payloads = [{"key1": "value1"}, {"key2": "value2"}]
86-
# ids = [str(uuid.uuid4()), str(uuid.uuid4())]
87+
def test_get(self):
88+
valid_uuid = str(uuid.uuid4())
8789

88-
# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids)
90+
mock_response = MagicMock()
91+
mock_response.properties = {
92+
"hash": "abc123",
93+
"created_at": "2025-03-08T12:00:00Z",
94+
"updated_at": "2025-03-08T13:00:00Z",
95+
"user_id": "user_123",
96+
"agent_id": "agent_456",
97+
"run_id": "run_789",
98+
"data": {"key": "value"},
99+
"category": "test",
100+
}
101+
mock_response.uuid = valid_uuid
89102

90-
# def test_get(self):
91-
# valid_uuid = str(uuid.uuid4())
103+
self.client_mock.collections.get.return_value.query.fetch_object_by_id.return_value = mock_response
92104

93-
# mock_response = MagicMock()
94-
# mock_response.properties = {
95-
# "hash": "abc123",
96-
# "created_at": "2025-03-08T12:00:00Z",
97-
# "updated_at": "2025-03-08T13:00:00Z",
98-
# "user_id": "user_123",
99-
# "agent_id": "agent_456",
100-
# "run_id": "run_789",
101-
# "data": {"key": "value"},
102-
# "category": "test",
103-
# }
104-
# mock_response.uuid = valid_uuid
105+
result = self.weaviate_db.get(vector_id=valid_uuid)
105106

106-
# self.client_mock.collections.get.return_value.query.fetch_object_by_id.return_value = mock_response
107+
assert result.id == valid_uuid
107108

108-
# result = self.weaviate_db.get(vector_id=valid_uuid)
109+
expected_payload = mock_response.properties.copy()
110+
expected_payload["id"] = valid_uuid
109111

110-
# assert result.id == valid_uuid
112+
assert result.payload == expected_payload
111113

112-
# expected_payload = mock_response.properties.copy()
113-
# expected_payload["id"] = valid_uuid
114+
def test_get_not_found(self):
115+
mock_response = httpx.Response(status_code=404, json={"error": "Not found"})
114116

115-
# assert result.payload == expected_payload
117+
self.client_mock.collections.get.return_value.data.get_by_id.side_effect = UnexpectedStatusCodeException(
118+
"Not found", mock_response
119+
)
116120

121+
def test_search(self):
122+
mock_objects = [{"uuid": "id1", "properties": {"key1": "value1"}, "metadata": {"distance": 0.2}}]
117123

118-
# def test_get_not_found(self):
119-
# mock_response = httpx.Response(status_code=404, json={"error": "Not found"})
124+
mock_response = MagicMock()
125+
mock_response.objects = []
120126

121-
# self.client_mock.collections.get.return_value.data.get_by_id.side_effect = UnexpectedStatusCodeException(
122-
# "Not found", mock_response
123-
# )
127+
for obj in mock_objects:
128+
mock_obj = MagicMock()
129+
mock_obj.uuid = obj["uuid"]
130+
mock_obj.properties = obj["properties"]
131+
mock_obj.metadata = MagicMock()
132+
mock_obj.metadata.distance = obj["metadata"]["distance"]
133+
mock_response.objects.append(mock_obj)
124134

135+
mock_hybrid = MagicMock()
136+
self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
137+
mock_hybrid.return_value = mock_response
125138

126-
# def test_search(self):
127-
# mock_objects = [
128-
# {
129-
# "uuid": "id1",
130-
# "properties": {"key1": "value1"},
131-
# "metadata": {"distance": 0.2}
132-
# }
133-
# ]
139+
vectors = [[0.1] * 1536]
140+
results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
134141

135-
# mock_response = MagicMock()
136-
# mock_response.objects = []
142+
mock_hybrid.assert_called_once()
137143

138-
# for obj in mock_objects:
139-
# mock_obj = MagicMock()
140-
# mock_obj.uuid = obj["uuid"]
141-
# mock_obj.properties = obj["properties"]
142-
# mock_obj.metadata = MagicMock()
143-
# mock_obj.metadata.distance = obj["metadata"]["distance"]
144-
# mock_response.objects.append(mock_obj)
144+
self.assertEqual(len(results), 1)
145+
self.assertEqual(results[0].id, "id1")
146+
self.assertEqual(results[0].score, 0.8)
145147

146-
# mock_hybrid = MagicMock()
147-
# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid
148-
# mock_hybrid.return_value = mock_response
148+
def test_delete(self):
149+
self.weaviate_db.delete(vector_id="id1")
149150

150-
# vectors = [[0.1] * 1536]
151-
# results = self.weaviate_db.search(query="", vectors=vectors, limit=5)
151+
self.client_mock.collections.get.return_value.data.delete_by_id.assert_called_once_with("id1")
152152

153-
# mock_hybrid.assert_called_once()
153+
def test_list(self):
154+
mock_objects = []
154155

155-
# self.assertEqual(len(results), 1)
156-
# self.assertEqual(results[0].id, "id1")
157-
# self.assertEqual(results[0].score, 0.8)
156+
mock_obj1 = MagicMock()
157+
mock_obj1.uuid = "id1"
158+
mock_obj1.properties = {"key1": "value1"}
159+
mock_objects.append(mock_obj1)
158160

159-
# def test_delete(self):
160-
# self.weaviate_db.delete(vector_id="id1")
161+
mock_obj2 = MagicMock()
162+
mock_obj2.uuid = "id2"
163+
mock_obj2.properties = {"key2": "value2"}
164+
mock_objects.append(mock_obj2)
161165

162-
# self.client_mock.collections.get.return_value.data.delete_by_id.assert_called_once_with("id1")
166+
mock_response = MagicMock()
167+
mock_response.objects = mock_objects
163168

164-
# def test_list(self):
165-
# mock_objects = []
169+
mock_fetch = MagicMock()
170+
self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
171+
mock_fetch.return_value = mock_response
166172

167-
# mock_obj1 = MagicMock()
168-
# mock_obj1.uuid = "id1"
169-
# mock_obj1.properties = {"key1": "value1"}
170-
# mock_objects.append(mock_obj1)
173+
results = self.weaviate_db.list(limit=10)
171174

172-
# mock_obj2 = MagicMock()
173-
# mock_obj2.uuid = "id2"
174-
# mock_obj2.properties = {"key2": "value2"}
175-
# mock_objects.append(mock_obj2)
175+
mock_fetch.assert_called_once()
176176

177-
# mock_response = MagicMock()
178-
# mock_response.objects = mock_objects
177+
# Verify results
178+
self.assertEqual(len(results), 1)
179+
self.assertEqual(len(results[0]), 2)
180+
self.assertEqual(results[0][0].id, "id1")
181+
self.assertEqual(results[0][0].payload["key1"], "value1")
182+
self.assertEqual(results[0][1].id, "id2")
183+
self.assertEqual(results[0][1].payload["key2"], "value2")
179184

180-
# mock_fetch = MagicMock()
181-
# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch
182-
# mock_fetch.return_value = mock_response
185+
def test_list_cols(self):
186+
mock_collection1 = MagicMock()
187+
mock_collection1.name = "collection1"
183188

184-
# results = self.weaviate_db.list(limit=10)
189+
mock_collection2 = MagicMock()
190+
mock_collection2.name = "collection2"
191+
self.client_mock.collections.list_all.return_value = [mock_collection1, mock_collection2]
185192

186-
# mock_fetch.assert_called_once()
193+
result = self.weaviate_db.list_cols()
194+
expected = {"collections": [{"name": "collection1"}, {"name": "collection2"}]}
187195

188-
# # Verify results
189-
# self.assertEqual(len(results), 1)
190-
# self.assertEqual(len(results[0]), 2)
191-
# self.assertEqual(results[0][0].id, "id1")
192-
# self.assertEqual(results[0][0].payload["key1"], "value1")
193-
# self.assertEqual(results[0][1].id, "id2")
194-
# self.assertEqual(results[0][1].payload["key2"], "value2")
196+
assert result == expected
195197

198+
self.client_mock.collections.list_all.assert_called_once()
196199

197-
# def test_list_cols(self):
198-
# mock_collection1 = MagicMock()
199-
# mock_collection1.name = "collection1"
200+
def test_delete_col(self):
201+
self.weaviate_db.delete_col()
200202

201-
# mock_collection2 = MagicMock()
202-
# mock_collection2.name = "collection2"
203-
# self.client_mock.collections.list_all.return_value = [mock_collection1, mock_collection2]
203+
self.client_mock.collections.delete.assert_called_once_with("test_collection")
204204

205-
# result = self.weaviate_db.list_cols()
206-
# expected = {"collections": [{"name": "collection1"}, {"name": "collection2"}]}
207205

208-
# assert result == expected
209-
210-
# self.client_mock.collections.list_all.assert_called_once()
211-
212-
213-
# def test_delete_col(self):
214-
# self.weaviate_db.delete_col()
215-
216-
# self.client_mock.collections.delete.assert_called_once_with("test_collection")
217-
218-
219-
# if __name__ == '__main__':
220-
# unittest.main()
206+
if __name__ == "__main__":
207+
unittest.main()

0 commit comments

Comments
 (0)