11import random
2+ from datetime import datetime , timedelta , timezone , UTC
3+ from typing import Any
4+ import httpx
5+ import logging
6+ import jwt
7+
8+ from taskiq import TaskiqMessage , TaskiqMiddleware , TaskiqResult
29
3- from taskiq import TaskiqMessage , TaskiqMiddleware
410
511from datalad_service import config
612
13+ logger = logging .getLogger ('datalad_service.' + __name__ )
14+
15+ _UPDATE_WORKER_TASK_MUTATION = """
16+ mutation UpdateWorkerTask(
17+ $id: ID!
18+ $args: JSON
19+ $kwargs: JSON
20+ $taskName: String
21+ $worker: String
22+ $queuedAt: DateTime
23+ $startedAt: DateTime
24+ $finishedAt: DateTime
25+ $error: String
26+ $executionTime: Int
27+ ) {
28+ updateWorkerTask(
29+ id: $id
30+ args: $args
31+ kwargs: $kwargs
32+ taskName: $taskName
33+ worker: $worker
34+ queuedAt: $queuedAt
35+ startedAt: $startedAt
36+ finishedAt: $finishedAt
37+ error: $error
38+ executionTime: $executionTime
39+ ) {
40+ id
41+ }
42+ }
43+ """
44+
45+
46+ def _update_worker_task_body (** kwargs ):
47+ """Create a GraphQL mutation body for updateWorkerTask."""
48+ return {
49+ 'query' : _UPDATE_WORKER_TASK_MUTATION ,
50+ 'variables' : kwargs ,
51+ 'operationName' : 'UpdateWorkerTask' ,
52+ }
53+
54+
55+ def generate_worker_token ():
56+ utc_now = datetime .now (timezone .utc )
57+ one_day_ahead = utc_now + timedelta (hours = 24 )
58+ return jwt .encode (
59+ {
60+ 'sub' : 'dataset-worker' ,
61+ 'iat' : int (utc_now .timestamp ()),
62+ 'exp' : int (one_day_ahead .timestamp ()),
63+ 'scopes' : ['dataset:worker' ],
64+ },
65+ config .JWT_SECRET ,
66+ algorithm = 'HS256' ,
67+ )
68+
769
870class WorkerMiddleware (TaskiqMiddleware ):
971 """
10- This middleware adds a custom worker label to outgoing tasks scheduled by workers.
72+ This middleware adds a custom worker label to outgoing tasks scheduled by workers
73+ and reports task status back to the OpenNeuro API.
1174 """
1275
1376 def __init__ (self , worker_id = None ):
1477 self .worker_id = worker_id
78+ self .api_token = generate_worker_token ()
1579
1680 async def pre_send (self , message : TaskiqMessage ) -> TaskiqMessage :
81+ """Assign new tasks to the correct worker."""
1782 if self .worker_id :
1883 message .labels ['queue_name' ] = f'worker-{ self .worker_id } '
1984 else :
@@ -22,3 +87,61 @@ async def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
2287 f'worker-{ random .randint (0 , config .DATALAD_WORKERS - 1 )} '
2388 )
2489 return message
90+
91+ async def _update_task_status (self , ** kwargs ):
92+ """Helper to send updates to the GraphQL API."""
93+ if not self .api_token :
94+ logger .warning ('DATALAD_WORKER_TOKEN not set, cannot update task status.' )
95+ return
96+
97+ body = _update_worker_task_body (** kwargs )
98+ try :
99+ async with httpx .AsyncClient () as client :
100+ response = await client .post (
101+ url = config .GRAPHQL_ENDPOINT ,
102+ json = body ,
103+ headers = {'Authorization' : f'Bearer { self .api_token } ' },
104+ )
105+ response .raise_for_status ()
106+ response_json = response .json ()
107+ if 'errors' in response_json :
108+ logger .error (
109+ f'GraphQL error updating task status: { response_json ["errors" ]} ' ,
110+ )
111+ except httpx .HTTPError as e :
112+ logger .error (f'HTTP error updating task status: { e } ' )
113+ logger .error (f'Response: { e .response .text } ' )
114+ except Exception as e :
115+ logger .error (f'Unexpected error updating task status: { e } ' )
116+
117+ async def post_send (self , message : TaskiqMessage ):
118+ """Called after a task is sent to the broker."""
119+ now = datetime .now (UTC ).isoformat ()
120+ await self ._update_task_status (
121+ id = message .task_id ,
122+ args = message .args ,
123+ kwargs = message .kwargs ,
124+ taskName = message .task_name ,
125+ worker = message .labels .get ('queue_name' ),
126+ queuedAt = now ,
127+ )
128+
129+ async def pre_execute (self , message : TaskiqMessage ) -> TaskiqMessage :
130+ """Called before a worker executes a task."""
131+ now = datetime .now (UTC ).isoformat ()
132+ await self ._update_task_status (id = message .task_id , startedAt = now )
133+ return message
134+
135+ async def post_execute (
136+ self ,
137+ message : TaskiqMessage ,
138+ result : TaskiqResult [Any ],
139+ ):
140+ """Called after a worker executes a task."""
141+ now = datetime .now (UTC ).isoformat ()
142+ await self ._update_task_status (
143+ id = message .task_id ,
144+ finishedAt = now ,
145+ error = None if result .error is None else repr (result .error ),
146+ executionTime = round (result .execution_time * 1000 ),
147+ )
0 commit comments