Skip to content

Commit 712bd16

Browse files
authored
add batch request api (#2)
* add issue templates * add a batch request interface * format * fix batch_requests for both sync and async caller
1 parent 9462b4e commit 712bd16

File tree

10 files changed

+191
-9
lines changed

10 files changed

+191
-9
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Bug Report
3+
about: Submit a report to help us improve.
4+
labels: 'bug, needs triage'
5+
---
6+
7+
**Describe the bug:**
8+
A clear and concise description of what the bug is.
9+
10+
**Describe how to reproduce:**
11+
Steps to reproduce the behavior. Ideally attach a minimal code sample to reproduce the described issue.
12+
13+
**Describe the expected behavior:**
14+
A clear and concise description of what you expected to happen.
15+
16+
**Environment:**
17+
At the very least, specify the versions of matrix, PyTorch, Python, and CUDA along with your operating system and, if relevant, GPU model.
18+
19+
**Additional Context:**
20+
Add any other context about the bug here.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
---
2+
name: Feature Request
3+
about: Submit a request for a new feature.
4+
labels: 'enhancement, needs triage'
5+
---
6+
7+
**Is your feature request related to a problem? Please describe:**
8+
A clear and concise description of what the problem is.
9+
10+
**Describe the solution you would like:**
11+
A clear and concise description of what you want to happen.
12+
13+
**Describe the alternatives you have considered:**
14+
A clear and concise description of any alternative solutions or features you have considered.
15+
16+
**Additional Context:**
17+
Add any other context about the feature request here.

.github/ISSUE_TEMPLATE/question.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
name: Question
3+
about: Ask a question to the users and contributors.
4+
labels: 'question, needs triage'
5+
---
6+
7+
Please make sure that you first search existing issues and documentation before asking a question. If you cannot find an answer, be clear and concise. Ideally attach a minimal code sample if it is relevant to your question.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
name: Typo or Documentation Issue
3+
about: Report a typo or an issue related to documentation.
4+
labels: 'documentation, needs triage'
5+
---
6+
7+
For typos, please go ahead; fix the typo and submit a PR. For documentation issues, please describe the issue here and wait for approval before submitting a PR.
File renamed without changes.

matrix/app_server/llm/query_llm.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,54 @@ async def make_request(
471471
}
472472

473473

474+
def batch_requests(
475+
url: tp.Union[str, tp.Callable[[], tp.Awaitable[str]]],
476+
model: str,
477+
requests: tp.List[tp.Dict[str, tp.Any]],
478+
**kwargs,
479+
) -> tp.List[tp.Dict[str, tp.Any]]:
480+
"""
481+
Process multiple requests by calling make_request for each.
482+
This function works whether called from sync or async context.
483+
"""
484+
485+
async def _process_requests():
486+
"""Helper function to process all requests concurrently."""
487+
return await asyncio.gather(
488+
*[make_request(url, model, request, **kwargs) for request in requests]
489+
)
490+
491+
# Get the event loop
492+
try:
493+
loop = asyncio.get_event_loop()
494+
except RuntimeError:
495+
# No event loop in this thread, create a new one
496+
loop = asyncio.new_event_loop()
497+
asyncio.set_event_loop(loop)
498+
499+
# Check if we're already in an async context
500+
if loop.is_running():
501+
# We're in an async context and can't use run_until_complete
502+
# Create a new thread to run our async code
503+
import concurrent.futures
504+
import threading
505+
506+
def run_in_new_loop():
507+
# Create a new event loop for this thread
508+
new_loop = asyncio.new_event_loop()
509+
try:
510+
return new_loop.run_until_complete(_process_requests())
511+
finally:
512+
new_loop.close()
513+
514+
# Run in an executor to avoid blocking the current event loop
515+
with concurrent.futures.ThreadPoolExecutor() as pool:
516+
return pool.submit(run_in_new_loop).result()
517+
else:
518+
# We're in a sync context, use the current loop
519+
return loop.run_until_complete(_process_requests())
520+
521+
474522
async def main(
475523
url: tp.Union[str, tp.Callable[[], tp.Awaitable[str]]],
476524
output_file: str,

matrix/cli.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,15 +357,13 @@ def check_health(
357357
else:
358358
if not use_chat:
359359
data_payload = {"prompt": prompt}
360-
response = asyncio.run(
361-
query_llm.make_request(
362-
metadata["endpoints"]["head"],
363-
metadata["model_name"],
364-
data_payload,
365-
app_name=app_name,
366-
**kwargs,
367-
)
368-
)
360+
response = query_llm.batch_requests(
361+
metadata["endpoints"]["head"],
362+
metadata["model_name"],
363+
[data_payload],
364+
app_name=app_name,
365+
**kwargs,
366+
)[0]
369367
print(response)
370368
return "error" not in response["response"]
371369

matrix/cluster/ray_dashboard_job.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import logging
28
import os
39
import shutil

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ classifiers=[
4242
dev = [
4343
# Test
4444
"pytest>=4.3.0",
45+
"pytest-asyncio>=0.26.0",
4546
"coverage[toml]>=5.1",
4647
# Format
4748
"black==24.10.0",
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import asyncio
8+
from unittest.mock import MagicMock, patch
9+
10+
import pytest
11+
12+
import matrix
13+
from matrix.app_server.llm import query_llm
14+
15+
16+
def test_batch_requests_from_async_run():
17+
"""Test batch_requests called from within an asyncio.run context."""
18+
mock_response = "mocked_response"
19+
20+
async def mock_make_request_async(_url, _model, request):
21+
return f"{mock_response}_{request}"
22+
23+
async def async_wrapper():
24+
with patch(
25+
"matrix.app_server.llm.query_llm.make_request",
26+
side_effect=mock_make_request_async,
27+
):
28+
requests = [1, 2, 3]
29+
# batch_requests should handle the async context internally
30+
# and return a list directly, not a task
31+
result = query_llm.batch_requests("", "", requests)
32+
33+
# Verify it returned a list, not a task
34+
assert isinstance(result, list)
35+
assert len(result) == 3
36+
assert result == [
37+
f"{mock_response}_1",
38+
f"{mock_response}_2",
39+
f"{mock_response}_3",
40+
]
41+
42+
# Use asyncio.run to execute the async wrapper
43+
asyncio.run(async_wrapper())
44+
45+
46+
def test_batch_requests_in_sync_context():
47+
"""Test batch_requests when called from a synchronous context."""
48+
# Create a mock for make_request_async
49+
mock_response = "mocked_response"
50+
51+
async def mock_make_request_async(_url, _model, request):
52+
return f"{mock_response}_{request}"
53+
54+
with patch(
55+
"matrix.app_server.llm.query_llm.make_request",
56+
side_effect=mock_make_request_async,
57+
):
58+
# Test with a list of requests
59+
requests = [1, 2, 3]
60+
result = query_llm.batch_requests("", "", requests)
61+
62+
# Verify results
63+
assert len(result) == 3
64+
assert result == [
65+
f"{mock_response}_1",
66+
f"{mock_response}_2",
67+
f"{mock_response}_3",
68+
]
69+
70+
71+
def test_batch_requests_empty_list():
72+
"""Test batch_requests with an empty list."""
73+
with patch("matrix.app_server.llm.query_llm.make_request") as mock_request:
74+
result = query_llm.batch_requests("", "", [])
75+
# make_request_async should not be called
76+
mock_request.assert_not_called()
77+
# Result should be an empty list
78+
assert result == []

0 commit comments

Comments
 (0)