Skip to content

Commit acdb468

Browse files
committed
Add task metadata feat
1 parent 180aa29 commit acdb468

6 files changed

Lines changed: 153 additions & 56 deletions

File tree

crates/fluxqueue-worker/src/task.rs

Lines changed: 86 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use anyhow::{Context, Result, anyhow};
2+
use fluxqueue_common::Task;
23
use pyo3::exceptions::PyRuntimeError;
34
use pyo3::prelude::*;
45
use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyModule, PyTuple};
@@ -112,8 +113,7 @@ struct TaskRequest {
112113
executor_id: Arc<String>,
113114
task_data: Arc<TaskData>,
114115
task_name: Arc<String>,
115-
raw_args: Arc<Vec<u8>>,
116-
raw_kwargs: Arc<Vec<u8>>,
116+
task: Arc<Task>,
117117
resp_tx: oneshot::Sender<Result<()>>,
118118
}
119119

@@ -134,8 +134,7 @@ impl PythonDispatcher {
134134
req.executor_id,
135135
req.task_data,
136136
req.task_name,
137-
req.raw_args,
138-
req.raw_kwargs,
137+
req.task,
139138
)
140139
.await
141140
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
@@ -159,8 +158,7 @@ impl PythonDispatcher {
159158
executor_id: Arc<String>,
160159
task_data: Arc<TaskData>,
161160
task_name: Arc<String>,
162-
raw_args: Arc<Vec<u8>>,
163-
raw_kwargs: Arc<Vec<u8>>,
161+
task: Arc<Task>,
164162
) -> Result<()> {
165163
let (resp_tx, resp_rx) = oneshot::channel();
166164

@@ -170,8 +168,7 @@ impl PythonDispatcher {
170168
executor_id,
171169
task_data,
172170
task_name,
173-
raw_args,
174-
raw_kwargs,
171+
task,
175172
resp_tx,
176173
})
177174
.await
@@ -182,29 +179,39 @@ impl PythonDispatcher {
182179
}
183180
}
184181

