Skip to content

Commit a7c04d4

Browse files
authored
Context propagation sample (#120)
1 parent e299047 commit a7c04d4

File tree

11 files changed

+366
-0
lines changed

11 files changed

+366
-0
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Some examples require extra dependencies. See each sample's directory for specif
5353
<!-- Keep this list in alphabetical order -->
5454
* [activity_worker](activity_worker) - Use Python activities from a workflow in another language.
5555
* [cloud_export_to_parquet](cloud_export_to_parquet) - Set up schedule workflow to process exported files on an hourly basis
56+
* [context_propagation](context_propagation) - Context propagation through workflows/activities via interceptor.
5657
* [custom_converter](custom_converter) - Use a custom payload converter to handle custom types.
5758
* [custom_decorator](custom_decorator) - Custom decorator to auto-heartbeat a long-running activity.
5859
* [dsl](dsl) - DSL workflow that executes steps defined in a YAML file.

Diff for: context_propagation/README.md

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Context Propagation Interceptor Sample
2+
3+
This sample shows how to use an interceptor to propagate contextual information through workflows and activities. For
4+
this example, [contextvars](https://docs.python.org/3/library/contextvars.html) holds the contextual information.
5+
6+
To run, first see [README.md](../README.md) for prerequisites. Then, run the following from this directory to start the
7+
worker:
8+
9+
poetry run python worker.py
10+
11+
This will start the worker. Then, in another terminal, run the following to execute the workflow:
12+
13+
poetry run python starter.py
14+
15+
The starter terminal should complete with the hello result and the worker terminal should show the logs with the
16+
propagated user ID contextual information flowing through the workflows/activities.

Diff for: context_propagation/__init__.py

Whitespace-only changes.

Diff for: context_propagation/activities.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from temporalio import activity
2+
3+
from context_propagation import shared
4+
5+
6+
@activity.defn
7+
async def say_hello_activity(name: str) -> str:
8+
activity.logger.info(f"Activity called by user {shared.user_id.get()}")
9+
return f"Hello, {name}"

Diff for: context_propagation/interceptor.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import Any, Mapping, Protocol, Type
5+
6+
import temporalio.activity
7+
import temporalio.api.common.v1
8+
import temporalio.client
9+
import temporalio.converter
10+
import temporalio.worker
11+
import temporalio.workflow
12+
13+
with temporalio.workflow.unsafe.imports_passed_through():
14+
from context_propagation.shared import HEADER_KEY, user_id
15+
16+
17+
class _InputWithHeaders(Protocol):
18+
headers: Mapping[str, temporalio.api.common.v1.Payload]
19+
20+
21+
def set_header_from_context(
22+
input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter
23+
) -> None:
24+
user_id_val = user_id.get()
25+
if user_id_val:
26+
input.headers = {
27+
**input.headers,
28+
HEADER_KEY: payload_converter.to_payload(user_id_val),
29+
}
30+
31+
32+
@contextmanager
33+
def context_from_header(
34+
input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter
35+
):
36+
payload = input.headers.get(HEADER_KEY)
37+
token = (
38+
user_id.set(payload_converter.from_payload(payload, str)) if payload else None
39+
)
40+
try:
41+
yield
42+
finally:
43+
if token:
44+
user_id.reset(token)
45+
46+
47+
class ContextPropagationInterceptor(
48+
temporalio.client.Interceptor, temporalio.worker.Interceptor
49+
):
50+
"""Interceptor that can serialize/deserialize contexts."""
51+
52+
def __init__(
53+
self,
54+
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter,
55+
) -> None:
56+
self._payload_converter = payload_converter
57+
58+
def intercept_client(
59+
self, next: temporalio.client.OutboundInterceptor
60+
) -> temporalio.client.OutboundInterceptor:
61+
return _ContextPropagationClientOutboundInterceptor(
62+
next, self._payload_converter
63+
)
64+
65+
def intercept_activity(
66+
self, next: temporalio.worker.ActivityInboundInterceptor
67+
) -> temporalio.worker.ActivityInboundInterceptor:
68+
return _ContextPropagationActivityInboundInterceptor(next)
69+
70+
def workflow_interceptor_class(
71+
self, input: temporalio.worker.WorkflowInterceptorClassInput
72+
) -> Type[_ContextPropagationWorkflowInboundInterceptor]:
73+
return _ContextPropagationWorkflowInboundInterceptor
74+
75+
76+
class _ContextPropagationClientOutboundInterceptor(
77+
temporalio.client.OutboundInterceptor
78+
):
79+
def __init__(
80+
self,
81+
next: temporalio.client.OutboundInterceptor,
82+
payload_converter: temporalio.converter.PayloadConverter,
83+
) -> None:
84+
super().__init__(next)
85+
self._payload_converter = payload_converter
86+
87+
async def start_workflow(
88+
self, input: temporalio.client.StartWorkflowInput
89+
) -> temporalio.client.WorkflowHandle[Any, Any]:
90+
set_header_from_context(input, self._payload_converter)
91+
return await super().start_workflow(input)
92+
93+
async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any:
94+
set_header_from_context(input, self._payload_converter)
95+
return await super().query_workflow(input)
96+
97+
async def signal_workflow(
98+
self, input: temporalio.client.SignalWorkflowInput
99+
) -> None:
100+
set_header_from_context(input, self._payload_converter)
101+
await super().signal_workflow(input)
102+
103+
async def start_workflow_update(
104+
self, input: temporalio.client.StartWorkflowUpdateInput
105+
) -> temporalio.client.WorkflowUpdateHandle[Any]:
106+
set_header_from_context(input, self._payload_converter)
107+
return await self.next.start_workflow_update(input)
108+
109+
110+
class _ContextPropagationActivityInboundInterceptor(
111+
temporalio.worker.ActivityInboundInterceptor
112+
):
113+
async def execute_activity(
114+
self, input: temporalio.worker.ExecuteActivityInput
115+
) -> Any:
116+
with context_from_header(input, temporalio.activity.payload_converter()):
117+
return await self.next.execute_activity(input)
118+
119+
120+
class _ContextPropagationWorkflowInboundInterceptor(
121+
temporalio.worker.WorkflowInboundInterceptor
122+
):
123+
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
124+
self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound))
125+
126+
async def execute_workflow(
127+
self, input: temporalio.worker.ExecuteWorkflowInput
128+
) -> Any:
129+
with context_from_header(input, temporalio.workflow.payload_converter()):
130+
return await self.next.execute_workflow(input)
131+
132+
async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
133+
with context_from_header(input, temporalio.workflow.payload_converter()):
134+
return await self.next.handle_signal(input)
135+
136+
async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
137+
with context_from_header(input, temporalio.workflow.payload_converter()):
138+
return await self.next.handle_query(input)
139+
140+
def handle_update_validator(
141+
self, input: temporalio.worker.HandleUpdateInput
142+
) -> None:
143+
with context_from_header(input, temporalio.workflow.payload_converter()):
144+
self.next.handle_update_validator(input)
145+
146+
async def handle_update_handler(
147+
self, input: temporalio.worker.HandleUpdateInput
148+
) -> Any:
149+
with context_from_header(input, temporalio.workflow.payload_converter()):
150+
return await self.next.handle_update_handler(input)
151+
152+
153+
class _ContextPropagationWorkflowOutboundInterceptor(
154+
temporalio.worker.WorkflowOutboundInterceptor
155+
):
156+
async def signal_child_workflow(
157+
self, input: temporalio.worker.SignalChildWorkflowInput
158+
) -> None:
159+
set_header_from_context(input, temporalio.workflow.payload_converter())
160+
return await self.next.signal_child_workflow(input)
161+
162+
async def signal_external_workflow(
163+
self, input: temporalio.worker.SignalExternalWorkflowInput
164+
) -> None:
165+
set_header_from_context(input, temporalio.workflow.payload_converter())
166+
return await self.next.signal_external_workflow(input)
167+
168+
def start_activity(
169+
self, input: temporalio.worker.StartActivityInput
170+
) -> temporalio.workflow.ActivityHandle:
171+
set_header_from_context(input, temporalio.workflow.payload_converter())
172+
return self.next.start_activity(input)
173+
174+
async def start_child_workflow(
175+
self, input: temporalio.worker.StartChildWorkflowInput
176+
) -> temporalio.workflow.ChildWorkflowHandle:
177+
set_header_from_context(input, temporalio.workflow.payload_converter())
178+
return await self.next.start_child_workflow(input)
179+
180+
def start_local_activity(
181+
self, input: temporalio.worker.StartLocalActivityInput
182+
) -> temporalio.workflow.ActivityHandle:
183+
set_header_from_context(input, temporalio.workflow.payload_converter())
184+
return self.next.start_local_activity(input)

