Skip to content

Commit 2273ae5

Browse files
authored
Add multi-threading support for single API calls in inference module (#40)
Closes #39 - Add multi-threading support for single API calls in inference module - Incremental modification: The number of concurrent requests for each port can be specified by using `threads_per_port` (default is 20), and the upper limit of concurrency can be set through `max_workers` (default is 80, and it will not take effect when it exceeds 20 times the number of CPUs).
2 parents bc5ab68 + f669a42 commit 2273ae5

File tree

2 files changed

+176
-15
lines changed

2 files changed

+176
-15
lines changed

oasis/inference/inference_manager.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import asyncio
1515
import logging
1616
import threading
17+
from os import cpu_count
18+
from typing import Any
1719

1820
from oasis.inference.inference_thread import InferenceThread, SharedMemory
1921

@@ -32,32 +34,62 @@ class InferencerManager:
3234

3335
def __init__(
3436
self,
35-
channel,
36-
model_type,
37-
model_path,
38-
stop_tokens,
39-
server_url,
37+
channel: Any,
38+
model_type: str,
39+
model_path: str,
40+
stop_tokens: list[str],
41+
server_url: list[dict[str, list[int]]],
42+
threads_per_port: int = 20,
43+
max_workers: int = 80,
4044
):
4145
self.count = 0
4246
self.channel = channel
4347
self.threads = []
4448
self.lock = threading.Lock(
4549
) # Use thread lock to protect shared resources
4650
self.stop_event = threading.Event() # Event for stopping threads
51+
52+
# Check if max_workers is set to a reasonable value
53+
if max_workers < 1:
54+
inference_log.error(
55+
"Max workers must be at least 1. Setting to 1.")
56+
max_workers = 1
57+
# For IO bound tasks, max_workers should be set to a higher value
58+
# between 5 and 20 times the number of CPUs
59+
elif max_workers > cpu_count() * 20:
60+
inference_log.warning(
61+
f"Max workers is higher than recommended value. Setting to "
62+
f"{cpu_count() * 20}.")
63+
max_workers = cpu_count() * 20
64+
65+
# Check if threads_per_port is set to a reasonable value
66+
total_ports = 0
67+
for url in server_url:
68+
total_ports += len(url["ports"])
69+
if total_ports * threads_per_port > max_workers:
70+
threads_per_port = max(max_workers // total_ports, 1)
71+
inference_log.warning(
72+
f"Total threads exceeds max workers. Setting threads per port "
73+
f"to {threads_per_port}.")
74+
if threads_per_port < 1:
75+
inference_log.error(
76+
"Threads per port must be at least 1. Setting to 1.")
77+
threads_per_port = 1
78+
4779
for url in server_url:
4880
host = url["host"]
4981
for port in url["ports"]:
5082
_url = f"http://{host}:{port}/v1"
51-
shared_memory = SharedMemory()
52-
thread = InferenceThread(
53-
model_path=model_path,
54-
server_url=_url,
55-
stop_tokens=stop_tokens,
56-
model_type=model_type,
57-
temperature=0.0,
58-
shared_memory=shared_memory,
59-
)
60-
self.threads.append(thread)
83+
self.threads.extend([
84+
InferenceThread(
85+
model_path=model_path,
86+
server_url=_url,
87+
stop_tokens=stop_tokens,
88+
model_type=model_type,
89+
temperature=0.0,
90+
shared_memory=SharedMemory(),
91+
) for _ in range(threads_per_port)
92+
])
6193

6294
async def run(self):
6395
# Start threads
@@ -103,5 +135,7 @@ async def run(self):
103135
await self.stop()
104136

105137
async def stop(self):
138+
self.stop_event.set()
139+
106140
for thread in self.threads:
107141
thread.alive = False

test/inference/test_inference.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
2+
# Licensed under the Apache License, Version 2.0 (the “License”);
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an “AS IS” BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
import asyncio
15+
from unittest import mock
16+
17+
import pytest
18+
19+
from oasis.inference import InferencerManager
20+
from oasis.social_platform import Channel
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_manager_run_with_mocked_response():
25+
channel = Channel()
26+
27+
# Setup the InferencerManager with the real channel
28+
manager = InferencerManager(
29+
channel=channel,
30+
model_type="llama-3",
31+
model_path="/path/to/model",
32+
stop_tokens=["\n"],
33+
server_url=[{
34+
"host": "localhost",
35+
"ports": [8000]
36+
}],
37+
)
38+
39+
# Mocking the run method of model_backend to return a mocked response
40+
mock_response = mock.Mock()
41+
mock_response.choices = [
42+
mock.Mock(message=mock.Mock(content="Mock Response"))
43+
]
44+
45+
# Mocking channel.send_to as well
46+
with mock.patch.object(manager.threads[0].model_backend,
47+
'run',
48+
return_value=mock_response):
49+
50+
openai_messages = [{
51+
"role": "assistant",
52+
"content": 'mock_message',
53+
}]
54+
55+
# Run the manager asynchronously
56+
task = asyncio.create_task(manager.run())
57+
58+
# Add a message to the receive_queue
59+
mes_id = await channel.write_to_receive_queue(openai_messages)
60+
mes_id, content = await channel.read_from_send_queue(mes_id)
61+
assert content == "Mock Response"
62+
63+
await manager.stop()
64+
await task
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_multiple_threads():
69+
# Create a Channel instance
70+
channel = Channel()
71+
72+
# Set up multiple ports to simulate multiple threads
73+
server_url = [{
74+
"host": "localhost",
75+
"ports": [8000, 8001, 8002]
76+
} # 3 ports
77+
]
78+
79+
# Initialize InferencerManager with multiple threads
80+
manager = InferencerManager(
81+
channel=channel,
82+
model_type="llama-3",
83+
model_path="/path/to/model",
84+
stop_tokens=["\n"],
85+
server_url=server_url,
86+
threads_per_port=2, # 2 threads per port
87+
)
88+
89+
# Mock the response for multiple threads
90+
mock_response = mock.Mock()
91+
mock_response.choices = [
92+
mock.Mock(message=mock.Mock(content="Mock Response"))
93+
]
94+
95+
# Replace the model_backend.run method for all threads with the mock
96+
for thread in manager.threads:
97+
thread.model_backend.run = mock.Mock(return_value=mock_response)
98+
99+
# Start the manager
100+
task = asyncio.create_task(manager.run())
101+
102+
# Send multiple messages to the queue
103+
openai_messages = [{
104+
"role": "assistant",
105+
"content": f"mock_message_{i}"
106+
} for i in range(10)]
107+
108+
# Write messages to the receive queue
109+
message_ids = []
110+
for message in openai_messages:
111+
message_id = await channel.write_to_receive_queue([message])
112+
message_ids.append(message_id)
113+
114+
# Read results from the send queue
115+
results = []
116+
for message_id in message_ids:
117+
_, content = await channel.read_from_send_queue(message_id)
118+
results.append(content)
119+
120+
# Validate the results
121+
assert len(results) == 10 # Ensure all messages are processed
122+
assert all(content == "Mock Response"
123+
for content in results) # Ensure all responses are correct
124+
125+
# Stop the manager
126+
await manager.stop()
127+
await task

0 commit comments

Comments
 (0)