182+
struct CoroWithContext {
183+
args: Py<PyTuple>,
184+
kwargs: Py<PyDict>,
185+
context: Arc<Py<PyAny>>,
186+
}
187+
188+
struct MaybeCoro {
189+
func: Arc<Py<PyAny>>,
190+
with_context: Option<CoroWithContext>,
191+
}
192+
185193
async fn run_task(
186194
task_registry: Arc<TaskRegistry>,
187195
executor_id: Arc<String>,
188196
task_data: Arc<TaskData>,
189197
task_name: Arc<String>,
190-
raw_args: Arc<Vec<u8>>,
191-
raw_kwargs: Arc<Vec<u8>>,
198+
task: Arc<Task>,
192199
) -> Result<()> {
193200
let logger = Logger::new(format!("EXECUTOR {}", &executor_id));
194201
let duration_start = Instant::now();
195202

196-
let task_args: Value = from_slice(&raw_args).context(format!(
203+
let task_args: Value = from_slice(&task.args).context(format!(
197204
"Failed to deserialize task '{}' function args",
198205
&task_name
199206
))?;
200-
let task_kwargs: Value = from_slice(&raw_kwargs).context(format!(
207+
let task_kwargs: Value = from_slice(&task.kwargs).context(format!(
201208
"Failed to deserialize task '{}' function kwargs",
202209
&task_name
203210
))?;
204211

205212
let context = task_registry.get_task_context(task_data.clone())?;
206213

207-
let maybe_coro = Python::attach(|py| -> Result<Option<Py<PyAny>>> {
214+
let maybe_coro = Python::attach(|py| -> Result<Option<MaybeCoro>> {
208215
let py_args = pythonize(py, &task_args).context("Failed to pythonize args")?;
209216
let py_kwargs = pythonize(py, &task_kwargs).context("Failed to pythonize kwargs")?;
210217

@@ -216,7 +223,11 @@ async fn run_task(
216223
anyhow::bail!("Args must be an array/tuple, found {}", py_args.get_type());
217224
};
218225

219-
if let Some(context) = context {
226+
if let Some(context) = context.as_ref() {
227+
let task_metadata = get_task_metadata(py, task.clone())?;
228+
let metadata_var = context.getattr(py, "_metadata_var")?;
229+
metadata_var.call_method1(py, "set", (task_metadata,))?;
230+
220231
let context = context.as_any();
221232
let prefix = PyTuple::new(py, [context])?;
222233

@@ -230,19 +241,25 @@ async fn run_task(
230241
.cast_into::<PyDict>()
231242
.map_err(|_| anyhow!("Kwargs must be a map/dict"))?;
232243

233-
let result = task_data
234-
.func
235-
.call(py, args_tuple, Some(&kwargs_dict))
236-
.map_err(|e| anyhow!("Failed to call Python function: {:?}", e))?;
237-
238-
let bound_result = result.bind(py);
239-
let is_coroutine = bound_result
240-
.hasattr("__await__")
241-
.map_err(|_| anyhow!("Failed to check if result is awaitable"))?;
244+
let is_coroutine = is_coroutine(py, task_data.func.clone())?;
242245

243246
if is_coroutine {
244-
Ok(Some(result))
247+
let with_context = context.map(|context| CoroWithContext {
248+
args: args_tuple.unbind(),
249+
kwargs: kwargs_dict.unbind(),
250+
context,
251+
});
252+
253+
Ok(Some(MaybeCoro {
254+
func: task_data.func.clone(),
255+
with_context,
256+
}))
245257
} else {
258+
task_data
259+
.func
260+
.call(py, args_tuple.clone(), Some(&kwargs_dict))
261+
.map_err(|e| anyhow!("Failed to call Python function: {:?}", e))?;
262+
246263
let duration_end = duration_start.elapsed();
247264
logger.info(format_args!(
248265
"Task '{}' successfully finished in {}ms",
@@ -264,8 +281,27 @@ async fn run_task(
264281
anyhow!(e.to_string())
265282
})?;
266283

267-
if let Some(coro) = maybe_coro {
268-
let fut = Python::attach(|py| into_future(coro.into_bound(py)))?;
284+
if let Some(maybe_coro) = maybe_coro {
285+
let fut = Python::attach(|py| {
286+
if let Some(with_context) = maybe_coro.with_context {
287+
let task_metadata = get_task_metadata(py, task.clone())
288+
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
289+
let result = with_context.context.call_method1(
290+
py,
291+
"_run_async_task",
292+
(
293+
maybe_coro.func.as_any(),
294+
task_metadata,
295+
with_context.args,
296+
Some(with_context.kwargs),
297+
),
298+
)?;
299+
into_future(result.into_bound(py))
300+
} else {
301+
let func = maybe_coro.func.clone_ref(py);
302+
into_future(func.into_bound(py))
303+
}
304+
})?;
269305
fut.await?;
270306

271307
let duration_end = duration_start.elapsed();
@@ -370,6 +406,30 @@ fn get_registry(module_path: &str, queue_name: &str) -> Result<TasksAndContexts>
370406
Ok(result)
371407
}
372408

409+
fn get_task_metadata(py: Python<'_>, task: Arc<Task>) -> Result<Py<PyAny>> {
410+
let module = py.import("fluxqueue.models")?.unbind();
411+
let task_metadata = module.call_method1(
412+
py,
413+
"TaskMetadata",
414+
(
415+
task.id.clone(),
416+
task.retries,
417+
task.max_retries,
418+
task.created_at,
419+
),
420+
)?;
421+
422+
Ok(task_metadata)
423+
}
424+
425+
fn is_coroutine(py: Python<'_>, func: Arc<Py<PyAny>>) -> Result<bool> {
426+
let inspect = py.import("inspect")?;
427+
let is_coro: bool = inspect
428+
.call_method1("iscoroutinefunction", (func.as_any(),))?
429+
.extract()?;
430+
Ok(is_coro)
431+
}
432+
373433
fn normalize_path(path: &Path) -> PathBuf {
374434
let mut components = Vec::new();
375435
for comp in path.components() {

crates/fluxqueue-worker/src/worker.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ pub async fn run_worker(
2424
})?;
2525
let redis_client = Arc::new(redis_client);
2626

27-
let task_registry = Arc::new(TaskRegistry::new(&tasks_module_path, &queue_name).map_err(|e| {
28-
tracing::error!("{}", e);
29-
std::process::exit(1);
30-
})?);
27+
let task_registry = Arc::new(TaskRegistry::new(&tasks_module_path, &queue_name).map_err(
28+
|e| {
29+
tracing::error!("{}", e);
30+
std::process::exit(1);
31+
},
32+
)?);
3133
let registered_tasks = task_registry.get_registered_tasks()?;
3234
let registered_contexts = task_registry.get_registered_contexts()?;
3335

@@ -164,6 +166,7 @@ async fn executor_loop(
164166
Ok(Some(raw_data)) => {
165167
let task = deserialize_raw_task_data(&raw_data)?;
166168
let task_name = format!("{}:{}", &task.name, &task.id);
169+
let actual_task_name = task.name.clone();
167170

168171
logger.info(format_args!(
169172
"Received a task '{}' with a total of {} Bytes",
@@ -181,7 +184,7 @@ async fn executor_loop(
181184
return Ok(());
182185
};
183186

184-
let task_result = run_task(ctx.executor_id.clone(), ctx.python_dispatcher.clone(), &task, task_data.clone()).await;
187+
let task_result = run_task(ctx.executor_id.clone(), ctx.python_dispatcher.clone(), Arc::new(task), task_data.clone()).await;
185188

186189
match task_result {
187190
Ok(_) => {
@@ -195,7 +198,7 @@ async fn executor_loop(
195198
if let Err(err) = ctx.redis_client
196199
.mark_as_failed(&ctx.queue_name, &ctx.executor_id, &raw_data)
197200
.await {
198-
logger.error(format_args!("Failed to mark the task '{}' as failed: {}", &task.name, err));
201+
logger.error(format_args!("Failed to mark the task '{}' as failed: {}", actual_task_name, err));
199202
}
200203
}
201204
}
@@ -294,15 +297,13 @@ async fn janitor_loop(
294297
async fn run_task(
295298
executor_id: Arc<String>,
296299
python_dispatcher: Arc<PythonDispatcher>,
297-
task: &Task,
300+
task: Arc<Task>,
298301
task_data: Arc<TaskData>,
299302
) -> Result<()> {
300303
let task_name = Arc::new(format!("{}:{}", &task.name, &task.id));
301-
let raw_args = Arc::new(task.args.clone());
302-
let raw_kwargs = Arc::new(task.kwargs.clone());
303304

304305
python_dispatcher
305-
.execute(executor_id, task_data, task_name, raw_args, raw_kwargs)
306+
.execute(executor_id, task_data, task_name, task)
306307
.await?;
307308

308309
Ok(())
@@ -365,7 +366,7 @@ mod tests {
365366
let result = run_task(
366367
Arc::new("test".to_string()),
367368
dispatcher_pool.clone(),
368-
&task,
369+
Arc::new(task),
369370
task_func,
370371
)
371372
.await;
@@ -398,7 +399,7 @@ mod tests {
398399
let result = run_task(
399400
Arc::new("test".to_string()),
400401
dispatcher_pool.clone(),
401-
&task,
402+
Arc::new(task),
402403
task_func,
403404
)
404405
.await;

python/fluxqueue/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
__all__ = ["Context", "FluxQueue"]
2-
3-
from .client import FluxQueue
4-
from .context import Context
1+
from .client import FluxQueue as FluxQueue
2+
from .context import Context as Context
3+
from .models import TaskMetadata as TaskMetadata

python/fluxqueue/_task.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@ def _task_decorator(
4545
if return_type and return_type is not type(None):
4646
raise TypeError(f"Task function must return None, got {return_type}")
4747

48-
is_async = inspect.iscoroutinefunction(func)
4948
task_name = get_task_name(func, name)
5049

5150
# TODO: Add unique identifier 'fluxqueue' just to be 100% sure
5251
cast(Any, func).task_name = task_name
5352
cast(Any, func).queue = queue
5453

55-
if is_async:
54+
if inspect.iscoroutinefunction(func):
5655

5756
@wraps(func)
5857
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> None:

python/fluxqueue/context.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import inspect
2+
import threading
23
from collections.abc import Callable, Coroutine
34
from contextvars import ContextVar
45
from typing import Any, Concatenate, ParamSpec, TypeVar, cast, get_type_hints, overload
56

67
from ._core import FluxQueueCore
78
from ._task import _task_decorator
9+
from .models import TaskMetadata
810

911
P = ParamSpec("P")
1012

@@ -13,28 +15,47 @@ class Context:
1315
__fluxqueue_context__: str | None = None
1416

1517
def __init__(self) -> None:
16-
self._thread_storage: ContextVar[dict[str, Any]] = ContextVar(
17-
"thread_storage", default=None
18+
self._thread_local = threading.local()
19+
self._metadata_var: ContextVar[TaskMetadata] = ContextVar(
20+
"task_metadata", default=None
1821
)
1922

2023
@property
2124
def thread_storage(self) -> dict[str, Any]:
2225
"""
23-
Context-scoped storage dictionary.
26+
Retrieves the thread-persistent storage dictionary.
2427
25-
Returns a mutable dictionary associated with the current
26-
execution context. The storage is isolated per thread or
27-
async task using ContextVar, ensuring safe concurrent usage.
28+
Returns a dictionary that persists across all tasks executed by the current worker.
29+
Used for storing long-lived resources like database engines and connection
30+
pools to avoid re-initialization overhead.
31+
"""
32+
if not hasattr(self._thread_local, "storage"):
33+
self._thread_local.storage = {}
34+
35+
return self._thread_local.storage
2836

29-
The dictionary is initialized lazily on first access.
37+
@property
38+
def metadata(self) -> TaskMetadata:
3039
"""
31-
storage = self._thread_storage.get(None)
40+
Returns metadata isolated to the current task.
3241
33-
if storage is None:
34-
storage = {}
35-
self._thread_storage.set(storage)
42+
Returns a TaskMetadata instance containing execution details like
43+
retry counts and task IDs. This property uses ContextVars to ensure
44+
data isolation during concurrent task execution on the same thread.
45+
"""
46+
return self._metadata_var.get()
3647

37-
return storage
48+
async def _run_async_task(
49+
self, func: Callable, metadata: TaskMetadata, args, kwargs
50+
):
51+
"""
52+
This function is for internal use only.
53+
"""
54+
token = self._metadata_var.set(metadata)
55+
try:
56+
await func(*args, **kwargs)
57+
finally:
58+
self._metadata_var.reset(token)
3859

3960
def __init_subclass__(cls) -> None:
4061
if not cls.__fluxqueue_context__:

python/fluxqueue/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass(slots=True)
5+
class TaskMetadata:
6+
"""
7+
Read-only metadata for a FluxQueue task.
8+
"""
9+
10+
task_id: str
11+
"""Unique identifier for the current task execution."""
12+
retry_count: int
13+
"""Number of times this task has been retried."""
14+
max_retries: int
15+
"""Maximum number of retry attempts allowed before failure."""
16+
enqueued_at: int
17+
"""ISO 8601 timestamp of when the task was originally enqueued."""

0 commit comments

Comments
 (0)