-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathtest_batch_requests.py
More file actions
78 lines (62 loc) · 2.44 KB
/
test_batch_requests.py
File metadata and controls
78 lines (62 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
from unittest.mock import MagicMock, patch
import pytest
import matrix
from matrix.app_server.llm import query_llm
def test_batch_requests_from_async_run():
"""Test batch_requests called from within an asyncio.run context."""
mock_response = "mocked_response"
async def mock_make_request_async(_url, _model, request):
return f"{mock_response}_{request}"
async def async_wrapper():
with patch(
"matrix.app_server.llm.query_llm.make_request",
side_effect=mock_make_request_async,
):
requests = [1, 2, 3]
# batch_requests should handle the async context internally
# and return a list directly, not a task
result = query_llm.batch_requests("", "", requests)
# Verify it returned a list, not a task
assert isinstance(result, list)
assert len(result) == 3
assert result == [
f"{mock_response}_1",
f"{mock_response}_2",
f"{mock_response}_3",
]
# Use asyncio.run to execute the async wrapper
asyncio.run(async_wrapper())
def test_batch_requests_in_sync_context():
"""Test batch_requests when called from a synchronous context."""
# Create a mock for make_request_async
mock_response = "mocked_response"
async def mock_make_request_async(_url, _model, request):
return f"{mock_response}_{request}"
with patch(
"matrix.app_server.llm.query_llm.make_request",
side_effect=mock_make_request_async,
):
# Test with a list of requests
requests = [1, 2, 3]
result = query_llm.batch_requests("", "", requests)
# Verify results
assert len(result) == 3
assert result == [
f"{mock_response}_1",
f"{mock_response}_2",
f"{mock_response}_3",
]
def test_batch_requests_empty_list():
"""Test batch_requests with an empty list."""
with patch("matrix.app_server.llm.query_llm.make_request") as mock_request:
result = query_llm.batch_requests("", "", [])
# make_request_async should not be called
mock_request.assert_not_called()
# Result should be an empty list
assert result == []