@@ -15,11 +15,6 @@ use tokio::sync::{mpsc, oneshot};
1515
1616use crate :: logger:: Logger ;
1717
18- type TasksAndContexts = (
19- HashMap < String , Arc < Py < PyAny > > > ,
20- HashMap < String , Arc < Py < PyAny > > > ,
21- ) ;
22-
2318#[ derive( Debug ) ]
2419pub struct TaskRegistry {
2520 tasks : Arc < RwLock < HashMap < String , Arc < Py < PyAny > > > > > ,
@@ -28,75 +23,7 @@ pub struct TaskRegistry {
2823
2924impl TaskRegistry {
3025 pub fn new ( module_path : & str , queue_name : & str ) -> Result < Self > {
31- let script = include_str ! ( "../scripts/get_registry.py" ) ;
32- let script_cstr = CString :: new ( script) ?;
33- let filename = CString :: new ( "get_registry.py" ) ?;
34- let module_name = CString :: new ( "get_registry" ) ?;
35-
36- let full_current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
37- let full_module_path = full_current_dir. join ( module_path) ;
38- let clean_module_path = normalize_path ( & full_module_path) ;
39- let project_root = full_current_dir
40- . ancestors ( )
41- . find ( |p| p. join ( "tests" ) . exists ( ) )
42- . unwrap_or ( & full_current_dir) ;
43- let real_module_path = path_to_module_path ( project_root, & clean_module_path) ;
44-
45- if !clean_module_path. exists ( ) || real_module_path. is_none ( ) {
46- return Err ( anyhow ! (
47- "Tasks module path {:?} doesn't exist." ,
48- clean_module_path
49- ) ) ;
50- }
51-
52- let real_module_path = real_module_path. unwrap ( ) ;
53- let module_dir = project_root. to_string_lossy ( ) . to_string ( ) ;
54-
55- let ( tasks, contexts) = Python :: attach ( |py| -> Result < TasksAndContexts > {
56- let module = PyModule :: from_code (
57- py,
58- script_cstr. as_c_str ( ) ,
59- filename. as_c_str ( ) ,
60- module_name. as_c_str ( ) ,
61- )
62- . map_err ( |e| anyhow ! ( "Failed to import python module: {}" , e) ) ?;
63-
64- let registry: Bound < ' _ , PyDict > = module
65- . getattr ( "get_registry" )
66- . map_err ( |e| anyhow ! ( "Failed to get 'get_registry' script: {}" , e) ) ?
67- . call1 ( ( real_module_path, queue_name, module_dir) )
68- . map_err ( |e| anyhow ! ( "Failed to get tasks: {}" , e) ) ?
69- . cast_into :: < PyDict > ( )
70- . map_err ( |_| anyhow ! ( "Failed to cast result to a Python Dictionary" ) ) ?;
71-
72- let tasks: HashMap < String , Arc < Py < PyAny > > > = registry
73- . get_item ( "tasks" ) ?
74- . expect ( "tasks missing" )
75- . cast :: < PyDict > ( )
76- . map_err ( |e| anyhow ! ( "tasks is not a dict: {}" , e) ) ?
77- . iter ( )
78- . filter_map ( |( key, value) | {
79- let name: String = key. extract ( ) . ok ( ) ?;
80- let func: Py < PyAny > = value. unbind ( ) ;
81- Some ( ( name, Arc :: new ( func) ) )
82- } )
83- . collect ( ) ;
84-
85- let contexts: HashMap < String , Arc < Py < PyAny > > > = registry
86- . get_item ( "contexts" ) ?
87- . expect ( "contexts missing" )
88- . cast :: < PyDict > ( )
89- . map_err ( |e| anyhow ! ( "contexts is not a dict: {}" , e) ) ?
90- . iter ( )
91- . filter_map ( |( key, value) : ( Bound < PyAny > , Bound < PyAny > ) | {
92- let name: String = key. extract ( ) . ok ( ) ?;
93- let func: Py < PyAny > = value. unbind ( ) ;
94- Some ( ( name, Arc :: new ( func) ) )
95- } )
96- . collect ( ) ;
97-
98- Ok ( ( tasks, contexts) )
99- } ) ?;
26+ let ( tasks, contexts) = get_registry ( module_path, queue_name) ?;
10027
10128 Ok ( Self {
10229 tasks : Arc :: new ( RwLock :: new ( tasks) ) ,
@@ -276,6 +203,85 @@ async fn run_task(
276203 Ok ( ( ) )
277204}
278205
206+ type TasksAndContexts = (
207+ HashMap < String , Arc < Py < PyAny > > > ,
208+ HashMap < String , Arc < Py < PyAny > > > ,
209+ ) ;
210+
211+ fn get_registry ( module_path : & str , queue_name : & str ) -> Result < TasksAndContexts > {
212+ let script = include_str ! ( "../scripts/get_registry.py" ) ;
213+ let script_cstr = CString :: new ( script) ?;
214+ let filename = CString :: new ( "get_registry.py" ) ?;
215+ let module_name = CString :: new ( "get_registry" ) ?;
216+
217+ let full_current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
218+ let full_module_path = full_current_dir. join ( module_path) ;
219+ let clean_module_path = normalize_path ( & full_module_path) ;
220+ let project_root = full_current_dir
221+ . ancestors ( )
222+ . find ( |p| p. join ( "tests" ) . exists ( ) )
223+ . unwrap_or ( & full_current_dir) ;
224+ let real_module_path = path_to_module_path ( project_root, & clean_module_path) ;
225+
226+ if !clean_module_path. exists ( ) || real_module_path. is_none ( ) {
227+ return Err ( anyhow ! (
228+ "Tasks module path {:?} doesn't exist." ,
229+ clean_module_path
230+ ) ) ;
231+ }
232+
233+ let real_module_path = real_module_path. unwrap ( ) ;
234+ let module_dir = project_root. to_string_lossy ( ) . to_string ( ) ;
235+
236+ let result = Python :: attach ( |py| -> Result < TasksAndContexts > {
237+ let module = PyModule :: from_code (
238+ py,
239+ script_cstr. as_c_str ( ) ,
240+ filename. as_c_str ( ) ,
241+ module_name. as_c_str ( ) ,
242+ )
243+ . map_err ( |e| anyhow ! ( "Failed to import python module: {}" , e) ) ?;
244+
245+ let registry: Bound < ' _ , PyDict > = module
246+ . getattr ( "get_registry" )
247+ . map_err ( |e| anyhow ! ( "Failed to get 'get_registry' script: {}" , e) ) ?
248+ . call1 ( ( real_module_path, queue_name, module_dir) )
249+ . map_err ( |e| anyhow ! ( "Failed to get tasks: {}" , e) ) ?
250+ . cast_into :: < PyDict > ( )
251+ . map_err ( |_| anyhow ! ( "Failed to cast result to a Python Dictionary" ) ) ?;
252+
253+ let tasks: HashMap < String , Arc < Py < PyAny > > > = registry
254+ . get_item ( "tasks" ) ?
255+ . expect ( "tasks missing" )
256+ . cast :: < PyDict > ( )
257+ . map_err ( |e| anyhow ! ( "tasks is not a dict: {}" , e) ) ?
258+ . iter ( )
259+ . filter_map ( |( key, value) | {
260+ let name: String = key. extract ( ) . ok ( ) ?;
261+ let func: Py < PyAny > = value. unbind ( ) ;
262+ Some ( ( name, Arc :: new ( func) ) )
263+ } )
264+ . collect ( ) ;
265+
266+ let contexts: HashMap < String , Arc < Py < PyAny > > > = registry
267+ . get_item ( "contexts" ) ?
268+ . expect ( "contexts missing" )
269+ . cast :: < PyDict > ( )
270+ . map_err ( |e| anyhow ! ( "contexts is not a dict: {}" , e) ) ?
271+ . iter ( )
272+ . filter_map ( |( key, value) : ( Bound < PyAny > , Bound < PyAny > ) | {
273+ let name: String = key. extract ( ) . ok ( ) ?;
274+ let func: Py < PyAny > = value. unbind ( ) ;
275+ Some ( ( name, Arc :: new ( func) ) )
276+ } )
277+ . collect ( ) ;
278+
279+ Ok ( ( tasks, contexts) )
280+ } ) ?;
281+
282+ Ok ( result)
283+ }
284+
279285fn normalize_path ( path : & Path ) -> PathBuf {
280286 let mut components = Vec :: new ( ) ;
281287 for comp in path. components ( ) {
@@ -306,3 +312,84 @@ fn path_to_module_path(current_dir: &Path, target_path: &Path) -> Option<String>
306312
307313 Some ( components. join ( "." ) )
308314}
315+
316+ #[ cfg( test) ]
317+ mod tests {
318+ use super :: * ;
319+
320+ #[ test]
321+ fn test_path_to_module_path ( ) -> Result < ( ) > {
322+ let current_dir = Path :: new ( "project" ) ;
323+ let tasks_path = Path :: new ( "../project/tasks.py" ) ;
324+ let normalized_path = normalize_path ( tasks_path) ;
325+ let module_path = path_to_module_path ( current_dir, & normalized_path) ;
326+ let expected_path = Path :: new ( "project/tasks.py" ) ;
327+
328+ assert_eq ! ( normalized_path, expected_path) ;
329+ assert_eq ! ( module_path, Some ( "tasks" . to_string( ) ) ) ;
330+
331+ Ok ( ( ) )
332+ }
333+
334+ fn get_test_module_path ( filename : & str ) -> String {
335+ let current_dir = std:: env:: current_dir ( ) . unwrap ( ) ;
336+ let test_module_path = current_dir. join ( "tests" ) . join ( filename) ;
337+ test_module_path. to_str ( ) . unwrap ( ) . to_string ( )
338+ }
339+
340+ #[ test]
341+ fn test_get_task_functions_valid_module ( ) -> Result < ( ) > {
342+ let module_path_str = get_test_module_path ( "test_tasks_module.py" ) ;
343+ let ( tasks, _) = get_registry ( & module_path_str, "default" ) ?;
344+
345+ assert_eq ! ( tasks. len( ) , 3 ) ;
346+
347+ let task_names: Vec < String > = tasks. iter ( ) . map ( |( name, _) | name. clone ( ) ) . collect ( ) ;
348+ assert ! ( task_names. contains( & "task-1" . to_string( ) ) ) ;
349+ assert ! ( task_names. contains( & "task-2" . to_string( ) ) ) ;
350+ assert ! ( task_names. contains( & "async-task" . to_string( ) ) ) ;
351+
352+ assert ! ( !task_names. contains( & "high-priority-task" . to_string( ) ) ) ;
353+
354+ Ok ( ( ) )
355+ }
356+
357+ #[ test]
358+ fn test_get_task_functions_different_queue ( ) -> Result < ( ) > {
359+ let module_path_str = get_test_module_path ( "test_tasks_module.py" ) ;
360+ let ( tasks, _) = get_registry ( & module_path_str, "high-priority" ) ?;
361+
362+ let task_names: Vec < String > = tasks. iter ( ) . map ( |( name, _) | name. clone ( ) ) . collect ( ) ;
363+ assert_eq ! ( tasks. len( ) , 1 ) ;
364+ assert ! ( task_names. contains( & "high-priority-task" . to_string( ) ) ) ;
365+
366+ Ok ( ( ) )
367+ }
368+
369+ #[ test]
370+ fn test_get_task_functions_empty_module ( ) -> Result < ( ) > {
371+ let module_path_str = get_test_module_path ( "test_tasks_empty.py" ) ;
372+ let ( tasks, _) = get_registry ( & module_path_str, "default" ) ?;
373+
374+ assert_eq ! ( tasks. len( ) , 0 ) ;
375+
376+ Ok ( ( ) )
377+ }
378+
379+ #[ test]
380+ fn test_get_task_functions_duplicate_names ( ) {
381+ let module_path_str = get_test_module_path ( "test_tasks_duplicate.py" ) ;
382+
383+ let result = get_registry ( & module_path_str, "default" ) ;
384+ assert ! ( result. is_err( ) ) ;
385+
386+ let error_msg = result. unwrap_err ( ) . to_string ( ) ;
387+ assert ! ( error_msg. contains( "duplicated" ) || error_msg. contains( "duplicate" ) ) ;
388+ }
389+
390+ #[ test]
391+ fn test_get_task_functions_invalid_path ( ) {
392+ let result = get_registry ( "nonexistent/path/to/module.py" , "default" ) ;
393+ assert ! ( result. is_err( ) ) ;
394+ }
395+ }
0 commit comments