@@ -12,7 +12,7 @@ use tokio::task::JoinSet;
1212
1313use crate :: logger:: { Logger , initial_logs} ;
1414use crate :: redis_client:: RedisClient ;
15- use crate :: task:: TaskRegistry ;
15+ use crate :: task:: { PythonDispatcher , TaskRegistry } ;
1616use fluxqueue_common:: { Task , deserialize_raw_task_data} ;
1717
1818pub async fn run_worker (
@@ -50,6 +50,7 @@ pub async fn run_worker(
5050
5151 let queue_name = Arc :: from ( queue_name. to_string ( ) ) ;
5252 let executor_ids = generate_executor_ids ( concurrency) ;
53+ let python_dispatcher = Arc :: new ( PythonDispatcher :: new ( ) ) ;
5354 let mut executors = JoinSet :: new ( ) ;
5455
5556 for i in 0 ..concurrency {
@@ -58,6 +59,7 @@ pub async fn run_worker(
5859 let executor_id = Arc :: clone ( & executor_ids[ i] ) ;
5960 let shutdown = shutdown. clone ( ) ;
6061 let task_registry = Arc :: clone ( & task_registry) ;
62+ let python_dispatcher = Arc :: clone ( & python_dispatcher) ;
6163
6264 redis_client
6365 . register_executor ( & queue_name, & executor_id)
@@ -71,6 +73,7 @@ pub async fn run_worker(
7173 executor_id,
7274 redis_client,
7375 task_registry,
76+ python_dispatcher,
7477 ) ) ;
7578 }
7679
@@ -111,6 +114,7 @@ async fn executor_loop(
111114 executor_id : Arc < str > ,
112115 redis_client : Arc < RedisClient > ,
113116 task_registry : Arc < TaskRegistry > ,
117+ python_dispatcher : Arc < PythonDispatcher > ,
114118) -> Result < ( ) > {
115119 let logger = Logger :: new ( format ! ( "EXECUTOR {}" , & executor_id) ) ;
116120
@@ -148,7 +152,7 @@ async fn executor_loop(
148152 } ;
149153
150154 let duration_start = Instant :: now( ) ;
151- let task_result = run_task( & task, task_function) . await ;
155+ let task_result = run_task( python_dispatcher . clone ( ) , & task, task_function) . await ;
152156
153157 match task_result {
154158 Ok ( _) => {
@@ -268,7 +272,11 @@ async fn janitor_loop(
268272 }
269273}
270274
271- async fn run_task ( task : & Task , task_function : Arc < Py < PyAny > > ) -> Result < ( ) > {
275+ async fn run_task (
276+ python_dispatcher : Arc < PythonDispatcher > ,
277+ task : & Task ,
278+ task_function : Arc < Py < PyAny > > ,
279+ ) -> Result < ( ) > {
272280 let task_args: rmpv:: Value = from_slice ( & task. args ) . context ( format ! (
273281 "Failed to deserialize task {} function args" ,
274282 task. name
@@ -278,47 +286,41 @@ async fn run_task(task: &Task, task_function: Arc<Py<PyAny>>) -> Result<()> {
278286 task. name
279287 ) ) ?;
280288
281- tokio:: task:: spawn_blocking ( move || {
282- Python :: attach ( |py| -> Result < ( ) > {
283- let py_args = pythonize ( py, & task_args) . context ( "Failed to pythonize args" ) ?;
284- let py_kwargs = pythonize ( py, & task_kwargs) . context ( "Failed to pythonize kwargs" ) ?;
285-
286- let args_tuple = if let Ok ( list) = py_args. cast :: < PyList > ( ) {
287- list. to_tuple ( )
288- } else if let Ok ( tuple) = py_args. cast :: < PyTuple > ( ) {
289- tuple. clone ( )
290- } else {
291- anyhow:: bail!( "Args must be an array/tuple, found {}" , py_args. get_type( ) ) ;
292- } ;
293-
294- let kwargs_dict = py_kwargs
295- . cast_into :: < PyDict > ( )
296- . map_err ( |_| anyhow ! ( "Kwargs must be a map/dict" ) ) ?;
297-
298- let result = task_function
299- . call ( py, args_tuple, Some ( & kwargs_dict) )
300- . map_err ( |e| anyhow ! ( "Failed to call Python function: {:?}" , e) ) ?;
301-
302- let bound_result = result. bind ( py) ;
303- let is_coroutine = bound_result
304- . hasattr ( "__await__" )
305- . map_err ( |_| anyhow ! ( "Failed to check if result is awaitable" ) ) ?;
306-
307- if is_coroutine {
308- let asyncio = py. import ( "asyncio" ) ?;
309- let run_func = asyncio. getattr ( "run" ) ?;
310-
311- if !run_func. is_callable ( ) {
312- anyhow:: bail!( "asyncio.run() not callable. Python 3.7+ required" ) ;
313- }
289+ let maybe_coro = Python :: attach ( |py| -> Result < Option < Py < PyAny > > > {
290+ let py_args = pythonize ( py, & task_args) . context ( "Failed to pythonize args" ) ?;
291+ let py_kwargs = pythonize ( py, & task_kwargs) . context ( "Failed to pythonize kwargs" ) ?;
314292
315- run_func. call1 ( ( result, ) ) ?;
316- }
317- Ok ( ( ) )
318- } )
319- } )
320- . await
321- . map_err ( |e| anyhow ! ( "Task execution panicked: {}" , e) ) ??;
293+ let args_tuple = if let Ok ( list) = py_args. cast :: < PyList > ( ) {
294+ list. to_tuple ( )
295+ } else if let Ok ( tuple) = py_args. cast :: < PyTuple > ( ) {
296+ tuple. clone ( )
297+ } else {
298+ anyhow:: bail!( "Args must be an array/tuple, found {}" , py_args. get_type( ) ) ;
299+ } ;
300+
301+ let kwargs_dict = py_kwargs
302+ . cast_into :: < PyDict > ( )
303+ . map_err ( |_| anyhow ! ( "Kwargs must be a map/dict" ) ) ?;
304+
305+ let result = task_function
306+ . call ( py, args_tuple, Some ( & kwargs_dict) )
307+ . map_err ( |e| anyhow ! ( "Failed to call Python function: {:?}" , e) ) ?;
308+
309+ let bound_result = result. bind ( py) ;
310+ let is_coroutine = bound_result
311+ . hasattr ( "__await__" )
312+ . map_err ( |_| anyhow ! ( "Failed to check if result is awaitable" ) ) ?;
313+
314+ if is_coroutine {
315+ Ok ( Some ( result) )
316+ } else {
317+ Ok ( None )
318+ }
319+ } ) ?;
320+
321+ if let Some ( coro) = maybe_coro {
322+ python_dispatcher. execute ( coro) . await ?;
323+ }
322324
323325 Ok ( ( ) )
324326}
0 commit comments