Skip to content

Commit 9a4d4b4

Browse files
authored
refactor: Move function calls to dispatcher (#116)
1 parent 8cd20cd commit 9a4d4b4

2 files changed

Lines changed: 107 additions & 74 deletions

File tree

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
use anyhow::{Result, anyhow};
1+
use anyhow::{Context, Result, anyhow};
2+
use pyo3::exceptions::PyRuntimeError;
23
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyList, PyTuple};
35
use pyo3_async_runtimes::tokio::into_future;
6+
use pythonize::pythonize;
7+
use rmp_serde::from_slice;
8+
use rmpv::Value;
49
use std::collections::HashMap;
5-
use std::pin::Pin;
610
use std::sync::{Arc, RwLock};
711
use tokio::sync::{mpsc, oneshot};
812

@@ -19,7 +23,7 @@ impl TaskRegistry {
1923

2024
pub fn insert(&self, name: String, func: Py<PyAny>) -> Result<()> {
2125
let mut tasks = self.tasks.write().map_err(|_| {
22-
anyhow::anyhow!("Internal Error: Task registry lock poisoned (a thread panicked)")
26+
anyhow!("Internal Error: Task registry lock poisoned (a thread panicked)")
2327
})?;
2428
tasks.insert(name, Arc::new(func));
2529
Ok(())
@@ -31,51 +35,119 @@ impl TaskRegistry {
3135
}
3236
}
3337

34-
type PyResponse = Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send>>;
35-
3638
struct TaskRequest {
37-
func: Py<PyAny>,
38-
resp_tx: oneshot::Sender<PyResult<PyResponse>>,
39+
func: Arc<Py<PyAny>>,
40+
task_name: Arc<String>,
41+
raw_args: Arc<Vec<u8>>,
42+
raw_kwargs: Arc<Vec<u8>>,
43+
resp_tx: oneshot::Sender<Result<()>>,
3944
}
4045

4146
pub struct PythonDispatcher {
4247
tx: mpsc::Sender<TaskRequest>,
4348
}
4449

4550
impl PythonDispatcher {
46-
pub fn new() -> Self {
51+
pub fn new() -> Result<Self> {
4752
let logical_cores = num_cpus::get();
4853
let (tx, mut rx) = mpsc::channel::<TaskRequest>(logical_cores * 2);
4954

50-
std::thread::spawn(move || {
51-
Python::attach(|py| {
52-
let dispatcher = async move {
53-
while let Some(req) = rx.recv().await {
54-
let res = Python::attach(|py| into_future(req.func.into_bound(py)));
55+
let dispatcher = async move {
56+
while let Some(req) = rx.recv().await {
57+
run_task(req.func, req.task_name, req.raw_args, req.raw_kwargs)
58+
.await
59+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
5560

56-
let _ = req.resp_tx.send(res.map(|f| Box::pin(f) as PyResponse));
57-
}
58-
Ok(())
59-
};
61+
let _ = req.resp_tx.send(Ok(()));
62+
}
63+
Ok(())
64+
};
6065

66+
tokio::task::spawn_blocking(move || {
67+
Python::attach(|py| {
6168
pyo3_async_runtimes::tokio::run(py, dispatcher).expect("Python loop failed");
6269
});
6370
});
6471

65-
Self { tx }
72+
Ok(Self { tx })
6673
}
6774

68-
pub async fn execute(&self, func: Py<PyAny>) -> Result<Py<PyAny>> {
75+
pub async fn execute(
76+
&self,
77+
func: Arc<Py<PyAny>>,
78+
task_name: Arc<String>,
79+
raw_args: Arc<Vec<u8>>,
80+
raw_kwargs: Arc<Vec<u8>>,
81+
) -> Result<()> {
6982
let (resp_tx, resp_rx) = oneshot::channel();
7083

7184
self.tx
72-
.send(TaskRequest { func, resp_tx })
85+
.send(TaskRequest {
86+
func,
87+
task_name,
88+
raw_args,
89+
raw_kwargs,
90+
resp_tx,
91+
})
7392
.await
7493
.map_err(|_| anyhow!("Dispatcher channel closed"))?;
7594

76-
let py_fut = resp_rx.await??;
95+
resp_rx.await??;
96+
Ok(())
97+
}
98+
}
99+
100+
async fn run_task(
101+
task_function: Arc<Py<PyAny>>,
102+
task_name: Arc<String>,
103+
raw_args: Arc<Vec<u8>>,
104+
raw_kwargs: Arc<Vec<u8>>,
105+
) -> Result<()> {
106+
let task_args: Value = from_slice(&raw_args).context(format!(
107+
"Failed to deserialize task '{}' function args",
108+
&task_name
109+
))?;
110+
let task_kwargs: Value = from_slice(&raw_kwargs).context(format!(
111+
"Failed to deserialize task '{}' function kwargs",
112+
&task_name
113+
))?;
114+
115+
let maybe_coro = Python::attach(|py| -> Result<Option<Py<PyAny>>> {
116+
let py_args = pythonize(py, &task_args).context("Failed to pythonize args")?;
117+
let py_kwargs = pythonize(py, &task_kwargs).context("Failed to pythonize kwargs")?;
118+
119+
let args_tuple = if let Ok(list) = py_args.cast::<PyList>() {
120+
list.to_tuple()
121+
} else if let Ok(tuple) = py_args.cast::<PyTuple>() {
122+
tuple.clone()
123+
} else {
124+
anyhow::bail!("Args must be an array/tuple, found {}", py_args.get_type());
125+
};
126+
127+
let kwargs_dict = py_kwargs
128+
.cast_into::<PyDict>()
129+
.map_err(|_| anyhow!("Kwargs must be a map/dict"))?;
130+
131+
let result = task_function
132+
.call(py, args_tuple, Some(&kwargs_dict))
133+
.map_err(|e| anyhow!("Failed to call Python function: {:?}", e))?;
134+
135+
let bound_result = result.bind(py);
136+
let is_coroutine = bound_result
137+
.hasattr("__await__")
138+
.map_err(|_| anyhow!("Failed to check if result is awaitable"))?;
139+
140+
if is_coroutine {
141+
Ok(Some(result))
142+
} else {
143+
Ok(None)
144+
}
145+
})?;
77146

78-
let result = py_fut.await?;
79-
Ok(result)
147+
if let Some(coro) = maybe_coro {
148+
let fut = Python::attach(|py| into_future(coro.into_bound(py)))?;
149+
fut.await?;
80150
}
151+
152+
Ok(())
81153
}

crates/fluxqueue-worker/src/worker.rs

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
use anyhow::{Context, Result, anyhow};
2-
use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyModule, PyTuple};
1+
use anyhow::{Result, anyhow};
2+
use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods, PyModule};
33
use pyo3::{Bound, Py, PyAny, Python};
4-
use pythonize::pythonize;
5-
use rmp_serde::from_slice;
64
use std::ffi::CString;
75
use std::path::{Path, PathBuf};
86
use std::sync::Arc;
@@ -58,7 +56,7 @@ pub async fn run_worker(
5856
let executor_id = Arc::clone(&executor_ids[i]);
5957
let shutdown = shutdown.clone();
6058
let task_registry = Arc::clone(&task_registry);
61-
let python_dispatcher = Arc::new(PythonDispatcher::new());
59+
let python_dispatcher = Arc::new(PythonDispatcher::new()?);
6260

6361
redis_client
6462
.register_executor(&queue_name, &executor_id)
@@ -276,50 +274,13 @@ async fn run_task(
276274
task: &Task,
277275
task_function: Arc<Py<PyAny>>,
278276
) -> Result<()> {
279-
let task_args: rmpv::Value = from_slice(&task.args).context(format!(
280-
"Failed to deserialize task '{}' function args",
281-
task.name
282-
))?;
283-
let task_kwargs: rmpv::Value = from_slice(&task.kwargs).context(format!(
284-
"Failed to deserialize task '{}' function kwargs",
285-
task.name
286-
))?;
287-
288-
let maybe_coro = Python::attach(|py| -> Result<Option<Py<PyAny>>> {
289-
let py_args = pythonize(py, &task_args).context("Failed to pythonize args")?;
290-
let py_kwargs = pythonize(py, &task_kwargs).context("Failed to pythonize kwargs")?;
291-
292-
let args_tuple = if let Ok(list) = py_args.cast::<PyList>() {
293-
list.to_tuple()
294-
} else if let Ok(tuple) = py_args.cast::<PyTuple>() {
295-
tuple.clone()
296-
} else {
297-
anyhow::bail!("Args must be an array/tuple, found {}", py_args.get_type());
298-
};
299-
300-
let kwargs_dict = py_kwargs
301-
.cast_into::<PyDict>()
302-
.map_err(|_| anyhow!("Kwargs must be a map/dict"))?;
303-
304-
let result = task_function
305-
.call(py, args_tuple, Some(&kwargs_dict))
306-
.map_err(|e| anyhow!("Failed to call Python function: {:?}", e))?;
307-
308-
let bound_result = result.bind(py);
309-
let is_coroutine = bound_result
310-
.hasattr("__await__")
311-
.map_err(|_| anyhow!("Failed to check if result is awaitable"))?;
277+
let task_name = Arc::new(task.name.clone());
278+
let raw_args = Arc::new(task.args.clone());
279+
let raw_kwargs = Arc::new(task.kwargs.clone());
312280

313-
if is_coroutine {
314-
Ok(Some(result))
315-
} else {
316-
Ok(None)
317-
}
318-
})?;
319-
320-
if let Some(coro) = maybe_coro {
321-
python_dispatcher.execute(coro).await?;
322-
}
281+
python_dispatcher
282+
.execute(task_function, task_name, raw_args, raw_kwargs)
283+
.await?;
323284

324285
Ok(())
325286
}
@@ -356,7 +317,7 @@ fn get_task_functions(module_path: &str, queue_name: &str) -> Result<Vec<(String
356317
filename.as_c_str(),
357318
module_name.as_c_str(),
358319
)
359-
.context("Failed to import python module")?;
320+
.map_err(|e| anyhow!("Failed to import python module: {}", e))?;
360321

361322
let py_funcs: Bound<'_, PyDict> = module
362323
.getattr("list_functions")
@@ -505,7 +466,7 @@ mod tests {
505466
#[tokio::test]
506467
async fn test_run_task_with_sync_function() -> Result<()> {
507468
let task_registry = TaskRegistry::new();
508-
let python_dispatcher = Arc::new(PythonDispatcher::new());
469+
let python_dispatcher = Arc::new(PythonDispatcher::new()?);
509470
let module_path_str = get_test_module_path("test_tasks_module.py");
510471
let task_functions = get_task_functions(&module_path_str, "default")?;
511472

@@ -537,7 +498,7 @@ mod tests {
537498
#[tokio::test]
538499
async fn test_run_task_with_async_function() -> Result<()> {
539500
let task_registry = TaskRegistry::new();
540-
let python_dispatcher = Arc::new(PythonDispatcher::new());
501+
let python_dispatcher = Arc::new(PythonDispatcher::new()?);
541502
let module_path_str = get_test_module_path("test_tasks_module.py");
542503
let task_functions = get_task_functions(&module_path_str, "default")?;
543504

0 commit comments

Comments
 (0)