Skip to content

Commit 5ac5708

Browse files
committed
feat(worker): Report task runtime state to API
This allows for showing status of tasks and debugging tasks more easily in production.
1 parent b3ff9f4 commit 5ac5708

File tree

3 files changed

+186
-2
lines changed

3 files changed

+186
-2
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import WorkerTask from "../../models/worker-task"
2+
import { checkWorker } from "../permissions"
3+
4+
/**
5+
* Update a worker task record
6+
*
7+
* This can be called for new tasks, or to update existing tasks.
8+
*/
9+
export const updateWorkerTask = async (obj, args, { userInfo }) => {
10+
checkWorker(userInfo)
11+
const { id, ...updateData } = args
12+
13+
// Don't allow null values to unset fields
14+
const update = Object.fromEntries(
15+
Object.entries(updateData).filter(([, value]) => value != null),
16+
)
17+
18+
const task = await WorkerTask.findOneAndUpdate({ id }, { $set: update }, {
19+
new: true,
20+
upsert: true,
21+
}).exec()
22+
return task
23+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import mongoose from "mongoose"
2+
import type { Document } from "mongoose"
3+
const { Schema, model } = mongoose
4+
5+
export interface WorkerTaskDocument extends Document {
6+
id: string
7+
args?: Record<string, unknown>
8+
kwargs?: Record<string, unknown>
9+
taskName?: string
10+
worker?: string
11+
queuedAt?: Date
12+
startedAt?: Date
13+
finishedAt?: Date
14+
error?: string
15+
executionTime?: number
16+
}
17+
18+
const workerTaskSchema = new Schema({
19+
id: { type: String, required: true, unique: true },
20+
args: { type: Object },
21+
kwargs: { type: Object },
22+
taskName: { type: String },
23+
worker: { type: String },
24+
queuedAt: { type: Date },
25+
startedAt: { type: Date },
26+
finishedAt: { type: Date },
27+
error: { type: String },
28+
executionTime: { type: Number },
29+
})
30+
31+
workerTaskSchema.index({ id: 1 })
32+
33+
const WorkerTaskModel = model<WorkerTaskDocument>(
34+
"WorkerTask",
35+
workerTaskSchema,
36+
)
37+
38+
export default WorkerTaskModel
Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,84 @@
11
import random
2+
from datetime import datetime, timedelta, timezone, UTC
3+
from typing import Any
4+
import httpx
5+
import logging
6+
import jwt
7+
8+
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
29

3-
from taskiq import TaskiqMessage, TaskiqMiddleware
410

511
from datalad_service import config
612

13+
logger = logging.getLogger('datalad_service.' + __name__)
14+
15+
_UPDATE_WORKER_TASK_MUTATION = """
16+
mutation UpdateWorkerTask(
17+
$id: ID!
18+
$args: JSON
19+
$kwargs: JSON
20+
$taskName: String
21+
$worker: String
22+
$queuedAt: DateTime
23+
$startedAt: DateTime
24+
$finishedAt: DateTime
25+
$error: String
26+
$executionTime: Int
27+
) {
28+
updateWorkerTask(
29+
id: $id
30+
args: $args
31+
kwargs: $kwargs
32+
taskName: $taskName
33+
worker: $worker
34+
queuedAt: $queuedAt
35+
startedAt: $startedAt
36+
finishedAt: $finishedAt
37+
error: $error
38+
executionTime: $executionTime
39+
) {
40+
id
41+
}
42+
}
43+
"""
44+
45+
46+
def _update_worker_task_body(**kwargs):
47+
"""Create a GraphQL mutation body for updateWorkerTask."""
48+
return {
49+
'query': _UPDATE_WORKER_TASK_MUTATION,
50+
'variables': kwargs,
51+
'operationName': 'UpdateWorkerTask',
52+
}
53+
54+
55+
def generate_worker_token():
56+
utc_now = datetime.now(timezone.utc)
57+
one_day_ahead = utc_now + timedelta(hours=24)
58+
return jwt.encode(
59+
{
60+
'sub': 'dataset-worker',
61+
'iat': int(utc_now.timestamp()),
62+
'exp': int(one_day_ahead.timestamp()),
63+
'scopes': ['dataset:worker'],
64+
},
65+
config.JWT_SECRET,
66+
algorithm='HS256',
67+
)
68+
769

870
class WorkerMiddleware(TaskiqMiddleware):
971
"""
10-
This middleware adds a custom worker label to outgoing tasks scheduled by workers.
72+
This middleware adds a custom worker label to outgoing tasks scheduled by workers
73+
and reports task status back to the OpenNeuro API.
1174
"""
1275

1376
def __init__(self, worker_id=None):
1477
self.worker_id = worker_id
78+
self.api_token = generate_worker_token()
1579

1680
async def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
81+
"""Assign new tasks to the correct worker."""
1782
if self.worker_id:
1883
message.labels['queue_name'] = f'worker-{self.worker_id}'
1984
else:
@@ -22,3 +87,61 @@ async def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
2287
f'worker-{random.randint(0, config.DATALAD_WORKERS - 1)}'
2388
)
2489
return message
90+
91+
async def _update_task_status(self, **kwargs):
92+
"""Helper to send updates to the GraphQL API."""
93+
if not self.api_token:
94+
logger.warning('DATALAD_WORKER_TOKEN not set, cannot update task status.')
95+
return
96+
97+
body = _update_worker_task_body(**kwargs)
98+
try:
99+
async with httpx.AsyncClient() as client:
100+
response = await client.post(
101+
url=config.GRAPHQL_ENDPOINT,
102+
json=body,
103+
headers={'Authorization': f'Bearer {self.api_token}'},
104+
)
105+
response.raise_for_status()
106+
response_json = response.json()
107+
if 'errors' in response_json:
108+
logger.error(
109+
f'GraphQL error updating task status: {response_json["errors"]}',
110+
)
111+
except httpx.HTTPError as e:
112+
logger.error(f'HTTP error updating task status: {e}')
113+
logger.error(f'Response: {e.response.text}')
114+
except Exception as e:
115+
logger.error(f'Unexpected error updating task status: {e}')
116+
117+
async def post_send(self, message: TaskiqMessage):
118+
"""Called after a task is sent to the broker."""
119+
now = datetime.now(UTC).isoformat()
120+
await self._update_task_status(
121+
id=message.task_id,
122+
args=message.args,
123+
kwargs=message.kwargs,
124+
taskName=message.task_name,
125+
worker=message.labels.get('queue_name'),
126+
queuedAt=now,
127+
)
128+
129+
async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
130+
"""Called before a worker executes a task."""
131+
now = datetime.now(UTC).isoformat()
132+
await self._update_task_status(id=message.task_id, startedAt=now)
133+
return message
134+
135+
async def post_execute(
136+
self,
137+
message: TaskiqMessage,
138+
result: TaskiqResult[Any],
139+
):
140+
"""Called after a worker executes a task."""
141+
now = datetime.now(UTC).isoformat()
142+
await self._update_task_status(
143+
id=message.task_id,
144+
finishedAt=now,
145+
error=None if result.error is None else repr(result.error),
146+
executionTime=round(result.execution_time * 1000),
147+
)

0 commit comments

Comments
 (0)