1
1
import asyncio
2
+ from collections import OrderedDict
2
3
from dataclasses import dataclass
3
4
from datetime import datetime , timedelta
5
+ from enum import Enum
4
6
import logging
5
- from typing import Optional
7
+ from typing import Awaitable , Callable , Optional
6
8
7
9
from temporalio import common , workflow , activity
8
10
from temporalio .client import Client , WorkflowHandle
@@ -29,22 +31,97 @@ class JobOutput:
29
31
stderr : str
30
32
31
33
34
+ class TaskStatus (Enum ):
35
+ BLOCKED = 1
36
+ UNBLOCKED = 2
37
+
38
+
39
+ @dataclass
40
+ class Task :
41
+ input : Job
42
+ handler : Callable [["JobRunner" , Job ], Awaitable [JobOutput ]]
43
+ status : TaskStatus = TaskStatus .BLOCKED
44
+ output : Optional [JobOutput ] = None
45
+
46
+ @property
47
+ def blocked (self ) -> bool :
48
+ return self .status == TaskStatus .BLOCKED
49
+
50
+
32
51
@workflow .defn
33
52
class JobRunner :
34
53
"""
35
54
Jobs must be executed in order dictated by job dependency graph (see `job.depends_on`) and
36
55
not before `job.after_time`.
37
56
"""
38
57
58
+ def __init__ (self ) -> None :
59
+ self .task_queue = OrderedDict [JobID , Task ]()
60
+ self .completed_tasks = set [JobID ]()
61
+
62
+ def all_handlers_completed (self ):
63
+ # We are considering adding an API like `all_handlers_completed` to SDKs. In this particular
64
+ # case, the user doesn't actually need the new API, since they are forced to track pending
65
+ # tasks in their queue implementation.
66
+ return not self .task_queue
67
+
68
+ # Note some undesirable things:
69
+ # 1. The update handler functions have become generic enqueuers; the "real" handler functions
70
+ # are some other methods that don't have the @workflow.update decorator.
71
+ # 2. The update handler functions have to store a reference to the real handler in the queue.
72
+ # 3. The workflow `run` method is *much* more complicated and bug-prone here, compared to
73
+ # I1:WaitUntilReadyToExecuteHandler
74
+
39
75
@workflow .run
40
76
async def run (self ):
41
- await workflow .wait_condition (
42
- lambda : workflow .info ().is_continue_as_new_suggested ()
43
- )
77
+ """
78
+ Process all tasks in the queue serially, in the main workflow coroutine.
79
+ """
80
+ # Note: there are many mistakes a user will make while trying to implement this workflow.
81
+ while not (
82
+ workflow .info ().is_continue_as_new_suggested ()
83
+ and self .all_handlers_completed ()
84
+ ):
85
+ await workflow .wait_condition (lambda : bool (self .task_queue ))
86
+ for id , task in list (self .task_queue .items ()):
87
+ if task .status == TaskStatus .UNBLOCKED :
88
+ await task .handler (self , task .input )
89
+ del self .task_queue [id ]
90
+ self .completed_tasks .add (id )
91
+ for id , task in self .task_queue .items ():
92
+ if task .status == TaskStatus .BLOCKED and self .ready_to_execute (
93
+ task .input
94
+ ):
95
+ task .status = TaskStatus .UNBLOCKED
44
96
workflow .continue_as_new ()
45
97
98
+ def ready_to_execute (self , job : Job ) -> bool :
99
+ if not set (job .depends_on ) <= self .completed_tasks :
100
+ return False
101
+ if after_time := job .after_time :
102
+ if float (after_time ) > workflow .now ().timestamp ():
103
+ return False
104
+ return True
105
+
106
+ async def _enqueue_job_and_wait_for_result (
107
+ self , job : Job , handler : Callable [["JobRunner" , Job ], Awaitable [JobOutput ]]
108
+ ) -> JobOutput :
109
+ task = Task (job , handler )
110
+ self .task_queue [job .id ] = task
111
+ await workflow .wait_condition (lambda : task .output is not None )
112
+ # Footgun: a user might well think that they can record task completion here, but in fact it
113
+ # deadlocks.
114
+ # self.completed_tasks.add(job.id)
115
+ assert task .output
116
+ return task .output
117
+
46
118
@workflow .update
47
119
async def run_shell_script_job (self , job : Job ) -> JobOutput :
120
+ return await self ._enqueue_job_and_wait_for_result (
121
+ job , JobRunner ._actually_run_shell_script_job
122
+ )
123
+
124
+ async def _actually_run_shell_script_job (self , job : Job ) -> JobOutput :
48
125
if security_errors := await workflow .execute_activity (
49
126
run_shell_script_security_linter ,
50
127
args = [job .run ],
@@ -58,6 +135,11 @@ async def run_shell_script_job(self, job: Job) -> JobOutput:
58
135
59
136
@workflow .update
60
137
async def run_python_job (self , job : Job ) -> JobOutput :
138
+ return await self ._enqueue_job_and_wait_for_result (
139
+ job , JobRunner ._actually_run_python_job
140
+ )
141
+
142
+ async def _actually_run_python_job (self , job : Job ) -> JobOutput :
61
143
if not await workflow .execute_activity (
62
144
check_python_interpreter_version ,
63
145
args = [job .python_interpreter_version ],
0 commit comments