Skip to content

Commit d8f4468

Browse files
committed
Serialized, non-ordered handling of n messages
1 parent 4a4ab8c commit d8f4468

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import asyncio
2+
import logging
3+
from datetime import timedelta
4+
from typing import Optional
5+
6+
from temporalio import activity, common, workflow
7+
from temporalio.client import Client, WorkflowHandle
8+
from temporalio.worker import Worker
9+
10+
Arg = str
11+
Result = str
12+
13+
# Problem:
14+
# -------
15+
# - Your workflow receives an unbounded number of updates.
16+
# - Each update must be processed by calling two activities.
17+
# - The next update may not start processing until the previous has completed.
18+
19+
# Solution:
20+
# --------
21+
# Enqueue updates, and process items from the queue in a single coroutine (the main workflow
22+
# coroutine).
23+
24+
# Discussion:
25+
# ----------
26+
# The queue is used because Temporal's async update & signal handlers will interleave if they
27+
# contain multiple yield points. An alternative would be to use standard async handler functions,
28+
# with handling being done with an asyncio.Lock held. The queue approach would be necessary if we
29+
# need to process in an order other than arrival.
30+
31+
32+
class Queue(asyncio.Queue[tuple[Arg, asyncio.Future[Result]]]):
33+
def __init__(self, serialized_queue_state: list[Arg]) -> None:
34+
super().__init__()
35+
for arg in serialized_queue_state:
36+
self.put_nowait((arg, asyncio.Future()))
37+
38+
def serialize(self) -> list[Arg]:
39+
args = []
40+
while True:
41+
try:
42+
args.append(self.get_nowait())
43+
except asyncio.QueueEmpty:
44+
return args
45+
46+
47+
@workflow.defn
48+
class MessageProcessor:
49+
50+
@workflow.run
51+
async def run(self, serialized_queue_state: Optional[list[Arg]] = None):
52+
self.queue = Queue(serialized_queue_state or [])
53+
while True:
54+
arg, fut = await self.queue.get()
55+
fut.set_result(await self.process_task(arg))
56+
if workflow.info().is_continue_as_new_suggested():
57+
# Footgun: If we don't let the event loop tick, then CAN will end the workflow
58+
# before the update handler is notified that the result future has completed.
59+
# See https://github.com/temporalio/features/issues/481
60+
await asyncio.sleep(0) # Let update handler complete
61+
print("CAN")
62+
return workflow.continue_as_new(args=[self.queue.serialize()])
63+
64+
# Note: handler must be async if we are both enqueuing, and returning an update result
65+
# => We could add SDK APIs to manually complete updates.
66+
@workflow.update
67+
async def add_task(self, arg: Arg) -> Result:
68+
# Footgun: handler must wait for workflow initialization
69+
# See https://github.com/temporalio/features/issues/400
70+
await workflow.wait_condition(lambda: hasattr(self, "queue"))
71+
fut = asyncio.Future[Result]()
72+
self.queue.put_nowait((arg, fut)) # Note: update validation gates enqueue
73+
return await fut
74+
75+
async def process_task(self, arg):
76+
t1, t2 = [
77+
await workflow.execute_activity(
78+
get_current_time, start_to_close_timeout=timedelta(seconds=10)
79+
)
80+
for _ in range(2)
81+
]
82+
return f"{arg}-result-{t1}-{t2}"
83+
84+
85+
time = 0
86+
87+
88+
@activity.defn
89+
async def get_current_time() -> int:
90+
global time
91+
time += 1
92+
return time
93+
94+
95+
async def app(wf: WorkflowHandle):
96+
for i in range(20):
97+
print(f"app(): sending update {i}")
98+
result = await wf.execute_update(MessageProcessor.add_task, f"arg {i}")
99+
print(f"app(): {result}")
100+
101+
102+
async def main():
103+
client = await Client.connect("localhost:7233")
104+
105+
async with Worker(
106+
client,
107+
task_queue="tq",
108+
workflows=[MessageProcessor],
109+
activities=[get_current_time],
110+
):
111+
wf = await client.start_workflow(
112+
MessageProcessor.run,
113+
id="wid",
114+
task_queue="tq",
115+
id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING,
116+
)
117+
await asyncio.gather(app(wf), wf.result())
118+
119+
120+
if __name__ == "__main__":
121+
logging.basicConfig(level=logging.INFO)
122+
asyncio.run(main())

0 commit comments

Comments
 (0)