1- use anyhow:: { Result , anyhow} ;
1+ use anyhow:: { Context , Result , anyhow} ;
2+ use pyo3:: exceptions:: PyRuntimeError ;
23use pyo3:: prelude:: * ;
4+ use pyo3:: types:: { PyDict , PyList , PyTuple } ;
35use pyo3_async_runtimes:: tokio:: into_future;
6+ use pythonize:: pythonize;
7+ use rmp_serde:: from_slice;
8+ use rmpv:: Value ;
49use std:: collections:: HashMap ;
5- use std:: pin:: Pin ;
610use std:: sync:: { Arc , RwLock } ;
711use tokio:: sync:: { mpsc, oneshot} ;
812
@@ -19,7 +23,7 @@ impl TaskRegistry {
1923
2024 pub fn insert ( & self , name : String , func : Py < PyAny > ) -> Result < ( ) > {
2125 let mut tasks = self . tasks . write ( ) . map_err ( |_| {
22- anyhow:: anyhow !( "Internal Error: Task registry lock poisoned (a thread panicked)" )
26+ anyhow ! ( "Internal Error: Task registry lock poisoned (a thread panicked)" )
2327 } ) ?;
2428 tasks. insert ( name, Arc :: new ( func) ) ;
2529 Ok ( ( ) )
@@ -31,51 +35,119 @@ impl TaskRegistry {
3135 }
3236}
3337
34- type PyResponse = Pin < Box < dyn Future < Output = PyResult < Py < PyAny > > > + Send > > ;
35-
3638struct TaskRequest {
37- func : Py < PyAny > ,
38- resp_tx : oneshot:: Sender < PyResult < PyResponse > > ,
39+ func : Arc < Py < PyAny > > ,
40+ task_name : Arc < String > ,
41+ raw_args : Arc < Vec < u8 > > ,
42+ raw_kwargs : Arc < Vec < u8 > > ,
43+ resp_tx : oneshot:: Sender < Result < ( ) > > ,
3944}
4045
4146pub struct PythonDispatcher {
4247 tx : mpsc:: Sender < TaskRequest > ,
4348}
4449
4550impl PythonDispatcher {
46- pub fn new ( ) -> Self {
51+ pub fn new ( ) -> Result < Self > {
4752 let logical_cores = num_cpus:: get ( ) ;
4853 let ( tx, mut rx) = mpsc:: channel :: < TaskRequest > ( logical_cores * 2 ) ;
4954
50- std :: thread :: spawn ( move || {
51- Python :: attach ( |py| {
52- let dispatcher = async move {
53- while let Some ( req ) = rx . recv ( ) . await {
54- let res = Python :: attach ( |py| into_future ( req . func . into_bound ( py ) ) ) ;
55+ let dispatcher = async move {
56+ while let Some ( req ) = rx . recv ( ) . await {
57+ run_task ( req . func , req . task_name , req . raw_args , req . raw_kwargs )
58+ . await
59+ . map_err ( |e| PyRuntimeError :: new_err ( e . to_string ( ) ) ) ? ;
5560
56- let _ = req. resp_tx . send ( res . map ( |f| Box :: pin ( f ) as PyResponse ) ) ;
57- }
58- Ok ( ( ) )
59- } ;
61+ let _ = req. resp_tx . send ( Ok ( ( ) ) ) ;
62+ }
63+ Ok ( ( ) )
64+ } ;
6065
66+ tokio:: task:: spawn_blocking ( move || {
67+ Python :: attach ( |py| {
6168 pyo3_async_runtimes:: tokio:: run ( py, dispatcher) . expect ( "Python loop failed" ) ;
6269 } ) ;
6370 } ) ;
6471
65- Self { tx }
72+ Ok ( Self { tx } )
6673 }
6774
68- pub async fn execute ( & self , func : Py < PyAny > ) -> Result < Py < PyAny > > {
75+ pub async fn execute (
76+ & self ,
77+ func : Arc < Py < PyAny > > ,
78+ task_name : Arc < String > ,
79+ raw_args : Arc < Vec < u8 > > ,
80+ raw_kwargs : Arc < Vec < u8 > > ,
81+ ) -> Result < ( ) > {
6982 let ( resp_tx, resp_rx) = oneshot:: channel ( ) ;
7083
7184 self . tx
72- . send ( TaskRequest { func, resp_tx } )
85+ . send ( TaskRequest {
86+ func,
87+ task_name,
88+ raw_args,
89+ raw_kwargs,
90+ resp_tx,
91+ } )
7392 . await
7493 . map_err ( |_| anyhow ! ( "Dispatcher channel closed" ) ) ?;
7594
76- let py_fut = resp_rx. await ??;
95+ resp_rx. await ??;
96+ Ok ( ( ) )
97+ }
98+ }
99+
100+ async fn run_task (
101+ task_function : Arc < Py < PyAny > > ,
102+ task_name : Arc < String > ,
103+ raw_args : Arc < Vec < u8 > > ,
104+ raw_kwargs : Arc < Vec < u8 > > ,
105+ ) -> Result < ( ) > {
106+ let task_args: Value = from_slice ( & raw_args) . context ( format ! (
107+ "Failed to deserialize task '{}' function args" ,
108+ & task_name
109+ ) ) ?;
110+ let task_kwargs: Value = from_slice ( & raw_kwargs) . context ( format ! (
111+ "Failed to deserialize task '{}' function kwargs" ,
112+ & task_name
113+ ) ) ?;
114+
115+ let maybe_coro = Python :: attach ( |py| -> Result < Option < Py < PyAny > > > {
116+ let py_args = pythonize ( py, & task_args) . context ( "Failed to pythonize args" ) ?;
117+ let py_kwargs = pythonize ( py, & task_kwargs) . context ( "Failed to pythonize kwargs" ) ?;
118+
119+ let args_tuple = if let Ok ( list) = py_args. cast :: < PyList > ( ) {
120+ list. to_tuple ( )
121+ } else if let Ok ( tuple) = py_args. cast :: < PyTuple > ( ) {
122+ tuple. clone ( )
123+ } else {
124+ anyhow:: bail!( "Args must be an array/tuple, found {}" , py_args. get_type( ) ) ;
125+ } ;
126+
127+ let kwargs_dict = py_kwargs
128+ . cast_into :: < PyDict > ( )
129+ . map_err ( |_| anyhow ! ( "Kwargs must be a map/dict" ) ) ?;
130+
131+ let result = task_function
132+ . call ( py, args_tuple, Some ( & kwargs_dict) )
133+ . map_err ( |e| anyhow ! ( "Failed to call Python function: {:?}" , e) ) ?;
134+
135+ let bound_result = result. bind ( py) ;
136+ let is_coroutine = bound_result
137+ . hasattr ( "__await__" )
138+ . map_err ( |_| anyhow ! ( "Failed to check if result is awaitable" ) ) ?;
139+
140+ if is_coroutine {
141+ Ok ( Some ( result) )
142+ } else {
143+ Ok ( None )
144+ }
145+ } ) ?;
77146
78- let result = py_fut. await ?;
79- Ok ( result)
147+ if let Some ( coro) = maybe_coro {
148+ let fut = Python :: attach ( |py| into_future ( coro. into_bound ( py) ) ) ?;
149+ fut. await ?;
80150 }
151+
152+ Ok ( ( ) )
81153}
0 commit comments