Diff for: context_propagation/shared.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from contextvars import ContextVar
2+
from typing import Optional
3+
4+
HEADER_KEY = "__my_user_id"
5+
6+
user_id: ContextVar[Optional[str]] = ContextVar("user_id", default=None)

Diff for: context_propagation/starter.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import asyncio
2+
import logging
3+
4+
from temporalio.client import Client
5+
6+
from context_propagation import interceptor, shared, workflows
7+
8+
9+
async def main():
10+
logging.basicConfig(level=logging.INFO)
11+
12+
# Set the user ID
13+
shared.user_id.set("some-user")
14+
15+
# Connect client
16+
client = await Client.connect(
17+
"localhost:7233",
18+
# Use our interceptor
19+
interceptors=[interceptor.ContextPropagationInterceptor()],
20+
)
21+
22+
# Start workflow, send signal, wait for completion, issue query
23+
handle = await client.start_workflow(
24+
workflows.SayHelloWorkflow.run,
25+
"Temporal",
26+
id=f"context-propagation-workflow-id",
27+
task_queue="context-propagation-task-queue",
28+
)
29+
await handle.signal(workflows.SayHelloWorkflow.signal_complete)
30+
result = await handle.result()
31+
logging.info(f"Workflow result: {result}")
32+
33+
34+
if __name__ == "__main__":
35+
asyncio.run(main())

