Skip to content

Commit d420f31

Browse files
committed
Start working on context in the worker
1 parent ee13e3e commit d420f31

7 files changed

Lines changed: 344 additions & 248 deletions

File tree

crates/fluxqueue-worker/scripts/get_functions.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import importlib
2+
import inspect
3+
import sys
4+
from pathlib import Path
5+
6+
7+
def get_registry(module_path: str, queue: str, module_dir: str | None = None):
8+
if module_dir:
9+
module_dir_path = Path(module_dir).resolve()
10+
if str(module_dir_path) not in sys.path:
11+
sys.path.insert(0, str(module_dir_path))
12+
13+
module = importlib.import_module(module_path)
14+
registry = {"tasks": {}, "contexts": {}}
15+
for _name, obj in inspect.getmembers(module):
16+
if inspect.isfunction(obj):
17+
task_name = getattr(obj, "task_name", None)
18+
task_queue = getattr(obj, "queue", None)
19+
if not task_queue or task_queue != queue:
20+
continue
21+
22+
if registry["tasks"].get(task_name):
23+
raise ValueError(f"Task '{task_name}' is duplicated")
24+
25+
original_func = getattr(obj, "__wrapped__", obj)
26+
registry["tasks"][task_name] = original_func
27+
elif inspect.isclass(obj):
28+
context_name = getattr(obj, "__fluxqueue_context__", None)
29+
if not context_name:
30+
continue
31+
32+
if registry["contexts"].get(context_name):
33+
raise ValueError(f"Context '{context_name}' is duplicated")
34+
35+
registry["contexts"][context_name] = obj
36+
37+
return registry

crates/fluxqueue-worker/src/logger.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ pub fn initial_logs(
3434
concurrency: usize,
3535
redis_url: &str,
3636
tasks_module_path: &str,
37-
tasks: &Vec<&String>,
37+
tasks: Vec<String>,
38+
contexts: Vec<String>,
3839
) {
3940
info!("Queue: {}", queue_name);
4041
info!("Concurrency: {}", concurrency);
4142
info!("Redis: {}", redis_url);
4243
info!("Tasks module: {}", tasks_module_path);
4344
info!("Tasks found: {:?}", tasks);
45+
info!("Contexts found: {:?}", contexts);
4446
info!("Starting up the executors...");
4547
}
4648

crates/fluxqueue-worker/src/redis_client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl RedisClient {
3333
Ok(())
3434
}
3535

