11use anyhow:: { Context , Result , anyhow} ;
2+ use fluxqueue_common:: Task ;
23use pyo3:: exceptions:: PyRuntimeError ;
34use pyo3:: prelude:: * ;
45use pyo3:: types:: { PyAnyMethods , PyDict , PyDictMethods , PyList , PyModule , PyTuple } ;
@@ -112,8 +113,7 @@ struct TaskRequest {
112113 executor_id : Arc < String > ,
113114 task_data : Arc < TaskData > ,
114115 task_name : Arc < String > ,
115- raw_args : Arc < Vec < u8 > > ,
116- raw_kwargs : Arc < Vec < u8 > > ,
116+ task : Arc < Task > ,
117117 resp_tx : oneshot:: Sender < Result < ( ) > > ,
118118}
119119
@@ -134,8 +134,7 @@ impl PythonDispatcher {
134134 req. executor_id ,
135135 req. task_data ,
136136 req. task_name ,
137- req. raw_args ,
138- req. raw_kwargs ,
137+ req. task ,
139138 )
140139 . await
141140 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -159,8 +158,7 @@ impl PythonDispatcher {
159158 executor_id : Arc < String > ,
160159 task_data : Arc < TaskData > ,
161160 task_name : Arc < String > ,
162- raw_args : Arc < Vec < u8 > > ,
163- raw_kwargs : Arc < Vec < u8 > > ,
161+ task : Arc < Task > ,
164162 ) -> Result < ( ) > {
165163 let ( resp_tx, resp_rx) = oneshot:: channel ( ) ;
166164
@@ -170,8 +168,7 @@ impl PythonDispatcher {
170168 executor_id,
171169 task_data,
172170 task_name,
173- raw_args,
174- raw_kwargs,
171+ task,
175172 resp_tx,
176173 } )
177174 . await
@@ -182,29 +179,39 @@ impl PythonDispatcher {
182179 }
183180}
184181
182+ struct CoroWithContext {
183+ args : Py < PyTuple > ,
184+ kwargs : Py < PyDict > ,
185+ context : Arc < Py < PyAny > > ,
186+ }
187+
188+ struct MaybeCoro {
189+ func : Arc < Py < PyAny > > ,
190+ with_context : Option < CoroWithContext > ,
191+ }
192+
185193async fn run_task (
186194 task_registry : Arc < TaskRegistry > ,
187195 executor_id : Arc < String > ,
188196 task_data : Arc < TaskData > ,
189197 task_name : Arc < String > ,
190- raw_args : Arc < Vec < u8 > > ,
191- raw_kwargs : Arc < Vec < u8 > > ,
198+ task : Arc < Task > ,
192199) -> Result < ( ) > {
193200 let logger = Logger :: new ( format ! ( "EXECUTOR {}" , & executor_id) ) ;
194201 let duration_start = Instant :: now ( ) ;
195202
196- let task_args: Value = from_slice ( & raw_args ) . context ( format ! (
203+ let task_args: Value = from_slice ( & task . args ) . context ( format ! (
197204 "Failed to deserialize task '{}' function args" ,
198205 & task_name
199206 ) ) ?;
200- let task_kwargs: Value = from_slice ( & raw_kwargs ) . context ( format ! (
207+ let task_kwargs: Value = from_slice ( & task . kwargs ) . context ( format ! (
201208 "Failed to deserialize task '{}' function kwargs" ,
202209 & task_name
203210 ) ) ?;
204211
205212 let context = task_registry. get_task_context ( task_data. clone ( ) ) ?;
206213
207- let maybe_coro = Python :: attach ( |py| -> Result < Option < Py < PyAny > > > {
214+ let maybe_coro = Python :: attach ( |py| -> Result < Option < MaybeCoro > > {
208215 let py_args = pythonize ( py, & task_args) . context ( "Failed to pythonize args" ) ?;
209216 let py_kwargs = pythonize ( py, & task_kwargs) . context ( "Failed to pythonize kwargs" ) ?;
210217
@@ -216,7 +223,11 @@ async fn run_task(
216223 anyhow:: bail!( "Args must be an array/tuple, found {}" , py_args. get_type( ) ) ;
217224 } ;
218225
219- if let Some ( context) = context {
226+ if let Some ( context) = context. as_ref ( ) {
227+ let task_metadata = get_task_metadata ( py, task. clone ( ) ) ?;
228+ let metadata_var = context. getattr ( py, "_metadata_var" ) ?;
229+ metadata_var. call_method1 ( py, "set" , ( task_metadata, ) ) ?;
230+
220231 let context = context. as_any ( ) ;
221232 let prefix = PyTuple :: new ( py, [ context] ) ?;
222233
@@ -230,19 +241,25 @@ async fn run_task(
230241 . cast_into :: < PyDict > ( )
231242 . map_err ( |_| anyhow ! ( "Kwargs must be a map/dict" ) ) ?;
232243
233- let result = task_data
234- . func
235- . call ( py, args_tuple, Some ( & kwargs_dict) )
236- . map_err ( |e| anyhow ! ( "Failed to call Python function: {:?}" , e) ) ?;
237-
238- let bound_result = result. bind ( py) ;
239- let is_coroutine = bound_result
240- . hasattr ( "__await__" )
241- . map_err ( |_| anyhow ! ( "Failed to check if result is awaitable" ) ) ?;
244+ let is_coroutine = is_coroutine ( py, task_data. func . clone ( ) ) ?;
242245
243246 if is_coroutine {
244- Ok ( Some ( result) )
247+ let with_context = context. map ( |context| CoroWithContext {
248+ args : args_tuple. unbind ( ) ,
249+ kwargs : kwargs_dict. unbind ( ) ,
250+ context,
251+ } ) ;
252+
253+ Ok ( Some ( MaybeCoro {
254+ func : task_data. func . clone ( ) ,
255+ with_context,
256+ } ) )
245257 } else {
258+ task_data
259+ . func
260+ . call ( py, args_tuple. clone ( ) , Some ( & kwargs_dict) )
261+ . map_err ( |e| anyhow ! ( "Failed to call Python function: {:?}" , e) ) ?;
262+
246263 let duration_end = duration_start. elapsed ( ) ;
247264 logger. info ( format_args ! (
248265 "Task '{}' successfully finished in {}ms" ,
@@ -264,8 +281,27 @@ async fn run_task(
264281 anyhow ! ( e. to_string( ) )
265282 } ) ?;
266283
267- if let Some ( coro) = maybe_coro {
268- let fut = Python :: attach ( |py| into_future ( coro. into_bound ( py) ) ) ?;
284+ if let Some ( maybe_coro) = maybe_coro {
285+ let fut = Python :: attach ( |py| {
286+ if let Some ( with_context) = maybe_coro. with_context {
287+ let task_metadata = get_task_metadata ( py, task. clone ( ) )
288+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
289+ let result = with_context. context . call_method1 (
290+ py,
291+ "_run_async_task" ,
292+ (
293+ maybe_coro. func . as_any ( ) ,
294+ task_metadata,
295+ with_context. args ,
296+ Some ( with_context. kwargs ) ,
297+ ) ,
298+ ) ?;
299+ into_future ( result. into_bound ( py) )
300+ } else {
301+ let func = maybe_coro. func . clone_ref ( py) ;
302+ into_future ( func. into_bound ( py) )
303+ }
304+ } ) ?;
269305 fut. await ?;
270306
271307 let duration_end = duration_start. elapsed ( ) ;
@@ -370,6 +406,30 @@ fn get_registry(module_path: &str, queue_name: &str) -> Result<TasksAndContexts>
370406 Ok ( result)
371407}
372408
409+ fn get_task_metadata ( py : Python < ' _ > , task : Arc < Task > ) -> Result < Py < PyAny > > {
410+ let module = py. import ( "fluxqueue.models" ) ?. unbind ( ) ;
411+ let task_metadata = module. call_method1 (
412+ py,
413+ "TaskMetadata" ,
414+ (
415+ task. id . clone ( ) ,
416+ task. retries ,
417+ task. max_retries ,
418+ task. created_at ,
419+ ) ,
420+ ) ?;
421+
422+ Ok ( task_metadata)
423+ }
424+
425+ fn is_coroutine ( py : Python < ' _ > , func : Arc < Py < PyAny > > ) -> Result < bool > {
426+ let inspect = py. import ( "inspect" ) ?;
427+ let is_coro: bool = inspect
428+ . call_method1 ( "iscoroutinefunction" , ( func. as_any ( ) , ) ) ?
429+ . extract ( ) ?;
430+ Ok ( is_coro)
431+ }
432+
373433fn normalize_path ( path : & Path ) -> PathBuf {
374434 let mut components = Vec :: new ( ) ;
375435 for comp in path. components ( ) {
0 commit comments