Diff for: context_propagation/worker.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import asyncio
2+
import logging
3+
4+
from temporalio.client import Client
5+
from temporalio.worker import Worker
6+
7+
from context_propagation import activities, interceptor, workflows
8+
9+
interrupt_event = asyncio.Event()
10+
11+
12+
async def main():
13+
logging.basicConfig(level=logging.INFO)
14+
15+
# Connect client
16+
client = await Client.connect(
17+
"localhost:7233",
18+
# Use our interceptor
19+
interceptors=[interceptor.ContextPropagationInterceptor()],
20+
)
21+
22+
# Run a worker for the workflow
23+
async with Worker(
24+
client,
25+
task_queue="context-propagation-task-queue",
26+
activities=[activities.say_hello_activity],
27+
workflows=[workflows.SayHelloWorkflow],
28+
):
29+
# Wait until interrupted
30+
logging.info("Worker started, ctrl+c to exit")
31+
await interrupt_event.wait()
32+
logging.info("Shutting down")
33+
34+
35+
if __name__ == "__main__":
36+
loop = asyncio.new_event_loop()
37+
try:
38+
loop.run_until_complete(main())
39+
except KeyboardInterrupt:
40+
interrupt_event.set()
41+
loop.run_until_complete(loop.shutdown_asyncgens())

Diff for: context_propagation/workflows.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from datetime import timedelta
2+
3+
from temporalio import workflow
4+
5+
with workflow.unsafe.imports_passed_through():
6+
from context_propagation.activities import say_hello_activity
7+
from context_propagation.shared import user_id
8+
9+
10+
@workflow.defn
11+
class SayHelloWorkflow:
12+
def __init__(self) -> None:
13+
self._complete = False
14+
15+
@workflow.run
16+
async def run(self, name: str) -> str:
17+
workflow.logger.info(f"Workflow called by user {user_id.get()}")
18+
19+
# Wait for signal then run activity
20+
await workflow.wait_condition(lambda: self._complete)
21+
return await workflow.execute_activity(
22+
say_hello_activity, name, start_to_close_timeout=timedelta(minutes=5)
23+
)
24+
25+
@workflow.signal
26+
async def signal_complete(self) -> None:
27+
workflow.logger.info(f"Signal called by user {user_id.get()}")
28+
self._complete = True

Diff for: tests/context_propagation/__init__.py

Whitespace-only changes.

Diff for: tests/context_propagation/workflow_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import uuid
2+
3+
from temporalio import activity
4+
from temporalio.client import Client
5+
from temporalio.exceptions import ApplicationError
6+
from temporalio.worker import Worker
7+
8+
from context_propagation.interceptor import ContextPropagationInterceptor
9+
from context_propagation.shared import user_id
10+
from context_propagation.workflows import SayHelloWorkflow
11+
12+
13+
async def test_workflow_with_context_propagator(client: Client):
14+
# Mock out the activity to assert the context value
15+
@activity.defn(name="say_hello_activity")
16+
async def say_hello_activity_mock(name: str) -> str:
17+
try:
18+
assert user_id.get() == "test-user"
19+
except Exception as err:
20+
raise ApplicationError("Assertion fail", non_retryable=True) from err
21+
return f"Mock for {name}"
22+
23+
# Replace interceptors in client
24+
new_config = client.config()
25+
new_config["interceptors"] = [ContextPropagationInterceptor()]
26+
client = Client(**new_config)
27+
task_queue = f"tq-{uuid.uuid4()}"
28+
29+
async with Worker(
30+
client,
31+
task_queue=task_queue,
32+
activities=[say_hello_activity_mock],
33+
workflows=[SayHelloWorkflow],
34+
):
35+
# Set the user during start/signal, but unset after
36+
token = user_id.set("test-user")
37+
handle = await client.start_workflow(
38+
SayHelloWorkflow.run,
39+
"some-name",
40+
id=f"wf-{uuid.uuid4()}",
41+
task_queue=task_queue,
42+
)
43+
await handle.signal(SayHelloWorkflow.signal_complete)
44+
user_id.reset(token)
45+
result = await handle.result()
46+
assert result == "Mock for some-name"

0 commit comments

Comments
 (0)