11use anyhow:: { Context , Result , anyhow} ;
22use pyo3:: exceptions:: PyRuntimeError ;
33use pyo3:: prelude:: * ;
4- use pyo3:: types:: { PyDict , PyList , PyTuple } ;
4+ use pyo3:: types:: { PyAnyMethods , PyDict , PyDictMethods , PyList , PyModule , PyTuple } ;
55use pyo3_async_runtimes:: tokio:: into_future;
66use pythonize:: pythonize;
77use rmp_serde:: from_slice;
88use rmpv:: Value ;
99use std:: collections:: HashMap ;
10+ use std:: ffi:: CString ;
11+ use std:: path:: { Path , PathBuf } ;
12+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1013use std:: sync:: { Arc , RwLock } ;
14+ use std:: time:: Instant ;
1115use tokio:: sync:: { mpsc, oneshot} ;
1216
17+ use crate :: logger:: Logger ;
18+
19+ #[ derive( Debug ) ]
1320pub struct TaskRegistry {
1421 tasks : Arc < RwLock < HashMap < String , Arc < Py < PyAny > > > > > ,
22+ contexts : Arc < RwLock < HashMap < String , Arc < Py < PyAny > > > > > ,
1523}
1624
1725impl TaskRegistry {
18- pub fn new ( ) -> Self {
19- Self {
20- tasks : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
26+ pub fn new ( module_path : & str , queue_name : & str ) -> Result < Self > {
27+ let script = include_str ! ( "../scripts/get_registry.py" ) ;
28+ let script_cstr = CString :: new ( script) ?;
29+ let filename = CString :: new ( "get_registry.py" ) ?;
30+ let module_name = CString :: new ( "get_registry" ) ?;
31+
32+ let full_current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
33+ let full_module_path = full_current_dir. join ( module_path) ;
34+ let clean_module_path = normalize_path ( & full_module_path) ;
35+ let project_root = full_current_dir
36+ . ancestors ( )
37+ . find ( |p| p. join ( "tests" ) . exists ( ) )
38+ . unwrap_or ( & full_current_dir) ;
39+ let real_module_path = path_to_module_path ( project_root, & clean_module_path) ;
40+
41+ if !clean_module_path. exists ( ) || real_module_path. is_none ( ) {
42+ return Err ( anyhow ! (
43+ "Tasks module path {:?} doesn't exist." ,
44+ clean_module_path
45+ ) ) ;
2146 }
47+
48+ let real_module_path = real_module_path. unwrap ( ) ;
49+ let module_dir = project_root. to_string_lossy ( ) . to_string ( ) ;
50+
51+ let ( tasks, contexts) = Python :: attach (
52+ |py| -> Result < (
53+ HashMap < String , Arc < Py < PyAny > > > ,
54+ HashMap < String , Arc < Py < PyAny > > > ,
55+ ) > {
56+ let module = PyModule :: from_code (
57+ py,
58+ script_cstr. as_c_str ( ) ,
59+ filename. as_c_str ( ) ,
60+ module_name. as_c_str ( ) ,
61+ )
62+ . map_err ( |e| anyhow ! ( "Failed to import python module: {}" , e) ) ?;
63+
64+ let registry: Bound < ' _ , PyDict > = module
65+ . getattr ( "get_registry" )
66+ . map_err ( |e| anyhow ! ( "Failed to get 'get_registry' script: {}" , e) ) ?
67+ . call1 ( ( real_module_path, queue_name, module_dir) )
68+ . map_err ( |e| anyhow ! ( "Failed to get tasks: {}" , e) ) ?
69+ . cast_into :: < PyDict > ( )
70+ . map_err ( |_| anyhow ! ( "Failed to cast result to a Python Dictionary" ) ) ?;
71+
72+ let tasks: HashMap < String , Arc < Py < PyAny > > > = registry
73+ . get_item ( "tasks" ) ?
74+ . expect ( "tasks missing" )
75+ . cast :: < PyDict > ( )
76+ . map_err ( |e| anyhow ! ( "tasks is not a dict: {}" , e) ) ?
77+ . iter ( )
78+ . filter_map ( |( key, value) | {
79+ let name: String = key. extract ( ) . ok ( ) ?;
80+ let func: Py < PyAny > = value. unbind ( ) ;
81+ Some ( ( name, Arc :: new ( func) ) )
82+ } )
83+ . collect ( ) ;
84+
85+ let contexts: HashMap < String , Arc < Py < PyAny > > > = registry
86+ . get_item ( "contexts" ) ?
87+ . expect ( "contexts missing" )
88+ . cast :: < PyDict > ( )
89+ . map_err ( |e| anyhow ! ( "contexts is not a dict: {}" , e) ) ?
90+ . iter ( )
91+ . filter_map ( |( key, value) : ( Bound < PyAny > , Bound < PyAny > ) | {
92+ let name: String = key. extract ( ) . ok ( ) ?;
93+ let func: Py < PyAny > = value. unbind ( ) ;
94+ Some ( ( name, Arc :: new ( func) ) )
95+ } )
96+ . collect ( ) ;
97+
98+ Ok ( ( tasks, contexts) )
99+ } ,
100+ ) ?;
101+
102+ Ok ( Self {
103+ tasks : Arc :: new ( RwLock :: new ( tasks) ) ,
104+ contexts : Arc :: new ( RwLock :: new ( contexts) ) ,
105+ } )
22106 }
23107
24- pub fn insert ( & self , name : String , func : Py < PyAny > ) -> Result < ( ) > {
25- let mut tasks = self . tasks . write ( ) . map_err ( |_| {
26- anyhow ! ( "Internal Error: Task registry lock poisoned (a thread panicked)" )
27- } ) ?;
28- tasks. insert ( name, Arc :: new ( func) ) ;
29- Ok ( ( ) )
108+ pub fn get_registered_tasks ( & self ) -> Result < Vec < String > > {
109+ let tasks = self . tasks . read ( ) . map_err ( |e| anyhow ! ( e. to_string( ) ) ) ?;
110+ let task_names: Vec < _ > = tasks. iter ( ) . map ( |t| t. 0 . to_string ( ) ) . collect ( ) ;
111+ Ok ( task_names)
112+ }
113+
114+ pub fn get_registered_contexts ( & self ) -> Result < Vec < String > > {
115+ let contexts = self . contexts . read ( ) . map_err ( |e| anyhow ! ( e. to_string( ) ) ) ?;
116+ let context_names: Vec < _ > = contexts. iter ( ) . map ( |t| t. 0 . to_string ( ) ) . collect ( ) ;
117+ Ok ( context_names)
30118 }
31119
32120 pub fn get ( & self , name : & str ) -> Option < Arc < Py < PyAny > > > {
@@ -36,6 +124,7 @@ impl TaskRegistry {
36124}
37125
38126struct TaskRequest {
127+ executor_id : Arc < String > ,
39128 func : Arc < Py < PyAny > > ,
40129 task_name : Arc < String > ,
41130 raw_args : Arc < Vec < u8 > > ,
@@ -54,9 +143,15 @@ impl PythonDispatcher {
54143
55144 let dispatcher = async move {
56145 while let Some ( req) = rx. recv ( ) . await {
57- run_task ( req. func , req. task_name , req. raw_args , req. raw_kwargs )
58- . await
59- . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
146+ run_task (
147+ req. executor_id ,
148+ req. func ,
149+ req. task_name ,
150+ req. raw_args ,
151+ req. raw_kwargs ,
152+ )
153+ . await
154+ . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
60155
61156 let _ = req. resp_tx . send ( Ok ( ( ) ) ) ;
62157 }
@@ -74,6 +169,7 @@ impl PythonDispatcher {
74169
75170 pub async fn execute (
76171 & self ,
172+ executor_id : Arc < String > ,
77173 func : Arc < Py < PyAny > > ,
78174 task_name : Arc < String > ,
79175 raw_args : Arc < Vec < u8 > > ,
@@ -83,6 +179,7 @@ impl PythonDispatcher {
83179
84180 self . tx
85181 . send ( TaskRequest {
182+ executor_id,
86183 func,
87184 task_name,
88185 raw_args,
@@ -98,11 +195,15 @@ impl PythonDispatcher {
98195}
99196
100197async fn run_task (
198+ executor_id : Arc < String > ,
101199 task_function : Arc < Py < PyAny > > ,
102200 task_name : Arc < String > ,
103201 raw_args : Arc < Vec < u8 > > ,
104202 raw_kwargs : Arc < Vec < u8 > > ,
105203) -> Result < ( ) > {
204+ let logger = Logger :: new ( format ! ( "EXECUTOR {}" , & executor_id) ) ;
205+ let duration_start = Instant :: now ( ) ;
206+
106207 let task_args: Value = from_slice ( & raw_args) . context ( format ! (
107208 "Failed to deserialize task '{}' function args" ,
108209 & task_name
@@ -140,14 +241,99 @@ async fn run_task(
140241 if is_coroutine {
141242 Ok ( Some ( result) )
142243 } else {
244+ let duration_end = duration_start. elapsed ( ) ;
245+ logger. info ( format_args ! (
246+ "Task '{}' successfully finished in {}ms" ,
247+ & task_name,
248+ duration_end. as_millis( )
249+ ) ) ;
250+
143251 Ok ( None )
144252 }
253+ } )
254+ . map_err ( |e| {
255+ let duration_end = duration_start. elapsed ( ) ;
256+ logger. error ( format_args ! (
257+ "Task '{}' failed in {}ms: {}" ,
258+ & task_name,
259+ duration_end. as_millis( ) ,
260+ e
261+ ) ) ;
262+ anyhow ! ( e. to_string( ) )
145263 } ) ?;
146264
147265 if let Some ( coro) = maybe_coro {
148266 let fut = Python :: attach ( |py| into_future ( coro. into_bound ( py) ) ) ?;
149267 fut. await ?;
268+
269+ let duration_end = duration_start. elapsed ( ) ;
270+ logger. info ( format_args ! (
271+ "Task '{}' successfully finished in {}ms" ,
272+ & task_name,
273+ duration_end. as_millis( )
274+ ) ) ;
150275 }
151276
152277 Ok ( ( ) )
153278}
279+
280+ pub struct DispatcherPool {
281+ dispatchers : Vec < Arc < PythonDispatcher > > ,
282+ index : AtomicUsize ,
283+ }
284+
285+ impl DispatcherPool {
286+ pub fn new ( concurrency : usize ) -> Result < Self > {
287+ // let logical_cores = num_cpus::get();
288+ // let pool_size = (concurrency / 4)
289+ // .max(1)
290+ // .min(logical_cores * 2)
291+ // .max(logical_cores);
292+
293+ let mut dispatchers = Vec :: with_capacity ( concurrency) ;
294+ for _ in 0 ..concurrency {
295+ dispatchers. push ( Arc :: new ( PythonDispatcher :: new ( ) ?) ) ;
296+ }
297+
298+ Ok ( Self {
299+ dispatchers,
300+ index : AtomicUsize :: new ( 0 ) ,
301+ } )
302+ }
303+
304+ pub fn get ( & self ) -> Arc < PythonDispatcher > {
305+ let idx = self . index . fetch_add ( 1 , Ordering :: Relaxed ) % self . dispatchers . len ( ) ;
306+ self . dispatchers [ idx] . clone ( )
307+ }
308+ }
309+
310+ fn normalize_path ( path : & Path ) -> PathBuf {
311+ let mut components = Vec :: new ( ) ;
312+ for comp in path. components ( ) {
313+ match comp {
314+ std:: path:: Component :: ParentDir => {
315+ components. pop ( ) ;
316+ }
317+ std:: path:: Component :: CurDir => { }
318+ other => components. push ( other) ,
319+ }
320+ }
321+ components. iter ( ) . collect ( )
322+ }
323+
324+ fn path_to_module_path ( current_dir : & Path , target_path : & Path ) -> Option < String > {
325+ let rel_path = target_path. strip_prefix ( current_dir) . ok ( ) ?;
326+
327+ let mut components: Vec < String > = rel_path
328+ . components ( )
329+ . map ( |c| c. as_os_str ( ) . to_string_lossy ( ) . to_string ( ) )
330+ . collect ( ) ;
331+
332+ if let Some ( last) = components. last_mut ( )
333+ && let Some ( pos) = last. rfind ( '.' )
334+ {
335+ last. truncate ( pos) ;
336+ }
337+
338+ Some ( components. join ( "." ) )
339+ }
0 commit comments