Skip to content

Commit a025ced

Browse files
committed
Add tests for tasks
1 parent 0ad34f9 commit a025ced

1 file changed

Lines changed: 161 additions & 74 deletions

File tree

crates/fluxqueue-worker/src/task.rs

Lines changed: 161 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ use tokio::sync::{mpsc, oneshot};
1515

1616
use crate::logger::Logger;
1717

18-
type TasksAndContexts = (
19-
HashMap<String, Arc<Py<PyAny>>>,
20-
HashMap<String, Arc<Py<PyAny>>>,
21-
);
22-
2318
#[derive(Debug)]
2419
pub struct TaskRegistry {
2520
tasks: Arc<RwLock<HashMap<String, Arc<Py<PyAny>>>>>,
@@ -28,75 +23,7 @@ pub struct TaskRegistry {
2823

2924
impl TaskRegistry {
3025
pub fn new(module_path: &str, queue_name: &str) -> Result<Self> {
31-
let script = include_str!("../scripts/get_registry.py");
32-
let script_cstr = CString::new(script)?;
33-
let filename = CString::new("get_registry.py")?;
34-
let module_name = CString::new("get_registry")?;
35-
36-
let full_current_dir = std::env::current_dir().unwrap();
37-
let full_module_path = full_current_dir.join(module_path);
38-
let clean_module_path = normalize_path(&full_module_path);
39-
let project_root = full_current_dir
40-
.ancestors()
41-
.find(|p| p.join("tests").exists())
42-
.unwrap_or(&full_current_dir);
43-
let real_module_path = path_to_module_path(project_root, &clean_module_path);
44-
45-
if !clean_module_path.exists() || real_module_path.is_none() {
46-
return Err(anyhow!(
47-
"Tasks module path {:?} doesn't exist.",
48-
clean_module_path
49-
));
50-
}
51-
52-
let real_module_path = real_module_path.unwrap();
53-
let module_dir = project_root.to_string_lossy().to_string();
54-
55-
let (tasks, contexts) = Python::attach(|py| -> Result<TasksAndContexts> {
56-
let module = PyModule::from_code(
57-
py,
58-
script_cstr.as_c_str(),
59-
filename.as_c_str(),
60-
module_name.as_c_str(),
61-
)
62-
.map_err(|e| anyhow!("Failed to import python module: {}", e))?;
63-
64-
let registry: Bound<'_, PyDict> = module
65-
.getattr("get_registry")
66-
.map_err(|e| anyhow!("Failed to get 'get_registry' script: {}", e))?
67-
.call1((real_module_path, queue_name, module_dir))
68-
.map_err(|e| anyhow!("Failed to get tasks: {}", e))?
69-
.cast_into::<PyDict>()
70-
.map_err(|_| anyhow!("Failed to cast result to a Python Dictionary"))?;
71-
72-
let tasks: HashMap<String, Arc<Py<PyAny>>> = registry
73-
.get_item("tasks")?
74-
.expect("tasks missing")
75-
.cast::<PyDict>()
76-
.map_err(|e| anyhow!("tasks is not a dict: {}", e))?
77-
.iter()
78-
.filter_map(|(key, value)| {
79-
let name: String = key.extract().ok()?;
80-
let func: Py<PyAny> = value.unbind();
81-
Some((name, Arc::new(func)))
82-
})
83-
.collect();
84-
85-
let contexts: HashMap<String, Arc<Py<PyAny>>> = registry
86-
.get_item("contexts")?
87-
.expect("contexts missing")
88-
.cast::<PyDict>()
89-
.map_err(|e| anyhow!("contexts is not a dict: {}", e))?
90-
.iter()
91-
.filter_map(|(key, value): (Bound<PyAny>, Bound<PyAny>)| {
92-
let name: String = key.extract().ok()?;
93-
let func: Py<PyAny> = value.unbind();
94-
Some((name, Arc::new(func)))
95-
})
96-
.collect();
97-
98-
Ok((tasks, contexts))
99-
})?;
26+
let (tasks, contexts) = get_registry(module_path, queue_name)?;
10027

10128
Ok(Self {
10229
tasks: Arc::new(RwLock::new(tasks)),
@@ -276,6 +203,85 @@ async fn run_task(
276203
Ok(())
277204
}
278205

206+
type TasksAndContexts = (
207+
HashMap<String, Arc<Py<PyAny>>>,
208+
HashMap<String, Arc<Py<PyAny>>>,
209+
);
210+
211+
fn get_registry(module_path: &str, queue_name: &str) -> Result<TasksAndContexts> {
212+
let script = include_str!("../scripts/get_registry.py");
213+
let script_cstr = CString::new(script)?;
214+
let filename = CString::new("get_registry.py")?;
215+
let module_name = CString::new("get_registry")?;
216+
217+
let full_current_dir = std::env::current_dir().unwrap();
218+
let full_module_path = full_current_dir.join(module_path);
219+
let clean_module_path = normalize_path(&full_module_path);
220+
let project_root = full_current_dir
221+
.ancestors()
222+
.find(|p| p.join("tests").exists())
223+
.unwrap_or(&full_current_dir);
224+
let real_module_path = path_to_module_path(project_root, &clean_module_path);
225+
226+
if !clean_module_path.exists() || real_module_path.is_none() {
227+
return Err(anyhow!(
228+
"Tasks module path {:?} doesn't exist.",
229+
clean_module_path
230+
));
231+
}
232+
233+
let real_module_path = real_module_path.unwrap();
234+
let module_dir = project_root.to_string_lossy().to_string();
235+
236+
let result = Python::attach(|py| -> Result<TasksAndContexts> {
237+
let module = PyModule::from_code(
238+
py,
239+
script_cstr.as_c_str(),
240+
filename.as_c_str(),
241+
module_name.as_c_str(),
242+
)
243+
.map_err(|e| anyhow!("Failed to import python module: {}", e))?;
244+
245+
let registry: Bound<'_, PyDict> = module
246+
.getattr("get_registry")
247+
.map_err(|e| anyhow!("Failed to get 'get_registry' script: {}", e))?
248+
.call1((real_module_path, queue_name, module_dir))
249+
.map_err(|e| anyhow!("Failed to get tasks: {}", e))?
250+
.cast_into::<PyDict>()
251+
.map_err(|_| anyhow!("Failed to cast result to a Python Dictionary"))?;
252+
253+
let tasks: HashMap<String, Arc<Py<PyAny>>> = registry
254+
.get_item("tasks")?
255+
.expect("tasks missing")
256+
.cast::<PyDict>()
257+
.map_err(|e| anyhow!("tasks is not a dict: {}", e))?
258+
.iter()
259+
.filter_map(|(key, value)| {
260+
let name: String = key.extract().ok()?;
261+
let func: Py<PyAny> = value.unbind();
262+
Some((name, Arc::new(func)))
263+
})
264+
.collect();
265+
266+
let contexts: HashMap<String, Arc<Py<PyAny>>> = registry
267+
.get_item("contexts")?
268+
.expect("contexts missing")
269+
.cast::<PyDict>()
270+
.map_err(|e| anyhow!("contexts is not a dict: {}", e))?
271+
.iter()
272+
.filter_map(|(key, value): (Bound<PyAny>, Bound<PyAny>)| {
273+
let name: String = key.extract().ok()?;
274+
let func: Py<PyAny> = value.unbind();
275+
Some((name, Arc::new(func)))
276+
})
277+
.collect();
278+
279+
Ok((tasks, contexts))
280+
})?;
281+
282+
Ok(result)
283+
}
284+
279285
fn normalize_path(path: &Path) -> PathBuf {
280286
let mut components = Vec::new();
281287
for comp in path.components() {
@@ -306,3 +312,84 @@ fn path_to_module_path(current_dir: &Path, target_path: &Path) -> Option<String>
306312

307313
Some(components.join("."))
308314
}
315+
316+
#[cfg(test)]
317+
mod tests {
318+
use super::*;
319+
320+
#[test]
321+
fn test_path_to_module_path() -> Result<()> {
322+
let current_dir = Path::new("project");
323+
let tasks_path = Path::new("../project/tasks.py");
324+
let normalized_path = normalize_path(tasks_path);
325+
let module_path = path_to_module_path(current_dir, &normalized_path);
326+
let expected_path = Path::new("project/tasks.py");
327+
328+
assert_eq!(normalized_path, expected_path);
329+
assert_eq!(module_path, Some("tasks".to_string()));
330+
331+
Ok(())
332+
}
333+
334+
fn get_test_module_path(filename: &str) -> String {
335+
let current_dir = std::env::current_dir().unwrap();
336+
let test_module_path = current_dir.join("tests").join(filename);
337+
test_module_path.to_str().unwrap().to_string()
338+
}
339+
340+
#[test]
341+
fn test_get_task_functions_valid_module() -> Result<()> {
342+
let module_path_str = get_test_module_path("test_tasks_module.py");
343+
let (tasks, _) = get_registry(&module_path_str, "default")?;
344+
345+
assert_eq!(tasks.len(), 3);
346+
347+
let task_names: Vec<String> = tasks.iter().map(|(name, _)| name.clone()).collect();
348+
assert!(task_names.contains(&"task-1".to_string()));
349+
assert!(task_names.contains(&"task-2".to_string()));
350+
assert!(task_names.contains(&"async-task".to_string()));
351+
352+
assert!(!task_names.contains(&"high-priority-task".to_string()));
353+
354+
Ok(())
355+
}
356+
357+
#[test]
358+
fn test_get_task_functions_different_queue() -> Result<()> {
359+
let module_path_str = get_test_module_path("test_tasks_module.py");
360+
let (tasks, _) = get_registry(&module_path_str, "high-priority")?;
361+
362+
let task_names: Vec<String> = tasks.iter().map(|(name, _)| name.clone()).collect();
363+
assert_eq!(tasks.len(), 1);
364+
assert!(task_names.contains(&"high-priority-task".to_string()));
365+
366+
Ok(())
367+
}
368+
369+
#[test]
370+
fn test_get_task_functions_empty_module() -> Result<()> {
371+
let module_path_str = get_test_module_path("test_tasks_empty.py");
372+
let (tasks, _) = get_registry(&module_path_str, "default")?;
373+
374+
assert_eq!(tasks.len(), 0);
375+
376+
Ok(())
377+
}
378+
379+
#[test]
380+
fn test_get_task_functions_duplicate_names() {
381+
let module_path_str = get_test_module_path("test_tasks_duplicate.py");
382+
383+
let result = get_registry(&module_path_str, "default");
384+
assert!(result.is_err());
385+
386+
let error_msg = result.unwrap_err().to_string();
387+
assert!(error_msg.contains("duplicated") || error_msg.contains("duplicate"));
388+
}
389+
390+
#[test]
391+
fn test_get_task_functions_invalid_path() {
392+
let result = get_registry("nonexistent/path/to/module.py", "default");
393+
assert!(result.is_err());
394+
}
395+
}

0 commit comments

Comments
 (0)