|
1 | | -# import os |
2 | | -# import uuid |
3 | | -# import httpx |
4 | | -# import unittest |
5 | | -# from unittest.mock import MagicMock, patch |
| 1 | +import os |
| 2 | +import uuid |
| 3 | +import httpx |
| 4 | +import unittest |
| 5 | +from unittest.mock import MagicMock, patch |
6 | 6 |
|
7 | | -# import dotenv |
8 | | -# import weaviate |
9 | | -# from weaviate.classes.query import MetadataQuery, Filter |
10 | | -# from weaviate.exceptions import UnexpectedStatusCodeException |
| 7 | +import dotenv |
| 8 | +import weaviate |
| 9 | +from weaviate.exceptions import UnexpectedStatusCodeException |
11 | 10 |
|
12 | | -# from mem0.vector_stores.weaviate import Weaviate, OutputData |
| 11 | +from mem0.vector_stores.weaviate import Weaviate |
13 | 12 |
|
14 | 13 |
|
15 | | -# class TestWeaviateDB(unittest.TestCase): |
16 | | -# @classmethod |
17 | | -# def setUpClass(cls): |
18 | | -# dotenv.load_dotenv() |
| 14 | +class TestWeaviateDB(unittest.TestCase): |
| 15 | + @classmethod |
| 16 | + def setUpClass(cls): |
| 17 | + dotenv.load_dotenv() |
19 | 18 |
|
20 | | -# cls.original_env = { |
21 | | -# 'WEAVIATE_CLUSTER_URL': os.getenv('WEAVIATE_CLUSTER_URL', 'http://localhost:8080'), |
22 | | -# 'WEAVIATE_API_KEY': os.getenv('WEAVIATE_API_KEY', 'test_api_key'), |
23 | | -# } |
| 19 | + cls.original_env = { |
| 20 | + "WEAVIATE_CLUSTER_URL": os.getenv("WEAVIATE_CLUSTER_URL", "http://localhost:8080"), |
| 21 | + "WEAVIATE_API_KEY": os.getenv("WEAVIATE_API_KEY", "test_api_key"), |
| 22 | + } |
24 | 23 |
|
25 | | -# os.environ['WEAVIATE_CLUSTER_URL'] = 'http://localhost:8080' |
26 | | -# os.environ['WEAVIATE_API_KEY'] = 'test_api_key' |
| 24 | + os.environ["WEAVIATE_CLUSTER_URL"] = "http://localhost:8080" |
| 25 | + os.environ["WEAVIATE_API_KEY"] = "test_api_key" |
27 | 26 |
|
28 | | -# def setUp(self): |
29 | | -# self.client_mock = MagicMock(spec=weaviate.WeaviateClient) |
30 | | -# self.client_mock.collections = MagicMock() |
31 | | -# self.client_mock.collections.exists.return_value = False |
32 | | -# self.client_mock.collections.create.return_value = None |
33 | | -# self.client_mock.collections.delete.return_value = None |
| 27 | + def setUp(self): |
| 28 | + self.client_mock = MagicMock(spec=weaviate.WeaviateClient) |
| 29 | + self.client_mock.collections = MagicMock() |
| 30 | + self.client_mock.collections.exists.return_value = False |
| 31 | + self.client_mock.collections.create.return_value = None |
| 32 | + self.client_mock.collections.delete.return_value = None |
34 | 33 |
|
35 | | -# patcher = patch('mem0.vector_stores.weaviate.weaviate.connect_to_local', return_value=self.client_mock) |
36 | | -# self.mock_weaviate = patcher.start() |
37 | | -# self.addCleanup(patcher.stop) |
| 34 | + patcher = patch("mem0.vector_stores.weaviate.weaviate.connect_to_local", return_value=self.client_mock) |
| 35 | + self.mock_weaviate = patcher.start() |
| 36 | + self.addCleanup(patcher.stop) |
38 | 37 |
|
39 | | -# self.weaviate_db = Weaviate( |
40 | | -# collection_name="test_collection", |
41 | | -# embedding_model_dims=1536, |
42 | | -# cluster_url=os.getenv('WEAVIATE_CLUSTER_URL'), |
43 | | -# auth_client_secret=os.getenv('WEAVIATE_API_KEY'), |
44 | | -# additional_headers={"X-OpenAI-Api-Key": "test_key"}, |
45 | | -# ) |
| 38 | + self.weaviate_db = Weaviate( |
| 39 | + collection_name="test_collection", |
| 40 | + embedding_model_dims=1536, |
| 41 | + cluster_url=os.getenv("WEAVIATE_CLUSTER_URL"), |
| 42 | + auth_client_secret=os.getenv("WEAVIATE_API_KEY"), |
| 43 | + additional_headers={"X-OpenAI-Api-Key": "test_key"}, |
| 44 | + ) |
46 | 45 |
|
47 | | -# self.client_mock.reset_mock() |
| 46 | + self.client_mock.reset_mock() |
48 | 47 |
|
49 | | -# @classmethod |
50 | | -# def tearDownClass(cls): |
51 | | -# for key, value in cls.original_env.items(): |
52 | | -# if value is not None: |
53 | | -# os.environ[key] = value |
54 | | -# else: |
55 | | -# os.environ.pop(key, None) |
| 48 | + @classmethod |
| 49 | + def tearDownClass(cls): |
| 50 | + for key, value in cls.original_env.items(): |
| 51 | + if value is not None: |
| 52 | + os.environ[key] = value |
| 53 | + else: |
| 54 | + os.environ.pop(key, None) |
56 | 55 |
|
57 | | -# def tearDown(self): |
58 | | -# self.client_mock.reset_mock() |
| 56 | + def tearDown(self): |
| 57 | + self.client_mock.reset_mock() |
59 | 58 |
|
60 | | -# def test_create_col(self): |
61 | | -# self.client_mock.collections.exists.return_value = False |
62 | | -# self.weaviate_db.create_col(vector_size=1536) |
| 59 | + def test_create_col(self): |
| 60 | + self.client_mock.collections.exists.return_value = False |
| 61 | + self.weaviate_db.create_col(vector_size=1536) |
63 | 62 |
|
| 63 | + self.client_mock.collections.create.assert_called_once() |
64 | 64 |
|
65 | | -# self.client_mock.collections.create.assert_called_once() |
| 65 | + self.client_mock.reset_mock() |
66 | 66 |
|
| 67 | + self.client_mock.collections.exists.return_value = True |
| 68 | + self.weaviate_db.create_col(vector_size=1536) |
67 | 69 |
|
68 | | -# self.client_mock.reset_mock() |
| 70 | + self.client_mock.collections.create.assert_not_called() |
69 | 71 |
|
70 | | -# self.client_mock.collections.exists.return_value = True |
71 | | -# self.weaviate_db.create_col(vector_size=1536) |
| 72 | + def test_insert(self): |
| 73 | + self.client_mock.batch = MagicMock() |
72 | 74 |
|
73 | | -# self.client_mock.collections.create.assert_not_called() |
| 75 | + self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock() |
74 | 76 |
|
75 | | -# def test_insert(self): |
76 | | -# self.client_mock.batch = MagicMock() |
| 77 | + self.client_mock.collections.get.return_value.data.insert_many.return_value = { |
| 78 | + "results": [{"id": "id1"}, {"id": "id2"}] |
| 79 | + } |
77 | 80 |
|
78 | | -# self.client_mock.batch.fixed_size.return_value.__enter__.return_value = MagicMock() |
| 81 | + vectors = [[0.1] * 1536, [0.2] * 1536] |
| 82 | + payloads = [{"key1": "value1"}, {"key2": "value2"}] |
| 83 | + ids = [str(uuid.uuid4()), str(uuid.uuid4())] |
79 | 84 |
|
80 | | -# self.client_mock.collections.get.return_value.data.insert_many.return_value = { |
81 | | -# "results": [{"id": "id1"}, {"id": "id2"}] |
82 | | -# } |
| 85 | + self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids) |
83 | 86 |
|
84 | | -# vectors = [[0.1] * 1536, [0.2] * 1536] |
85 | | -# payloads = [{"key1": "value1"}, {"key2": "value2"}] |
86 | | -# ids = [str(uuid.uuid4()), str(uuid.uuid4())] |
| 87 | + def test_get(self): |
| 88 | + valid_uuid = str(uuid.uuid4()) |
87 | 89 |
|
88 | | -# results = self.weaviate_db.insert(vectors=vectors, payloads=payloads, ids=ids) |
| 90 | + mock_response = MagicMock() |
| 91 | + mock_response.properties = { |
| 92 | + "hash": "abc123", |
| 93 | + "created_at": "2025-03-08T12:00:00Z", |
| 94 | + "updated_at": "2025-03-08T13:00:00Z", |
| 95 | + "user_id": "user_123", |
| 96 | + "agent_id": "agent_456", |
| 97 | + "run_id": "run_789", |
| 98 | + "data": {"key": "value"}, |
| 99 | + "category": "test", |
| 100 | + } |
| 101 | + mock_response.uuid = valid_uuid |
89 | 102 |
|
90 | | -# def test_get(self): |
91 | | -# valid_uuid = str(uuid.uuid4()) |
| 103 | + self.client_mock.collections.get.return_value.query.fetch_object_by_id.return_value = mock_response |
92 | 104 |
|
93 | | -# mock_response = MagicMock() |
94 | | -# mock_response.properties = { |
95 | | -# "hash": "abc123", |
96 | | -# "created_at": "2025-03-08T12:00:00Z", |
97 | | -# "updated_at": "2025-03-08T13:00:00Z", |
98 | | -# "user_id": "user_123", |
99 | | -# "agent_id": "agent_456", |
100 | | -# "run_id": "run_789", |
101 | | -# "data": {"key": "value"}, |
102 | | -# "category": "test", |
103 | | -# } |
104 | | -# mock_response.uuid = valid_uuid |
| 105 | + result = self.weaviate_db.get(vector_id=valid_uuid) |
105 | 106 |
|
106 | | -# self.client_mock.collections.get.return_value.query.fetch_object_by_id.return_value = mock_response |
| 107 | + assert result.id == valid_uuid |
107 | 108 |
|
108 | | -# result = self.weaviate_db.get(vector_id=valid_uuid) |
| 109 | + expected_payload = mock_response.properties.copy() |
| 110 | + expected_payload["id"] = valid_uuid |
109 | 111 |
|
110 | | -# assert result.id == valid_uuid |
| 112 | + assert result.payload == expected_payload |
111 | 113 |
|
112 | | -# expected_payload = mock_response.properties.copy() |
113 | | -# expected_payload["id"] = valid_uuid |
| 114 | + def test_get_not_found(self): |
| 115 | + mock_response = httpx.Response(status_code=404, json={"error": "Not found"}) |
114 | 116 |
|
115 | | -# assert result.payload == expected_payload |
| 117 | + self.client_mock.collections.get.return_value.data.get_by_id.side_effect = UnexpectedStatusCodeException( |
| 118 | + "Not found", mock_response |
| 119 | + ) |
116 | 120 |
|
| 121 | + def test_search(self): |
| 122 | + mock_objects = [{"uuid": "id1", "properties": {"key1": "value1"}, "metadata": {"distance": 0.2}}] |
117 | 123 |
|
118 | | -# def test_get_not_found(self): |
119 | | -# mock_response = httpx.Response(status_code=404, json={"error": "Not found"}) |
| 124 | + mock_response = MagicMock() |
| 125 | + mock_response.objects = [] |
120 | 126 |
|
121 | | -# self.client_mock.collections.get.return_value.data.get_by_id.side_effect = UnexpectedStatusCodeException( |
122 | | -# "Not found", mock_response |
123 | | -# ) |
| 127 | + for obj in mock_objects: |
| 128 | + mock_obj = MagicMock() |
| 129 | + mock_obj.uuid = obj["uuid"] |
| 130 | + mock_obj.properties = obj["properties"] |
| 131 | + mock_obj.metadata = MagicMock() |
| 132 | + mock_obj.metadata.distance = obj["metadata"]["distance"] |
| 133 | + mock_response.objects.append(mock_obj) |
124 | 134 |
|
| 135 | + mock_hybrid = MagicMock() |
| 136 | + self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid |
| 137 | + mock_hybrid.return_value = mock_response |
125 | 138 |
|
126 | | -# def test_search(self): |
127 | | -# mock_objects = [ |
128 | | -# { |
129 | | -# "uuid": "id1", |
130 | | -# "properties": {"key1": "value1"}, |
131 | | -# "metadata": {"distance": 0.2} |
132 | | -# } |
133 | | -# ] |
| 139 | + vectors = [[0.1] * 1536] |
| 140 | + results = self.weaviate_db.search(query="", vectors=vectors, limit=5) |
134 | 141 |
|
135 | | -# mock_response = MagicMock() |
136 | | -# mock_response.objects = [] |
| 142 | + mock_hybrid.assert_called_once() |
137 | 143 |
|
138 | | -# for obj in mock_objects: |
139 | | -# mock_obj = MagicMock() |
140 | | -# mock_obj.uuid = obj["uuid"] |
141 | | -# mock_obj.properties = obj["properties"] |
142 | | -# mock_obj.metadata = MagicMock() |
143 | | -# mock_obj.metadata.distance = obj["metadata"]["distance"] |
144 | | -# mock_response.objects.append(mock_obj) |
| 144 | + self.assertEqual(len(results), 1) |
| 145 | + self.assertEqual(results[0].id, "id1") |
| 146 | + self.assertEqual(results[0].score, 0.8) |
145 | 147 |
|
146 | | -# mock_hybrid = MagicMock() |
147 | | -# self.client_mock.collections.get.return_value.query.hybrid = mock_hybrid |
148 | | -# mock_hybrid.return_value = mock_response |
| 148 | + def test_delete(self): |
| 149 | + self.weaviate_db.delete(vector_id="id1") |
149 | 150 |
|
150 | | -# vectors = [[0.1] * 1536] |
151 | | -# results = self.weaviate_db.search(query="", vectors=vectors, limit=5) |
| 151 | + self.client_mock.collections.get.return_value.data.delete_by_id.assert_called_once_with("id1") |
152 | 152 |
|
153 | | -# mock_hybrid.assert_called_once() |
| 153 | + def test_list(self): |
| 154 | + mock_objects = [] |
154 | 155 |
|
155 | | -# self.assertEqual(len(results), 1) |
156 | | -# self.assertEqual(results[0].id, "id1") |
157 | | -# self.assertEqual(results[0].score, 0.8) |
| 156 | + mock_obj1 = MagicMock() |
| 157 | + mock_obj1.uuid = "id1" |
| 158 | + mock_obj1.properties = {"key1": "value1"} |
| 159 | + mock_objects.append(mock_obj1) |
158 | 160 |
|
159 | | -# def test_delete(self): |
160 | | -# self.weaviate_db.delete(vector_id="id1") |
| 161 | + mock_obj2 = MagicMock() |
| 162 | + mock_obj2.uuid = "id2" |
| 163 | + mock_obj2.properties = {"key2": "value2"} |
| 164 | + mock_objects.append(mock_obj2) |
161 | 165 |
|
162 | | -# self.client_mock.collections.get.return_value.data.delete_by_id.assert_called_once_with("id1") |
| 166 | + mock_response = MagicMock() |
| 167 | + mock_response.objects = mock_objects |
163 | 168 |
|
164 | | -# def test_list(self): |
165 | | -# mock_objects = [] |
| 169 | + mock_fetch = MagicMock() |
| 170 | + self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch |
| 171 | + mock_fetch.return_value = mock_response |
166 | 172 |
|
167 | | -# mock_obj1 = MagicMock() |
168 | | -# mock_obj1.uuid = "id1" |
169 | | -# mock_obj1.properties = {"key1": "value1"} |
170 | | -# mock_objects.append(mock_obj1) |
| 173 | + results = self.weaviate_db.list(limit=10) |
171 | 174 |
|
172 | | -# mock_obj2 = MagicMock() |
173 | | -# mock_obj2.uuid = "id2" |
174 | | -# mock_obj2.properties = {"key2": "value2"} |
175 | | -# mock_objects.append(mock_obj2) |
| 175 | + mock_fetch.assert_called_once() |
176 | 176 |
|
177 | | -# mock_response = MagicMock() |
178 | | -# mock_response.objects = mock_objects |
| 177 | + # Verify results |
| 178 | + self.assertEqual(len(results), 1) |
| 179 | + self.assertEqual(len(results[0]), 2) |
| 180 | + self.assertEqual(results[0][0].id, "id1") |
| 181 | + self.assertEqual(results[0][0].payload["key1"], "value1") |
| 182 | + self.assertEqual(results[0][1].id, "id2") |
| 183 | + self.assertEqual(results[0][1].payload["key2"], "value2") |
179 | 184 |
|
180 | | -# mock_fetch = MagicMock() |
181 | | -# self.client_mock.collections.get.return_value.query.fetch_objects = mock_fetch |
182 | | -# mock_fetch.return_value = mock_response |
| 185 | + def test_list_cols(self): |
| 186 | + mock_collection1 = MagicMock() |
| 187 | + mock_collection1.name = "collection1" |
183 | 188 |
|
184 | | -# results = self.weaviate_db.list(limit=10) |
| 189 | + mock_collection2 = MagicMock() |
| 190 | + mock_collection2.name = "collection2" |
| 191 | + self.client_mock.collections.list_all.return_value = [mock_collection1, mock_collection2] |
185 | 192 |
|
186 | | -# mock_fetch.assert_called_once() |
| 193 | + result = self.weaviate_db.list_cols() |
| 194 | + expected = {"collections": [{"name": "collection1"}, {"name": "collection2"}]} |
187 | 195 |
|
188 | | -# # Verify results |
189 | | -# self.assertEqual(len(results), 1) |
190 | | -# self.assertEqual(len(results[0]), 2) |
191 | | -# self.assertEqual(results[0][0].id, "id1") |
192 | | -# self.assertEqual(results[0][0].payload["key1"], "value1") |
193 | | -# self.assertEqual(results[0][1].id, "id2") |
194 | | -# self.assertEqual(results[0][1].payload["key2"], "value2") |
| 196 | + assert result == expected |
195 | 197 |
|
| 198 | + self.client_mock.collections.list_all.assert_called_once() |
196 | 199 |
|
197 | | -# def test_list_cols(self): |
198 | | -# mock_collection1 = MagicMock() |
199 | | -# mock_collection1.name = "collection1" |
| 200 | + def test_delete_col(self): |
| 201 | + self.weaviate_db.delete_col() |
200 | 202 |
|
201 | | -# mock_collection2 = MagicMock() |
202 | | -# mock_collection2.name = "collection2" |
203 | | -# self.client_mock.collections.list_all.return_value = [mock_collection1, mock_collection2] |
| 203 | + self.client_mock.collections.delete.assert_called_once_with("test_collection") |
204 | 204 |
|
205 | | -# result = self.weaviate_db.list_cols() |
206 | | -# expected = {"collections": [{"name": "collection1"}, {"name": "collection2"}]} |
207 | 205 |
|
208 | | -# assert result == expected |
209 | | - |
210 | | -# self.client_mock.collections.list_all.assert_called_once() |
211 | | - |
212 | | - |
213 | | -# def test_delete_col(self): |
214 | | -# self.weaviate_db.delete_col() |
215 | | - |
216 | | -# self.client_mock.collections.delete.assert_called_once_with("test_collection") |
217 | | - |
218 | | - |
219 | | -# if __name__ == '__main__': |
220 | | -# unittest.main() |
| 206 | +if __name__ == "__main__": |
| 207 | + unittest.main() |
0 commit comments