Skip to content

Worker middleware feature #1317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 68 additions & 22 deletions docs/howto/advanced/middleware.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,76 @@
# Add a task middleware
# Middleware for workers and tasks

As of today, Procrastinate has no specific way of ensuring a piece of code runs
before or after every job. That being said, you can always decide to use
your own decorator instead of `@app.task` and have this decorator
implement the actions you need and delegate the rest to `@app.task`.
It might look like this:
Procrastinate lets you add middleware to workers and tasks. Middleware is a
function that wraps the execution of a task, allowing you to execute custom
logic before and after the task runs. You might use it to log task activity,
measure performance, or handle errors consistently.

```
import functools
A middleware is a function or coroutine (see examples below) that takes three arguments:
- `process_task`: a function resp. coroutine (without arguments) that runs the task
- `context`: a `JobContext` object that contains information about the job
- `worker`: the worker that runs the job

The middleware should call resp. await `process_task` to run the task and then return the
result.

:::{note}
The `worker` instance can be used to stop the worker from within the middleware by
calling `worker.stop()`. This will stop the worker after the jobs currently being
processed by the worker are finished.
:::

def task(original_func=None, **kwargs):
def wrap(func):
def new_func(*job_args, **job_kwargs):
# This is the custom part
log_something()
result = func(*job_args, **job_kwargs)
log_something_else()
return result
:::{warning}
When the middleware is called, the job was already fetched from the database and
is in `doing` state. After `process_task` the job is still in `doing` state and will
be updated to its final state after the middleware returns.
:::

wrapped_func = functools.update_wrapper(new_func, func, updated=())
return app.task(**kwargs)(wrapped_func)
## Worker middleware

if not original_func:
return wrap
To add a middleware to a worker, pass a middleware coroutine to the `run_worker` or
`run_worker_async` method. The middleware will wrap the execution of all tasks
run by this worker.

return wrap(original_func)
```python
async def custom_worker_middleware(process_task, context, worker):
# Execute any logic before the task is processed
result = await process_task()
# Execute any logic after the task is processed
return result

app.run_worker(middleware=custom_middleware)
```

Then, define all of your tasks using this `@task` decorator.
## Task middleware

You can also add a middleware to a specific task. This middleware will only wrap
the execution of this task then.

:::{note}
For a sync task, the middleware must be a sync function, and for an async task, the
middleware should be a coroutine.
:::

```python
# middleware of a sync task
def custom_sync_middleware(process_task, context, worker):
# Execute any logic before the task is processed
result = process_task()
# Execute any logic after the task is processed
return result

@app.task(middleware=custom_sync_middleware)
def my_task():
...

# or middleware of an async task
async def custom_async_middleware(process_task, context, worker):
# Execute any logic before the task is processed
result = await process_task()
# Execute any logic after the task is processed
return result