36-
pub async fn set_executors_heartbeat(&self, executor_ids: Arc<Vec<Arc<str>>>) -> Result<()> {
36+
pub async fn set_executors_heartbeat(&self, executor_ids: Arc<Vec<Arc<String>>>) -> Result<()> {
3737
for id in executor_ids.iter() {
3838
self.set_executor_heartbeat(id).await?;
3939
}
@@ -60,7 +60,7 @@ impl RedisClient {
6060
pub async fn cleanup_executors_registry(
6161
&self,
6262
queue_name: &str,
63-
ids: Arc<Vec<Arc<str>>>,
63+
ids: Arc<Vec<Arc<String>>>,
6464
) -> Result<()> {
6565
let mut conn = self.redis_pool.get().await?;
6666
let executors_key = keys::get_executors_key(queue_name);

crates/fluxqueue-worker/src/task.rs

Lines changed: 199 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,120 @@
11
use anyhow::{Context, Result, anyhow};
22
use pyo3::exceptions::PyRuntimeError;
33
use pyo3::prelude::*;
4-
use pyo3::types::{PyDict, PyList, PyTuple};
4+
use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyModule, PyTuple};
55
use pyo3_async_runtimes::tokio::into_future;
66
use pythonize::pythonize;
77
use rmp_serde::from_slice;
88
use rmpv::Value;
99
use std::collections::HashMap;
10+
use std::ffi::CString;
11+
use std::path::{Path, PathBuf};
12+
use std::sync::atomic::{AtomicUsize, Ordering};
1013
use std::sync::{Arc, RwLock};
14+
use std::time::Instant;
1115
use tokio::sync::{mpsc, oneshot};
1216

17+
use crate::logger::Logger;
18+
19+
#[derive(Debug)]
1320
pub struct TaskRegistry {
1421
tasks: Arc<RwLock<HashMap<String, Arc<Py<PyAny>>>>>,
22+
contexts: Arc<RwLock<HashMap<String, Arc<Py<PyAny>>>>>,
1523
}
1624

1725
impl TaskRegistry {
18-
pub fn new() -> Self {
19-
Self {
20-
tasks: Arc::new(RwLock::new(HashMap::new())),
26+
pub fn new(module_path: &str, queue_name: &str) -> Result<Self> {
27+
let script = include_str!("../scripts/get_registry.py");
28+
let script_cstr = CString::new(script)?;
29+
let filename = CString::new("get_registry.py")?;
30+
let module_name = CString::new("get_registry")?;
31+
32+
let full_current_dir = std::env::current_dir().unwrap();
33+
let full_module_path = full_current_dir.join(module_path);
34+
let clean_module_path = normalize_path(&full_module_path);
35+
let project_root = full_current_dir
36+
.ancestors()
37+
.find(|p| p.join("tests").exists())
38+
.unwrap_or(&full_current_dir);
39+
let real_module_path = path_to_module_path(project_root, &clean_module_path);
40+
41+
if !clean_module_path.exists() || real_module_path.is_none() {
42+
return Err(anyhow!(
43+
"Tasks module path {:?} doesn't exist.",
44+
clean_module_path
45+
));
2146
}
47+
48+
let real_module_path = real_module_path.unwrap();
49+
let module_dir = project_root.to_string_lossy().to_string();
50+
51+
let (tasks, contexts) = Python::attach(
52+
|py| -> Result<(
53+
HashMap<String, Arc<Py<PyAny>>>,
54+
HashMap<String, Arc<Py<PyAny>>>,
55+
)> {
56+
let module = PyModule::from_code(
57+
py,
58+
script_cstr.as_c_str(),
59+
filename.as_c_str(),
60+
module_name.as_c_str(),
61+
)
62+
.map_err(|e| anyhow!("Failed to import python module: {}", e))?;
63+
64+
let registry: Bound<'_, PyDict> = module
65+
.getattr("get_registry")
66+
.map_err(|e| anyhow!("Failed to get 'get_registry' script: {}", e))?
67+
.call1((real_module_path, queue_name, module_dir))
68+
.map_err(|e| anyhow!("Failed to get tasks: {}", e))?
69+
.cast_into::<PyDict>()
70+
.map_err(|_| anyhow!("Failed to cast result to a Python Dictionary"))?;
71+
72+
let tasks: HashMap<String, Arc<Py<PyAny>>> = registry
73+
.get_item("tasks")?
74+
.expect("tasks missing")
75+
.cast::<PyDict>()
76+
.map_err(|e| anyhow!("tasks is not a dict: {}", e))?
77+
.iter()
78+
.filter_map(|(key, value)| {
79+
let name: String = key.extract().ok()?;
80+
let func: Py<PyAny> = value.unbind();
81+
Some((name, Arc::new(func)))
82+
})
83+
.collect();
84+
85+
let contexts: HashMap<String, Arc<Py<PyAny>>> = registry
86+
.get_item("contexts")?
87+
.expect("contexts missing")
88+
.cast::<PyDict>()
89+
.map_err(|e| anyhow!("contexts is not a dict: {}", e))?
90+
.iter()
91+
.filter_map(|(key, value): (Bound<PyAny>, Bound<PyAny>)| {
92+
let name: String = key.extract().ok()?;
93+
let func: Py<PyAny> = value.unbind();
94+
Some((name, Arc::new(func)))
95+
})
96+
.collect();
97+
98+
Ok((tasks, contexts))
99+
},
100+
)?;
101+
102+
Ok(Self {
103+
tasks: Arc::new(RwLock::new(tasks)),
104+
contexts: Arc::new(RwLock::new(contexts)),
105+
})
22106
}
23107

24-
pub fn insert(&self, name: String, func: Py<PyAny>) -> Result<()> {
25-
let mut tasks = self.tasks.write().map_err(|_| {
26-
anyhow!("Internal Error: Task registry lock poisoned (a thread panicked)")
27-
})?;
28-
tasks.insert(name, Arc::new(func));
29-
Ok(())
108+
pub fn get_registered_tasks(&self) -> Result<Vec<String>> {
109+
let tasks = self.tasks.read().map_err(|e| anyhow!(e.to_string()))?;
110+
let task_names: Vec<_> = tasks.iter().map(|t| t.0.to_string()).collect();
111+
Ok(task_names)
112+
}
113+
114+
pub fn get_registered_contexts(&self) -> Result<Vec<String>> {
115+
let contexts = self.contexts.read().map_err(|e| anyhow!(e.to_string()))?;
116+
let context_names: Vec<_> = contexts.iter().map(|t| t.0.to_string()).collect();
117+
Ok(context_names)
30118
}
31119

32120
pub fn get(&self, name: &str) -> Option<Arc<Py<PyAny>>> {
@@ -36,6 +124,7 @@ impl TaskRegistry {
36124
}
37125

38126
struct TaskRequest {
127+
executor_id: Arc<String>,
39128
func: Arc<Py<PyAny>>,
40129
task_name: Arc<String>,
41130
raw_args: Arc<Vec<u8>>,
@@ -54,9 +143,15 @@ impl PythonDispatcher {
54143

55144
let dispatcher = async move {
56145
while let Some(req) = rx.recv().await {
57-
run_task(req.func, req.task_name, req.raw_args, req.raw_kwargs)
58-
.await
59-
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
146+
run_task(
147+
req.executor_id,
148+
req.func,
149+
req.task_name,
150+
req.raw_args,
151+
req.raw_kwargs,
152+
)
153+
.await
154+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
60155

61156
let _ = req.resp_tx.send(Ok(()));
62157
}
@@ -74,6 +169,7 @@ impl PythonDispatcher {
74169

75170
pub async fn execute(
76171
&self,
172+
executor_id: Arc<String>,
77173
func: Arc<Py<PyAny>>,
78174
task_name: Arc<String>,
79175
raw_args: Arc<Vec<u8>>,
@@ -83,6 +179,7 @@ impl PythonDispatcher {
83179

84180
self.tx
85181
.send(TaskRequest {
182+
executor_id,
86183
func,
87184
task_name,
88185
raw_args,
@@ -98,11 +195,15 @@ impl PythonDispatcher {
98195
}
99196

100197
async fn run_task(
198+
executor_id: Arc<String>,
101199
task_function: Arc<Py<PyAny>>,
102200
task_name: Arc<String>,
103201
raw_args: Arc<Vec<u8>>,
104202
raw_kwargs: Arc<Vec<u8>>,
105203
) -> Result<()> {
204+
let logger = Logger::new(format!("EXECUTOR {}", &executor_id));
205+
let duration_start = Instant::now();
206+
106207
let task_args: Value = from_slice(&raw_args).context(format!(
107208
"Failed to deserialize task '{}' function args",
108209
&task_name
@@ -140,14 +241,99 @@ async fn run_task(
140241
if is_coroutine {
141242
Ok(Some(result))
142243
} else {
244+
let duration_end = duration_start.elapsed();
245+
logger.info(format_args!(
246+
"Task '{}' successfully finished in {}ms",
247+
&task_name,
248+
duration_end.as_millis()
249+
));
250+
143251
Ok(None)
144252
}
253+
})
254+
.map_err(|e| {
255+
let duration_end = duration_start.elapsed();
256+
logger.error(format_args!(
257+
"Task '{}' failed in {}ms: {}",
258+
&task_name,
259+
duration_end.as_millis(),
260+
e
261+
));
262+
anyhow!(e.to_string())
145263
})?;
146264

147265
if let Some(coro) = maybe_coro {
148266
let fut = Python::attach(|py| into_future(coro.into_bound(py)))?;
149267
fut.await?;
268+
269+
let duration_end = duration_start.elapsed();
270+
logger.info(format_args!(
271+
"Task '{}' successfully finished in {}ms",
272+
&task_name,
273+
duration_end.as_millis()
274+
));
150275
}
151276

152277
Ok(())
153278
}
279+
280+
pub struct DispatcherPool {
281+
dispatchers: Vec<Arc<PythonDispatcher>>,
282+
index: AtomicUsize,
283+
}
284+
285+
impl DispatcherPool {
286+
pub fn new(concurrency: usize) -> Result<Self> {
287+
// let logical_cores = num_cpus::get();
288+
// let pool_size = (concurrency / 4)
289+
// .max(1)
290+
// .min(logical_cores * 2)
291+
// .max(logical_cores);
292+
293+
let mut dispatchers = Vec::with_capacity(concurrency);
294+
for _ in 0..concurrency {
295+
dispatchers.push(Arc::new(PythonDispatcher::new()?));
296+
}
297+
298+
Ok(Self {
299+
dispatchers,
300+
index: AtomicUsize::new(0),
301+
})
302+
}
303+
304+
pub fn get(&self) -> Arc<PythonDispatcher> {
305+
let idx = self.index.fetch_add(1, Ordering::Relaxed) % self.dispatchers.len();
306+
self.dispatchers[idx].clone()
307+
}
308+
}
309+
310+
fn normalize_path(path: &Path) -> PathBuf {
311+
let mut components = Vec::new();
312+
for comp in path.components() {
313+
match comp {
314+
std::path::Component::ParentDir => {
315+
components.pop();
316+
}
317+
std::path::Component::CurDir => {}
318+
other => components.push(other),
319+
}
320+
}
321+
components.iter().collect()
322+
}
323+
324+
fn path_to_module_path(current_dir: &Path, target_path: &Path) -> Option<String> {
325+
let rel_path = target_path.strip_prefix(current_dir).ok()?;
326+
327+
let mut components: Vec<String> = rel_path
328+
.components()
329+
.map(|c| c.as_os_str().to_string_lossy().to_string())
330+
.collect();
331+
332+
if let Some(last) = components.last_mut()
333+
&& let Some(pos) = last.rfind('.')
334+
{
335+
last.truncate(pos);
336+
}
337+
338+
Some(components.join("."))
339+
}

0 commit comments

Comments
 (0)