44
55from cadence import Client
66from cadence ._internal .activity ._definition import BaseDefinition
7+ from cadence ._internal .activity ._heartbeat import _HeartbeatSender
78from cadence .activity import ActivityInfo , ActivityContext
89from cadence .api .v1 .common_pb2 import Payload
910
@@ -14,15 +15,27 @@ def __init__(
1415 client : Client ,
1516 info : ActivityInfo ,
1617 activity_def : BaseDefinition [[Any ], Any ],
18+ heartbeat_sender : _HeartbeatSender ,
1719 ):
1820 self ._client = client
1921 self ._info = info
2022 self ._activity_def = activity_def
23+ self ._heartbeat_sender = heartbeat_sender
24+ self ._heartbeat_tasks : set [asyncio .Future [None ]] = set ()
2125
2226 async def execute (self , payload : Payload ) -> Any :
2327 params = self ._to_params (payload )
24- with self ._activate ():
25- return await self ._activity_def .impl_fn (* params )
28+ try :
29+ with self ._activate ():
30+ return await self ._activity_def .impl_fn (* params )
31+ finally :
32+ await self ._wait_pending_heartbeats ()
33+
34+ async def _wait_pending_heartbeats (self ) -> None :
35+ if not self ._heartbeat_tasks :
36+ return
37+ tasks = list (self ._heartbeat_tasks )
38+ await asyncio .gather (* tasks , return_exceptions = True )
2639
2740 def _to_params (self , payload : Payload ) -> list [Any ]:
2841 return self ._activity_def .signature .params_from_payload (
@@ -35,6 +48,13 @@ def client(self) -> Client:
3548 def info (self ) -> ActivityInfo :
3649 return self ._info
3750
51+ def heartbeat (self , * details : Any ) -> None :
52+ heartbeat_task = asyncio .create_task (
53+ self ._heartbeat_sender .send_heartbeat (* details )
54+ )
55+ self ._heartbeat_tasks .add (heartbeat_task )
56+ heartbeat_task .add_done_callback (self ._heartbeat_tasks .discard )
57+
3858
3959class _SyncContext (_Context ):
4060 def __init__ (
@@ -43,18 +63,30 @@ def __init__(
4363 info : ActivityInfo ,
4464 activity_def : BaseDefinition [[Any ], Any ],
4565 executor : ThreadPoolExecutor ,
66+ heartbeat_sender : _HeartbeatSender ,
4667 ):
47- super ().__init__ (client , info , activity_def )
68+ super ().__init__ (client , info , activity_def , heartbeat_sender )
4869 self ._executor = executor
4970
5071 async def execute (self , payload : Payload ) -> Any :
5172 params = self ._to_params (payload )
52- loop = asyncio .get_running_loop ()
53- return await loop .run_in_executor (self ._executor , self ._run , params )
73+ self ._loop = asyncio .get_running_loop ()
74+ try :
75+ return await self ._loop .run_in_executor (self ._executor , self ._run , params )
76+ finally :
77+ await self ._wait_pending_heartbeats ()
5478
5579 def _run (self , args : list [Any ]) -> Any :
5680 with self ._activate ():
5781 return self ._activity_def .impl_fn (* args )
5882
5983 def client (self ) -> Client :
6084 raise RuntimeError ("client is only supported in async activities" )
85+
86+ def heartbeat (self , * details : Any ) -> None :
87+ future = asyncio .run_coroutine_threadsafe (
88+ self ._heartbeat_sender .send_heartbeat (* details ), self ._loop
89+ )
90+ wrapped = asyncio .wrap_future (future , loop = self ._loop )
91+ self ._heartbeat_tasks .add (wrapped )
92+ wrapped .add_done_callback (self ._heartbeat_tasks .discard )
0 commit comments