Skip to content

Commit b22f680

Browse files
authored
Merge pull request #83 from zhuyanhuazhuyanhua/main
feat:add thread pool
2 parents f22b205 + 53c94dc commit b22f680

File tree

2 files changed

+176
-3
lines changed

2 files changed

+176
-3
lines changed

oxygent/oxy/function_tools/function_hub.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import functools
10+
import concurrent.futures
1011

1112
from pydantic import Field
1213

@@ -29,6 +30,18 @@ class FunctionHub(BaseTool):
2930
default_factory=dict, description="Registry of functions and their metadata"
3031
)
3132

33+
def __init__(self, **data):
34+
"""Initialize the FunctionHub with thread pool support."""
35+
super().__init__(**data)
36+
self._thread_pool = None # Private attribute for thread pool
37+
38+
@property
39+
def thread_pool(self):
40+
"""Lazy initialization of thread pool."""
41+
if self._thread_pool is None:
42+
self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
43+
return self._thread_pool
44+
3245
async def init(self):
3346
"""Initialize the hub by creating FunctionTool instances for all registered
3447
functions.
@@ -67,14 +80,35 @@ def decorator(func):
6780
if asyncio.iscoroutinefunction(func):
6881
async_func = func
6982
else:
70-
# Wrap synchronous function to make it asynchronous
83+
# Wrap synchronous function to make it asynchronous using thread pool
7184
@functools.wraps(func)
7285
async def async_func(*args, **kwargs):
73-
# TODO: Use thread pool for blocking synchronous operations
74-
return func(*args, **kwargs)
86+
# Use thread pool for blocking synchronous operations
87+
loop = asyncio.get_event_loop()
88+
if kwargs:
89+
# 如果有kwargs,使用functools.partial包装函数
90+
partial_func = functools.partial(func, **kwargs)
91+
return await loop.run_in_executor(
92+
self.thread_pool,
93+
partial_func,
94+
*args
95+
)
96+
else:
97+
# 如果没有kwargs,直接调用
98+
return await loop.run_in_executor(
99+
self.thread_pool,
100+
func,
101+
*args
102+
)
75103

76104
# Register function in the hub's dictionary
77105
self.func_dict[func.__name__] = (description, async_func)
78106
return async_func # Return the async version
79107

80108
return decorator
109+
110+
async def cleanup(self):
111+
"""Clean up resources, including the thread pool."""
112+
if self._thread_pool:
113+
self._thread_pool.shutdown(wait=True)
114+
self._thread_pool = None

test/unittest/test_function_hub.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44

55
import asyncio
6+
import time
7+
import concurrent.futures
68

79
import pytest
810

@@ -92,3 +94,140 @@ def inc(x: int):
9294

9395
result = asyncio.run(async_inc(41))
9496
assert result == 42
97+
98+
99+
# ────────────────────────────────────────────────────────────────────────────
100+
# Thread Pool Tests
101+
# ────────────────────────────────────────────────────────────────────────────
102+
def test_thread_pool_lazy_initialization(func_hub):
103+
"""Test that thread pool is lazily initialized."""
104+
# Thread pool should be None initially
105+
assert func_hub._thread_pool is None
106+
107+
# Accessing thread_pool property should initialize it
108+
pool = func_hub.thread_pool
109+
assert isinstance(pool, concurrent.futures.ThreadPoolExecutor)
110+
assert func_hub._thread_pool is not None
111+
112+
# Second access should return the same instance
113+
pool2 = func_hub.thread_pool
114+
assert pool is pool2
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_sync_function_execution_with_thread_pool(func_hub):
119+
"""Test that sync functions are executed in thread pool."""
120+
execution_info = {"thread_id": None, "main_thread_id": None}
121+
122+
@func_hub.tool("test sync function")
123+
def blocking_function(duration: float):
124+
"""Simulate blocking operation."""
125+
import threading
126+
execution_info["thread_id"] = threading.current_thread().ident
127+
time.sleep(duration)
128+
return f"completed in thread {execution_info['thread_id']}"
129+
130+
# Get main thread ID
131+
import threading
132+
execution_info["main_thread_id"] = threading.current_thread().ident
133+
134+
# Execute the function
135+
_, async_func = func_hub.func_dict["blocking_function"]
136+
result = await async_func(0.1)
137+
138+
# Verify function executed in different thread
139+
assert execution_info["thread_id"] is not None
140+
assert execution_info["thread_id"] != execution_info["main_thread_id"]
141+
assert "completed in thread" in result
142+
143+
144+
@pytest.mark.asyncio
145+
async def test_sync_function_with_kwargs_in_thread_pool(func_hub):
146+
"""Test sync functions with kwargs are executed in thread pool."""
147+
@func_hub.tool("test function with kwargs")
148+
def function_with_kwargs(a: int, b: int, multiplier: float = 1.0):
149+
"""Function that uses kwargs."""
150+
time.sleep(0.01) # Small delay to simulate work
151+
return (a + b) * multiplier
152+
153+
# Execute function with kwargs
154+
_, async_func = func_hub.func_dict["function_with_kwargs"]
155+
result = await async_func(2, 3, multiplier=2.0)
156+
157+
assert result == 10.0 # (2 + 3) * 2.0 = 10.0
158+
159+
160+
@pytest.mark.asyncio
161+
async def test_cleanup_shuts_down_thread_pool(func_hub):
162+
"""Test that cleanup properly shuts down thread pool."""
163+
# Initialize thread pool by accessing it
164+
pool = func_hub.thread_pool
165+
assert isinstance(pool, concurrent.futures.ThreadPoolExecutor)
166+
167+
# Verify pool is active
168+
assert func_hub._thread_pool is not None
169+
170+
# Cleanup should shut down the thread pool
171+
await func_hub.cleanup()
172+
173+
# Thread pool should be None after cleanup
174+
assert func_hub._thread_pool is None
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_multiple_cleanup_calls_safe(func_hub):
179+
"""Test that multiple cleanup calls are safe."""
180+
# Initialize thread pool
181+
func_hub.thread_pool
182+
183+
# First cleanup
184+
await func_hub.cleanup()
185+
assert func_hub._thread_pool is None
186+
187+
# Second cleanup should not raise error
188+
await func_hub.cleanup()
189+
assert func_hub._thread_pool is None
190+
191+
192+
@pytest.mark.asyncio
193+
async def test_cleanup_without_thread_pool_initialization(func_hub):
194+
"""Test cleanup when thread pool was never initialized."""
195+
# Ensure thread pool is not initialized
196+
assert func_hub._thread_pool is None
197+
198+
# Cleanup should work without errors
199+
await func_hub.cleanup()
200+
assert func_hub._thread_pool is None
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_concurrent_sync_function_execution(func_hub):
205+
"""Test concurrent execution of multiple sync functions."""
206+
results = []
207+
208+
@func_hub.tool("concurrent task")
209+
def concurrent_task(task_id: int, duration: float):
210+
"""Simulate concurrent blocking operation."""
211+
time.sleep(duration)
212+
return f"task_{task_id}_completed"
213+
214+
# Execute multiple tasks concurrently
215+
_, async_func = func_hub.func_dict["concurrent_task"]
216+
tasks = [
217+
async_func(1, 0.1),
218+
async_func(2, 0.15),
219+
async_func(3, 0.05),
220+
]
221+
222+
start_time = time.time()
223+
results = await asyncio.gather(*tasks)
224+
total_time = time.time() - start_time
225+
226+
# Verify all tasks completed
227+
assert len(results) == 3
228+
assert "task_1_completed" in results
229+
assert "task_2_completed" in results
230+
assert "task_3_completed" in results
231+
232+
# Total time should be less than sum of individual times (due to concurrency)
233+
assert total_time < 0.3 # Should be around 0.15s (max duration)

0 commit comments

Comments
 (0)