diff --git a/docs/howto/advanced/middleware.md b/docs/howto/advanced/middleware.md index f0717f98b..35f6d32c1 100644 --- a/docs/howto/advanced/middleware.md +++ b/docs/howto/advanced/middleware.md @@ -1,30 +1,76 @@ -# Add a task middleware +# Middleware for workers and tasks -As of today, Procrastinate has no specific way of ensuring a piece of code runs -before or after every job. That being said, you can always decide to use -your own decorator instead of `@app.task` and have this decorator -implement the actions you need and delegate the rest to `@app.task`. -It might look like this: +Procrastinate lets you add middleware to workers and tasks. Middleware is a +function that wraps the execution of a task, allowing you to execute custom +logic before and after the task runs. You might use it to log task activity, +measure performance, or handle errors consistently. -``` -import functools +A middleware is a function or coroutine (see examples below) that takes three arguments: +- `process_task`: a function resp. coroutine (without arguments) that runs the task +- `context`: a `JobContext` object that contains information about the job +- `worker`: the worker that runs the job + +The middleware should call resp. await `process_task` to run the task and then return the +result. + +:::{note} +The `worker` instance can be used to stop the worker from within the middleware by +calling `worker.stop()`. This will stop the worker after the jobs currently being +processed by the worker are finished. +::: -def task(original_func=None, **kwargs): - def wrap(func): - def new_func(*job_args, **job_kwargs): - # This is the custom part - log_something() - result = func(*job_args, **job_kwargs) - log_something_else() - return result +:::{warning} +When the middleware is called, the job was already fetched from the database and +is in `doing` state. After `process_task` the job is still in `doing` state and will +be updated to its final state after the middleware returns. +::: - wrapped_func = functools.update_wrapper(new_func, func, updated=()) - return app.task(**kwargs)(wrapped_func) +## Worker middleware - if not original_func: - return wrap +To add a middleware to a worker, pass a middleware coroutine to the `run_worker` or +`run_worker_async` method. The middleware will wrap the execution of all tasks +run by this worker. - return wrap(original_func) +```python +async def custom_worker_middleware(process_task, context, worker): + # Execute any logic before the task is processed + result = await process_task() + # Execute any logic after the task is processed + return result + +app.run_worker(middleware=custom_middleware) ``` -Then, define all of your tasks using this `@task` decorator. +## Task middleware + +You can also add a middleware to a specific task. This middleware will only wrap +the execution of this task then. + +:::{note} +For a sync task, the middleware must be a sync function, and for an async task, the +middleware should be a coroutine. +::: + +```python +# middleware of a sync task +def custom_sync_middleware(process_task, context, worker): + # Execute any logic before the task is processed + result = process_task() + # Execute any logic after the task is processed + return result + +@app.task(middleware=custom_sync_middleware) +def my_task(): + ... + +# or middleware of an async task +async def custom_async_middleware(process_task, context, worker): + # Execute any logic before the task is processed + result = await process_task() + # Execute any logic after the task is processed + return result + +@app.task(middleware=custom_async_middleware) +async def my_task(): + ... +``` diff --git a/procrastinate/app.py b/procrastinate/app.py index ac2d50c35..ea0d3e0be 100644 --- a/procrastinate/app.py +++ b/procrastinate/app.py @@ -13,7 +13,15 @@ from typing_extensions import NotRequired, Unpack -from procrastinate import blueprints, exceptions, jobs, manager, schema, utils +from procrastinate import ( + blueprints, + exceptions, + jobs, + manager, + middleware, + schema, + utils, +) from procrastinate import connector as connector_module if TYPE_CHECKING: @@ -34,6 +42,7 @@ class WorkerOptions(TypedDict): delete_jobs: NotRequired[str | jobs.DeleteJobCondition] additional_context: NotRequired[dict[str, Any]] install_signal_handlers: NotRequired[bool] + middleware: NotRequired[middleware.WorkerMiddleware] class App(blueprints.Blueprint): @@ -316,6 +325,9 @@ async def run_worker_async(self, **kwargs: Unpack[WorkerOptions]) -> None: worker. Use ``False`` if you want to handle signals yourself (e.g. if you run the work as an async task in a bigger application) (defaults to ``True``) + middleware: ``Optional[Middleware]`` + A coroutine that can be used to wrap the task execution. The default middleware + just awaits the task and returns the result. """ self.perform_import_paths() worker = self._worker(**kwargs) diff --git a/procrastinate/blueprints.py b/procrastinate/blueprints.py index e39619189..714fc41f5 100644 --- a/procrastinate/blueprints.py +++ b/procrastinate/blueprints.py @@ -7,7 +7,7 @@ from typing_extensions import Concatenate, ParamSpec, TypeVar, Unpack -from procrastinate import exceptions, jobs, periodic, retry, utils +from procrastinate import exceptions, jobs, middleware, periodic, retry, utils from procrastinate.job_context import JobContext if TYPE_CHECKING: @@ -211,6 +211,7 @@ def task( priority: int = jobs.DEFAULT_PRIORITY, lock: str | None = None, queueing_lock: str | None = None, + middleware: middleware.TaskMiddleware[R] | None = None, ) -> Callable[[Callable[P, R]], Task[P, R, P]]: """Declare a function as a task. This method is meant to be used as a decorator Parameters @@ -249,6 +250,11 @@ def task( Default is no retry. pass_context : Passes the task execution context in the task as first + middleware : + A function that can be used to wrap the task execution. The default middleware + just calls the task function and returns its result. If the task is synchronous, + the middleware must also be a sync function. If the task is async, the middleware + must be async, too. """ ... @@ -265,6 +271,7 @@ def task( priority: int = jobs.DEFAULT_PRIORITY, lock: str | None = None, queueing_lock: str | None = None, + middleware: middleware.TaskMiddleware[R] | None = None, ) -> Callable[ [Callable[Concatenate[JobContext, P], R]], Task[Concatenate[JobContext, P], R, P], @@ -299,6 +306,7 @@ def task( priority: int = jobs.DEFAULT_PRIORITY, lock: str | None = None, queueing_lock: str | None = None, + middleware: middleware.TaskMiddleware[R] | None = None, ): from procrastinate.tasks import Task @@ -329,6 +337,7 @@ def _wrap(func: Callable[P, R]) -> Task[P, R, P]: aliases=aliases, retry=retry, pass_context=pass_context, + middleware=middleware, ) self._register_task(task) diff --git a/procrastinate/middleware.py b/procrastinate/middleware.py new file mode 100644 index 000000000..9530347bc --- /dev/null +++ b/procrastinate/middleware.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Callable, TypeVar + +from procrastinate import job_context + +R = TypeVar("R") + +if TYPE_CHECKING: + from procrastinate import worker + +ProcessTask = Callable[[], R] +WorkerMiddleware = Callable[ + [ProcessTask[Awaitable], job_context.JobContext, "worker.Worker"], Awaitable +] +TaskMiddleware = Callable[[ProcessTask[R], job_context.JobContext, "worker.Worker"], R] + + +async def default_worker_middleware( + process_task: ProcessTask, + context: job_context.JobContext, + worker: worker.Worker, +): + return await process_task() + + +async def default_async_task_middleware( + process_task: ProcessTask, + context: job_context.JobContext, + worker: worker.Worker, +): + return await process_task() + + +def default_sync_task_middleware( + process_task: ProcessTask, + context: job_context.JobContext, + worker: worker.Worker, +): + return process_task() diff --git a/procrastinate/tasks.py b/procrastinate/tasks.py index 876da1c69..ba4cd1131 100644 --- a/procrastinate/tasks.py +++ b/procrastinate/tasks.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import inspect import logging from typing import Callable, Generic, TypedDict, cast @@ -8,6 +9,7 @@ from procrastinate import app as app_module from procrastinate import blueprints, exceptions, jobs, manager, types, utils +from procrastinate import middleware as middleware_module from procrastinate import retry as retry_module logger = logging.getLogger(__name__) @@ -85,6 +87,7 @@ def __init__( priority: int = jobs.DEFAULT_PRIORITY, lock: str | None = None, queueing_lock: str | None = None, + middleware: middleware_module.TaskMiddleware[R] | None = None, ): #: Default queue to send deferred jobs to. The queue can be overridden #: when a job is deferred. @@ -113,6 +116,14 @@ def __init__( #: Default queueing lock. The queuing lock can be overridden when a job #: is deferred. self.queueing_lock: str | None = queueing_lock + #: Middleware to be used when the task is executed. + if middleware is not None: + self.middleware = middleware + else: + if inspect.iscoroutinefunction(func): + self.middleware = middleware_module.default_async_task_middleware + else: + self.middleware = middleware_module.default_sync_task_middleware def add_namespace(self, namespace: str) -> None: """ diff --git a/procrastinate/worker.py b/procrastinate/worker.py index f4227d9d8..d48c81ea9 100644 --- a/procrastinate/worker.py +++ b/procrastinate/worker.py @@ -14,6 +14,7 @@ exceptions, job_context, jobs, + middleware, periodic, retry, signals, @@ -45,6 +46,7 @@ def __init__( delete_jobs: str | jobs.DeleteJobCondition | None = None, additional_context: dict[str, Any] | None = None, install_signal_handlers: bool = True, + middleware: middleware.WorkerMiddleware = middleware.default_worker_middleware, ): self.app = app self.queues = queues @@ -61,6 +63,7 @@ def __init__( ) or jobs.DeleteJobCondition.NEVER self.additional_context = additional_context self.install_signal_handlers = install_signal_handlers + self.middleware = middleware if self.worker_name: self.logger = logger.getChild(self.worker_name) @@ -230,14 +233,32 @@ async def _process_job(self, context: job_context.JobContext): exc_info: bool | BaseException = False async def ensure_async() -> Callable[..., Awaitable]: - await_func: Callable[..., Awaitable] + job_args = [context] if task.pass_context else [] if inspect.iscoroutinefunction(task.func): - await_func = task + + async def run_task_async(): + return await task.func(*job_args, **job.task_kwargs) + + wrapped_middleware = functools.partial( + task.middleware, + run_task_async, + context, + self, + ) else: - await_func = functools.partial(utils.sync_to_async, task) - job_args = [context] if task.pass_context else [] - task_result = await await_func(*job_args, **job.task_kwargs) + def run_task_sync(): + return task(*job_args, **job.task_kwargs) + + wrapped_middleware = functools.partial( + utils.sync_to_async, + task.middleware, + run_task_sync, + context, + self, + ) + + task_result = await wrapped_middleware() # In some cases, the task function might be a synchronous function # that returns an awaitable without actually being a # coroutinefunction. In that case, in the await above, we haven't @@ -251,7 +272,7 @@ async def ensure_async() -> Callable[..., Awaitable]: return task_result - job_result.result = await ensure_async() + job_result.result = await self.middleware(ensure_async, context, self) except BaseException as e: exc_info = e diff --git a/tests/unit/test_middleware.py b/tests/unit/test_middleware.py new file mode 100644 index 000000000..1248afa22 --- /dev/null +++ b/tests/unit/test_middleware.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from procrastinate.app import App +from procrastinate.job_context import JobContext +from procrastinate.worker import Worker + + +async def test_worker_middleware(app: App): + @app.task() + async def task_func(): + return 42 + + await task_func.defer_async() + + middleware_called = False + + async def custom_worker_middleware(process_task, context, worker): + assert isinstance(context, JobContext) + assert isinstance(worker, Worker) + worker.stop() + result = await process_task() + assert result == 42 + nonlocal middleware_called + middleware_called = True + return result + + await app.run_worker_async(wait=True, middleware=custom_worker_middleware) + + assert middleware_called + + +async def test_sync_task_middleware(app: App): + middleware_called = False + + def sync_task_middleware(process_task, context, worker): + assert isinstance(context, JobContext) + assert isinstance(worker, Worker) + worker.stop() + result = process_task() + assert result == 42 + nonlocal middleware_called + middleware_called = True + return result + + @app.task(middleware=sync_task_middleware) + def my_task(a): + return a + + await my_task.defer_async(a=42) + + await app.run_worker_async(wait=True) + + assert middleware_called + + +async def test_async_task_middleware(app: App): + middleware_called = False + + async def async_task_middleware(process_task, context, worker): + assert isinstance(context, JobContext) + assert isinstance(worker, Worker) + worker.stop() + result = await process_task() + assert result == 42 + nonlocal middleware_called + middleware_called = True + return result + + @app.task(middleware=async_task_middleware) + async def my_task(a): + return a + + await my_task.defer_async(a=42) + + await app.run_worker_async(wait=True) + + assert middleware_called