@@ -21,22 +21,20 @@ def __init__(
2121 self ._info = info
2222 self ._activity_def = activity_def
2323 self ._heartbeat_sender = heartbeat_sender
24- self ._heartbeat_tasks : set [asyncio .Task [None ]] = set ()
24+ self ._heartbeat_tasks : set [asyncio .Future [None ]] = set ()
2525
2626 async def execute (self , payload : Payload ) -> Any :
2727 params = self ._to_params (payload )
2828 try :
2929 with self ._activate ():
3030 return await self ._activity_def .impl_fn (* params )
3131 finally :
32- await self ._cancel_pending_heartbeats ()
32+ await self ._wait_pending_heartbeats ()
3333
34- async def _cancel_pending_heartbeats (self ) -> None :
34+ async def _wait_pending_heartbeats (self ) -> None :
3535 if not self ._heartbeat_tasks :
3636 return
3737 tasks = list (self ._heartbeat_tasks )
38- for task in tasks :
39- task .cancel ()
4038 await asyncio .gather (* tasks , return_exceptions = True )
4139
4240 def _to_params (self , payload : Payload ) -> list [Any ]:
@@ -73,7 +71,10 @@ def __init__(
7371 async def execute (self , payload : Payload ) -> Any :
7472 params = self ._to_params (payload )
7573 self ._loop = asyncio .get_running_loop ()
76- return await self ._loop .run_in_executor (self ._executor , self ._run , params )
74+ try :
75+ return await self ._loop .run_in_executor (self ._executor , self ._run , params )
76+ finally :
77+ await self ._wait_pending_heartbeats ()
7778
7879 def _run (self , args : list [Any ]) -> Any :
7980 with self ._activate ():
@@ -83,6 +84,9 @@ def client(self) -> Client:
8384 raise RuntimeError ("client is only supported in async activities" )
8485
8586 def heartbeat (self , * details : Any ) -> None :
86- asyncio .run_coroutine_threadsafe (
87+ future = asyncio .run_coroutine_threadsafe (
8788 self ._heartbeat_sender .send_heartbeat (* details ), self ._loop
8889 )
90+ wrapped = asyncio .wrap_future (future , loop = self ._loop )
91+ self ._heartbeat_tasks .add (wrapped )
92+ wrapped .add_done_callback (self ._heartbeat_tasks .discard )
0 commit comments