@app.task(middleware=custom_async_middleware)
async def my_task():
...
```
14 changes: 13 additions & 1 deletion procrastinate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@

from typing_extensions import NotRequired, Unpack

from procrastinate import blueprints, exceptions, jobs, manager, schema, utils
from procrastinate import (
blueprints,
exceptions,
jobs,
manager,
middleware,
schema,
utils,
)
from procrastinate import connector as connector_module

if TYPE_CHECKING:
Expand All @@ -34,6 +42,7 @@ class WorkerOptions(TypedDict):
delete_jobs: NotRequired[str | jobs.DeleteJobCondition]
additional_context: NotRequired[dict[str, Any]]
install_signal_handlers: NotRequired[bool]
middleware: NotRequired[middleware.WorkerMiddleware]


class App(blueprints.Blueprint):
Expand Down Expand Up @@ -316,6 +325,9 @@ async def run_worker_async(self, **kwargs: Unpack[WorkerOptions]) -> None:
worker. Use ``False`` if you want to handle signals yourself (e.g. if you
run the work as an async task in a bigger application)
(defaults to ``True``)
middleware: ``Optional[Middleware]``
A coroutine that can be used to wrap the task execution. The default middleware
just awaits the task and returns the result.
"""
self.perform_import_paths()
worker = self._worker(**kwargs)
Expand Down
11 changes: 10 additions & 1 deletion procrastinate/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing_extensions import Concatenate, ParamSpec, TypeVar, Unpack

from procrastinate import exceptions, jobs, periodic, retry, utils
from procrastinate import exceptions, jobs, middleware, periodic, retry, utils
from procrastinate.job_context import JobContext

if TYPE_CHECKING:
Expand Down Expand Up @@ -211,6 +211,7 @@ def task(
priority: int = jobs.DEFAULT_PRIORITY,
lock: str | None = None,
queueing_lock: str | None = None,
middleware: middleware.TaskMiddleware[R] | None = None,
) -> Callable[[Callable[P, R]], Task[P, R, P]]:
"""Declare a function as a task. This method is meant to be used as a decorator
Parameters
Expand Down Expand Up @@ -249,6 +250,11 @@ def task(
Default is no retry.
pass_context :
Passes the task execution context in the task as first
middleware :
A function that can be used to wrap the task execution. The default middleware
just calls the task function and returns its result. If the task is synchronous,
the middleware must also be a sync function. If the task is async, the middleware
must be async, too.
"""
...

Expand All @@ -265,6 +271,7 @@ def task(
priority: int = jobs.DEFAULT_PRIORITY,
lock: str | None = None,
queueing_lock: str | None = None,
middleware: middleware.TaskMiddleware[R] | None = None,
) -> Callable[
[Callable[Concatenate[JobContext, P], R]],
Task[Concatenate[JobContext, P], R, P],
Expand Down Expand Up @@ -299,6 +306,7 @@ def task(
priority: int = jobs.DEFAULT_PRIORITY,
lock: str | None = None,
queueing_lock: str | None = None,
middleware: middleware.TaskMiddleware[R] | None = None,
):
from procrastinate.tasks import Task

Expand Down Expand Up @@ -329,6 +337,7 @@ def _wrap(func: Callable[P, R]) -> Task[P, R, P]:
aliases=aliases,
retry=retry,
pass_context=pass_context,
middleware=middleware,
)
self._register_task(task)

Expand Down
41 changes: 41 additions & 0 deletions procrastinate/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from collections.abc import Awaitable
from typing import TYPE_CHECKING, Callable, TypeVar

from procrastinate import job_context

R = TypeVar("R")

if TYPE_CHECKING:
from procrastinate import worker

ProcessTask = Callable[[], R]
WorkerMiddleware = Callable[
[ProcessTask[Awaitable], job_context.JobContext, "worker.Worker"], Awaitable
]
TaskMiddleware = Callable[[ProcessTask[R], job_context.JobContext, "worker.Worker"], R]


async def default_worker_middleware(
process_task: ProcessTask,
context: job_context.JobContext,
worker: worker.Worker,
):
return await process_task()


async def default_async_task_middleware(
process_task: ProcessTask,
context: job_context.JobContext,
worker: worker.Worker,
):
return await process_task()


def default_sync_task_middleware(
process_task: ProcessTask,
context: job_context.JobContext,
worker: worker.Worker,
):
return process_task()
11 changes: 11 additions & 0 deletions procrastinate/tasks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import datetime
import inspect
import logging
from typing import Callable, Generic, TypedDict, cast

from typing_extensions import NotRequired, ParamSpec, TypeVar, Unpack

from procrastinate import app as app_module
from procrastinate import blueprints, exceptions, jobs, manager, types, utils
from procrastinate import middleware as middleware_module
from procrastinate import retry as retry_module

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(
priority: int = jobs.DEFAULT_PRIORITY,
lock: str | None = None,
queueing_lock: str | None = None,
middleware: middleware_module.TaskMiddleware[R] | None = None,
):
#: Default queue to send deferred jobs to. The queue can be overridden
#: when a job is deferred.
Expand Down Expand Up @@ -113,6 +116,14 @@ def __init__(
#: Default queueing lock. The queuing lock can be overridden when a job
#: is deferred.
self.queueing_lock: str | None = queueing_lock
#: Middleware to be used when the task is executed.
if middleware is not None:
self.middleware = middleware
else:
if inspect.iscoroutinefunction(func):
self.middleware = middleware_module.default_async_task_middleware
else:
self.middleware = middleware_module.default_sync_task_middleware

def add_namespace(self, namespace: str) -> None:
"""
Expand Down
33 changes: 27 additions & 6 deletions procrastinate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
exceptions,
job_context,
jobs,
middleware,
periodic,
retry,
signals,
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
delete_jobs: str | jobs.DeleteJobCondition | None = None,
additional_context: dict[str, Any] | None = None,
install_signal_handlers: bool = True,
middleware: middleware.WorkerMiddleware = middleware.default_worker_middleware,
):
self.app = app
self.queues = queues
Expand All @@ -61,6 +63,7 @@ def __init__(
) or jobs.DeleteJobCondition.NEVER
self.additional_context = additional_context
self.install_signal_handlers = install_signal_handlers
self.middleware = middleware

if self.worker_name:
self.logger = logger.getChild(self.worker_name)
Expand Down Expand Up @@ -230,14 +233,32 @@ async def _process_job(self, context: job_context.JobContext):
exc_info: bool | BaseException = False

async def ensure_async() -> Callable[..., Awaitable]:
await_func: Callable[..., Awaitable]
job_args = [context] if task.pass_context else []
if inspect.iscoroutinefunction(task.func):
await_func = task

async def run_task_async():
return await task.func(*job_args, **job.task_kwargs)

wrapped_middleware = functools.partial(
task.middleware,
run_task_async,
context,
self,
)
else:
await_func = functools.partial(utils.sync_to_async, task)

job_args = [context] if task.pass_context else []
task_result = await await_func(*job_args, **job.task_kwargs)
def run_task_sync():
return task(*job_args, **job.task_kwargs)

wrapped_middleware = functools.partial(
utils.sync_to_async,
task.middleware,
run_task_sync,
context,
self,
)

task_result = await wrapped_middleware()
# In some cases, the task function might be a synchronous function
# that returns an awaitable without actually being a
# coroutinefunction. In that case, in the await above, we haven't
Expand All @@ -251,7 +272,7 @@ async def ensure_async() -> Callable[..., Awaitable]:

return task_result

job_result.result = await ensure_async()
job_result.result = await self.middleware(ensure_async, context, self)

except BaseException as e:
exc_info = e
Expand Down
Loading
Loading