Skip to content

Commit c6306bb

Browse files
authored
Merge branch 'main' into heartbeat
2 parents a7e576c + bb59501 commit c6306bb

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

cadence/_internal/workflow/context.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
77
from cadence.api.v1.common_pb2 import ActivityType
8-
from cadence.api.v1.decision_pb2 import ScheduleActivityTaskDecisionAttributes
8+
from cadence.api.v1.decision_pb2 import (
9+
ScheduleActivityTaskDecisionAttributes,
10+
StartTimerDecisionAttributes,
11+
)
912
from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind
1013
from cadence.data_converter import DataConverter
1114
from cadence.workflow import WorkflowContext, WorkflowInfo, ResultType, ActivityOptions
@@ -97,6 +100,15 @@ async def execute_activity(
97100

98101
return cast(ResultType, result)
99102

103+
async def start_timer(self, duration: timedelta):
104+
if duration.total_seconds() <= 0: # shortcut
105+
return
106+
await self._decision_manager.start_timer(
107+
StartTimerDecisionAttributes(
108+
start_to_fire_timeout=duration,
109+
)
110+
)
111+
100112
def set_replay_mode(self, replay: bool) -> None:
101113
"""Set whether the workflow is currently in replay mode."""
102114
self._replay_mode = replay

cadence/workflow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ async def execute_activity(
4444
)
4545

4646

47+
async def sleep(duration: timedelta) -> None:
48+
return await WorkflowContext.get().start_timer(duration)
49+
50+
4751
T = TypeVar("T", bound=Callable[..., Any])
4852
C = TypeVar("C")
4953

@@ -277,6 +281,9 @@ async def execute_activity(
277281
**kwargs: Unpack[ActivityOptions],
278282
) -> ResultType: ...
279283

284+
@abstractmethod
285+
async def start_timer(self, duration: timedelta) -> None: ...
286+
280287
@contextmanager
281288
def _activate(self) -> Iterator["WorkflowContext"]:
282289
token = WorkflowContext._var.set(self)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import asyncio
2+
from datetime import timedelta
3+
4+
from cadence import Registry, workflow
5+
from cadence.api.v1.history_pb2 import EventFilterType
6+
from cadence.api.v1.service_workflow_pb2 import (
7+
GetWorkflowExecutionHistoryRequest,
8+
GetWorkflowExecutionHistoryResponse,
9+
)
10+
from tests.integration_tests.helper import CadenceHelper
11+
12+
13+
registry = Registry()
14+
15+
16+
@registry.activity()
17+
async def echo(message: str) -> str:
18+
return message
19+
20+
21+
@registry.workflow()
22+
class TimerWorkflow:
23+
@workflow.run
24+
async def run(self) -> str:
25+
await workflow.sleep(timedelta(seconds=1))
26+
await echo("hello")
27+
return "hello"
28+
29+
30+
@registry.workflow()
31+
class TimerCancelWorkflow:
32+
@workflow.run
33+
async def run(self) -> str:
34+
task = asyncio.create_task(workflow.sleep(timedelta(seconds=1)))
35+
await echo("hello")
36+
task.cancel()
37+
return "hello"
38+
39+
40+
async def test_timer(helper: CadenceHelper):
41+
async with helper.worker(registry) as worker:
42+
execution = await worker.client.start_workflow(
43+
"TimerWorkflow",
44+
task_list=worker.task_list,
45+
execution_start_to_close_timeout=timedelta(seconds=10),
46+
)
47+
48+
# wait for close event
49+
await worker.client.workflow_stub.GetWorkflowExecutionHistory(
50+
GetWorkflowExecutionHistoryRequest(
51+
domain=worker.client.domain,
52+
workflow_execution=execution,
53+
wait_for_new_event=True,
54+
history_event_filter_type=EventFilterType.EVENT_FILTER_TYPE_CLOSE_EVENT,
55+
skip_archival=True,
56+
)
57+
)
58+
59+
response: GetWorkflowExecutionHistoryResponse = (
60+
await worker.client.workflow_stub.GetWorkflowExecutionHistory(
61+
GetWorkflowExecutionHistoryRequest(
62+
domain=worker.client.domain,
63+
workflow_execution=execution,
64+
)
65+
)
66+
)
67+
68+
events = response.history.events
69+
70+
timer_started_events = [
71+
e for e in events if e.HasField("timer_started_event_attributes")
72+
]
73+
timer_fired_events = [
74+
e for e in events if e.HasField("timer_fired_event_attributes")
75+
]
76+
assert len(timer_started_events) == 1
77+
assert len(timer_fired_events) == 1
78+
79+
activity_scheduled_events = [
80+
e for e in events if e.HasField("activity_task_scheduled_event_attributes")
81+
]
82+
assert len(activity_scheduled_events) == 1
83+
84+
timer_started_time = timer_started_events[0].event_time.ToDatetime()
85+
activity_scheduled_time = activity_scheduled_events[0].event_time.ToDatetime()
86+
assert activity_scheduled_time >= timer_started_time + timedelta(seconds=1)
87+
88+
89+
async def test_timer_cancel(helper: CadenceHelper):
90+
async with helper.worker(registry) as worker:
91+
execution = await worker.client.start_workflow(
92+
"TimerCancelWorkflow",
93+
task_list=worker.task_list,
94+
execution_start_to_close_timeout=timedelta(seconds=10),
95+
)
96+
97+
# wait for close event
98+
await worker.client.workflow_stub.GetWorkflowExecutionHistory(
99+
GetWorkflowExecutionHistoryRequest(
100+
domain=worker.client.domain,
101+
workflow_execution=execution,
102+
wait_for_new_event=True,
103+
history_event_filter_type=EventFilterType.EVENT_FILTER_TYPE_CLOSE_EVENT,
104+
skip_archival=True,
105+
)
106+
)
107+
108+
response: GetWorkflowExecutionHistoryResponse = (
109+
await worker.client.workflow_stub.GetWorkflowExecutionHistory(
110+
GetWorkflowExecutionHistoryRequest(
111+
domain=worker.client.domain,
112+
workflow_execution=execution,
113+
)
114+
)
115+
)
116+
117+
events = response.history.events
118+
119+
timer_started_events = [
120+
e for e in events if e.HasField("timer_started_event_attributes")
121+
]
122+
timer_canceled_events = [
123+
e for e in events if e.HasField("timer_canceled_event_attributes")
124+
]
125+
assert len(timer_started_events) == 1
126+
assert len(timer_canceled_events) == 1
127+
128+
activity_scheduled_events = [
129+
e for e in events if e.HasField("activity_task_scheduled_event_attributes")
130+
]
131+
assert len(activity_scheduled_events) == 1
132+
133+
timer_started_time = timer_started_events[0].event_time.ToDatetime()
134+
activity_scheduled_time = activity_scheduled_events[0].event_time.ToDatetime()
135+
assert activity_scheduled_time < timer_started_time + timedelta(seconds=1)

0 commit comments

Comments
 (0)