forked from agentscope-ai/agentscope-runtime
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_reme_task_memory_service.py
More file actions
506 lines (397 loc) · 16.1 KB
/
test_reme_task_memory_service.py
File metadata and controls
506 lines (397 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name, protected-access, unused-argument,
# pylint: disable=wrong-import-position
# flake8: noqa: E402
import pytest
import pytest_asyncio
from agentscope_runtime.engine.services.memory import (
ReMeTaskMemoryService,
)
from agentscope_runtime.engine.schemas.agent_schemas import (
Message,
MessageType,
TextContent,
ContentType,
Role,
)
def create_message(role: str, content: str) -> Message:
"""Helper function to create a proper Message object."""
return Message(
type=MessageType.MESSAGE,
role=role,
content=[TextContent(type=ContentType.TEXT, text=content)],
)
@pytest_asyncio.fixture
async def mock_task_memory_service(mocker):
"""Mock the TaskMemoryService from reme_ai."""
mock_class = mocker.patch(
"agentscope_runtime.engine.services.memory."
"reme_task_memory_service.TaskMemoryService",
)
instance = mock_class.return_value
instance.start = mocker.AsyncMock()
instance.stop = mocker.AsyncMock()
instance.health = mocker.AsyncMock(return_value=True)
instance.add_memory = mocker.AsyncMock()
instance.search_memory = mocker.AsyncMock(return_value=[])
instance.list_memory = mocker.AsyncMock(return_value=[])
instance.delete_memory = mocker.AsyncMock()
yield instance
@pytest.fixture
def env_vars(monkeypatch):
"""Set up required environment variables."""
monkeypatch.setenv("FLOW_EMBEDDING_API_KEY", "test-embedding-key")
monkeypatch.setenv(
"FLOW_EMBEDDING_BASE_URL",
"https://test-embedding.com/v1",
)
monkeypatch.setenv("FLOW_LLM_API_KEY", "test-llm-key")
monkeypatch.setenv("FLOW_LLM_BASE_URL", "https://test-llm.com/v1")
@pytest_asyncio.fixture
async def memory_service(env_vars, mock_task_memory_service):
service = ReMeTaskMemoryService()
await service.start()
yield service
await service.stop()
@pytest.mark.asyncio
async def test_missing_env_variables():
with pytest.raises(ValueError, match="FLOW_EMBEDDING_API_KEY is not set"):
ReMeTaskMemoryService()
@pytest.mark.asyncio
async def test_service_lifecycle(memory_service: ReMeTaskMemoryService): # type: ignore[valid-type]
"""Test service start, stop, and health check."""
assert await memory_service.health() is True
await memory_service.stop()
# After stopping, we can't really test health since it's mocked
@pytest.mark.asyncio
async def test_transform_message():
"""Test message transformation functionality."""
# Test message with text content
message = create_message(Role.USER, "hello world")
transformed = ReMeTaskMemoryService.transform_message(message)
assert transformed["role"] == Role.USER
assert transformed["content"] == "hello world"
# Test message with no content
empty_message = Message(
type=MessageType.MESSAGE,
role=Role.USER,
content=[],
)
transformed_empty = ReMeTaskMemoryService.transform_message(
empty_message,
)
assert transformed_empty["role"] == Role.USER
assert transformed_empty["content"] is None
# Test message with None content
none_message = Message(
type=MessageType.MESSAGE,
role=Role.USER,
content=None,
)
transformed_none = ReMeTaskMemoryService.transform_message(
none_message,
)
assert transformed_none["role"] == Role.USER
assert transformed_none["content"] is None
@pytest.mark.asyncio
async def test_transform_messages(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test transformation of multiple messages."""
messages = [
create_message(Role.USER, "first message"),
create_message(Role.ASSISTANT, "second message"),
create_message(Role.USER, "third message"),
]
transformed = memory_service.transform_messages(messages)
assert len(transformed) == 3
assert transformed[0]["role"] == Role.USER
assert transformed[0]["content"] == "first message"
assert transformed[1]["role"] == Role.ASSISTANT
assert transformed[1]["content"] == "second message"
assert transformed[2]["role"] == Role.USER
assert transformed[2]["content"] == "third message"
@pytest.mark.asyncio
async def test_add_memory_no_session(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test adding memory without session ID."""
user_id = "user1"
messages = [create_message(Role.USER, "hello world")]
await memory_service.add_memory(user_id, messages)
# Verify the underlying service was called with transformed messages
memory_service.service.add_memory.assert_called_once()
call_args = memory_service.service.add_memory.call_args
assert call_args[0][0] == user_id
assert call_args[0][1] == [{"role": Role.USER, "content": "hello world"}]
assert call_args[0][2] is None # session_id
@pytest.mark.asyncio
async def test_add_memory_with_session(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test adding memory with session ID."""
user_id = "user2"
session_id = "session1"
messages = [create_message(Role.USER, "hello from session")]
await memory_service.add_memory(user_id, messages, session_id)
# Verify the underlying service was called correctly
memory_service.service.add_memory.assert_called_once()
call_args = memory_service.service.add_memory.call_args
assert call_args[0][0] == user_id
assert call_args[0][1] == [
{"role": Role.USER, "content": "hello from session"},
]
assert call_args[0][2] == session_id
@pytest.mark.asyncio
async def test_search_memory(memory_service: ReMeTaskMemoryService): # type: ignore[valid-type]
"""Test searching memory."""
user_id = "user3"
messages = [create_message(Role.USER, "search query")]
expected_results = [{"role": "user", "content": "found message"}]
# Configure mock to return expected results
memory_service.service.search_memory.return_value = expected_results
results = await memory_service.search_memory(user_id, messages)
# Verify the underlying service was called correctly
memory_service.service.search_memory.assert_called_once()
call_args = memory_service.service.search_memory.call_args
assert call_args[0][0] == user_id
assert call_args[0][1] == [{"role": Role.USER, "content": "search query"}]
assert call_args[0][2] is None # filters
# Verify results are returned as-is
assert results == expected_results
@pytest.mark.asyncio
async def test_search_memory_with_filters(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test searching memory with filters."""
user_id = "user4"
messages = [create_message(Role.USER, "search with filters")]
filters = {"top_k": 5}
expected_results = [{"role": "user", "content": "filtered result"}]
# Configure mock to return expected results
memory_service.service.search_memory.return_value = expected_results
results = await memory_service.search_memory(user_id, messages, filters)
# Verify the underlying service was called correctly
memory_service.service.search_memory.assert_called_once()
call_args = memory_service.service.search_memory.call_args
assert call_args[0][0] == user_id
assert call_args[0][1] == [
{"role": Role.USER, "content": "search with filters"},
]
assert call_args[0][2] == filters
assert results == expected_results
@pytest.mark.asyncio
async def test_list_memory(memory_service: ReMeTaskMemoryService): # type: ignore[valid-type]
"""Test listing memory."""
user_id = "user5"
expected_results = [
{"role": "user", "content": "message 1"},
{"role": "assistant", "content": "response 1"},
]
# Configure mock to return expected results
memory_service.service.list_memory.return_value = expected_results
results = await memory_service.list_memory(user_id)
# Verify the underlying service was called correctly
memory_service.service.list_memory.assert_called_once_with(user_id, None)
assert results == expected_results
@pytest.mark.asyncio
async def test_list_memory_with_filters(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test listing memory with pagination filters."""
user_id = "user6"
filters = {"page_size": 10, "page_num": 2}
expected_results = [{"role": "user", "content": "page 2 message"}]
# Configure mock to return expected results
memory_service.service.list_memory.return_value = expected_results
results = await memory_service.list_memory(user_id, filters)
# Verify the underlying service was called correctly
memory_service.service.list_memory.assert_called_once_with(
user_id,
filters,
)
assert results == expected_results
@pytest.mark.asyncio
async def test_delete_memory_session(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test deleting memory for a specific session."""
user_id = "user7"
session_id = "session_to_delete"
await memory_service.delete_memory(user_id, session_id)
# Verify the underlying service was called correctly
memory_service.service.delete_memory.assert_called_once_with(
user_id,
session_id,
)
@pytest.mark.asyncio
async def test_delete_memory_user(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test deleting all memory for a user."""
user_id = "user_to_delete"
await memory_service.delete_memory(user_id)
# Verify the underlying service was called correctly
memory_service.service.delete_memory.assert_called_once_with(user_id, None)
@pytest.mark.asyncio
async def test_multiple_messages_transformation(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test adding multiple messages with different content types."""
user_id = "user8"
messages = [
create_message(Role.USER, "first message"),
create_message(Role.ASSISTANT, "assistant response"),
create_message(Role.USER, "follow up question"),
]
await memory_service.add_memory(user_id, messages, "multi_session")
# Verify transformation worked correctly
memory_service.service.add_memory.assert_called_once()
call_args = memory_service.service.add_memory.call_args
transformed_messages = call_args[0][1]
assert len(transformed_messages) == 3
assert transformed_messages[0] == {
"role": Role.USER,
"content": "first message",
}
assert transformed_messages[1] == {
"role": Role.ASSISTANT,
"content": "assistant response",
}
assert transformed_messages[2] == {
"role": Role.USER,
"content": "follow up question",
}
@pytest.mark.asyncio
async def test_empty_messages_list(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test handling empty messages list."""
user_id = "user9"
messages = []
await memory_service.add_memory(user_id, messages)
# Verify the underlying service was still called
memory_service.service.add_memory.assert_called_once()
call_args = memory_service.service.add_memory.call_args
assert call_args[0][1] == []
@pytest.mark.asyncio
async def test_service_error_propagation(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test that errors from the underlying service are propagated."""
user_id = "error_user"
messages = [create_message(Role.USER, "test message")]
# Configure mock to raise an exception
memory_service.service.add_memory.side_effect = RuntimeError(
"Service error",
)
with pytest.raises(RuntimeError, match="Service error"):
await memory_service.add_memory(user_id, messages)
@pytest.mark.asyncio
async def test_health_check_failure(env_vars, mock_task_memory_service):
"""Test health check when service is unhealthy."""
mock_task_memory_service.health.return_value = False
service = ReMeTaskMemoryService()
await service.start()
health_status = await service.health()
assert health_status is False
@pytest.mark.asyncio
async def test_complex_message_content():
"""Test transformation of messages with complex content structures."""
message = Message(
type=MessageType.MESSAGE,
role=Role.USER,
content=[
TextContent(type=ContentType.TEXT, text="first text"),
TextContent(type=ContentType.TEXT, text="second text"),
],
)
transformed = ReMeTaskMemoryService.transform_message(message)
# Should only use the first text content
assert transformed["role"] == Role.USER
assert transformed["content"] == "first text"
@pytest.mark.asyncio
async def test_message_without_role():
"""Test transformation of message without role."""
message = Message(
type=MessageType.MESSAGE,
content=[TextContent(type=ContentType.TEXT, text="no role message")],
)
transformed = ReMeTaskMemoryService.transform_message(message)
assert transformed["role"] is None
assert transformed["content"] == "no role message"
@pytest.mark.asyncio
async def test_concurrent_operations(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test that concurrent operations work correctly."""
import asyncio
user_id = "concurrent_user"
# Create multiple concurrent operations
tasks = [
memory_service.add_memory(
user_id,
[create_message(Role.USER, "message 1")],
),
memory_service.search_memory(
user_id,
[create_message(Role.USER, "search")],
),
memory_service.list_memory(user_id),
]
# Execute all tasks concurrently
await asyncio.gather(*tasks)
# Verify all operations were called
memory_service.service.add_memory.assert_called()
memory_service.service.search_memory.assert_called()
memory_service.service.list_memory.assert_called()
@pytest.mark.asyncio
async def test_service_instance_type(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test that the underlying service is TaskMemoryService."""
# The service should be mocked, so we just verify it exists
assert hasattr(memory_service, "service")
assert memory_service.service is not None
@pytest.mark.asyncio
async def test_task_specific_operations(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test operations that might be specific to task memory."""
user_id = "task_user"
# Test task-related message
task_message = create_message(Role.USER, "Complete task: analyze data")
await memory_service.add_memory(user_id, [task_message], "task_session")
# Verify the call was made with task-related content
memory_service.service.add_memory.assert_called_once()
call_args = memory_service.service.add_memory.call_args
assert call_args[0][0] == user_id
assert call_args[0][1] == [
{"role": Role.USER, "content": "Complete task: analyze data"},
]
assert call_args[0][2] == "task_session"
@pytest.mark.asyncio
async def test_task_memory_search_with_task_filters(
memory_service: ReMeTaskMemoryService, # type: ignore[valid-type]
):
"""Test searching memory with task-specific filters."""
user_id = "task_search_user"
messages = [create_message(Role.USER, "find completed tasks")]
task_filters = {"task_status": "completed", "top_k": 10}
expected_results = [{"task_id": "123", "status": "completed"}]
# Configure mock to return expected results
memory_service.service.search_memory.return_value = expected_results
results = await memory_service.search_memory(
user_id,
messages,
task_filters,
)
# Verify the underlying service was called correctly
memory_service.service.search_memory.assert_called_once()
call_args = memory_service.service.search_memory.call_args
assert call_args[0][0] == user_id
assert call_args[0][1] == [
{"role": Role.USER, "content": "find completed tasks"},
]
assert call_args[0][2] == task_filters
assert results == expected_results