@@ -281,7 +281,11 @@ def fake_info(activity_type: str) -> ActivityInfo:
281281 )
282282
283283
284- def fake_task (activity_type : str , input_json : str ) -> PollForActivityTaskResponse :
284+ def fake_task (
285+ activity_type : str ,
286+ input_json : str ,
287+ heartbeat_details : str = "" ,
288+ ) -> PollForActivityTaskResponse :
285289 return PollForActivityTaskResponse (
286290 task_token = b"task_token" ,
287291 workflow_domain = "workflow_domain" ,
@@ -298,6 +302,9 @@ def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskRespons
298302 scheduled_time = from_datetime (datetime (2020 , 1 , 2 , 3 )),
299303 started_time = from_datetime (datetime (2020 , 1 , 2 , 4 )),
300304 start_to_close_timeout = from_timedelta (timedelta (seconds = 2 )),
305+ heartbeat_details = Payload (data = heartbeat_details .encode ())
306+ if heartbeat_details
307+ else Payload (),
301308 )
302309
303310
@@ -365,3 +372,82 @@ def activity_fn():
365372 identity = "identity" ,
366373 )
367374 )
375+
376+
377+ async def test_heartbeat_details_recovery_async (client ):
378+ worker_stub = client .worker_stub
379+ worker_stub .RespondActivityTaskCompleted = AsyncMock (
380+ return_value = RespondActivityTaskCompletedResponse ()
381+ )
382+
383+ reg = Registry ()
384+
385+ @reg .activity (name = "activity_type" )
386+ async def activity_fn ():
387+ return activity .heartbeat_details (str , int )
388+
389+ executor = ActivityExecutor (client , "task_list" , "identity" , 1 , reg .get_activity )
390+
391+ await executor .execute (
392+ fake_task ("activity_type" , "" , heartbeat_details = '"progress" 42' )
393+ )
394+
395+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (
396+ RespondActivityTaskCompletedRequest (
397+ task_token = b"task_token" ,
398+ result = Payload (data = b'["progress",42]' ),
399+ identity = "identity" ,
400+ )
401+ )
402+
403+
404+ async def test_heartbeat_details_recovery_sync (client ):
405+ worker_stub = client .worker_stub
406+ worker_stub .RespondActivityTaskCompleted = AsyncMock (
407+ return_value = RespondActivityTaskCompletedResponse ()
408+ )
409+
410+ reg = Registry ()
411+
412+ @reg .activity (name = "activity_type" )
413+ def activity_fn ():
414+ return activity .heartbeat_details (str , int )
415+
416+ executor = ActivityExecutor (client , "task_list" , "identity" , 1 , reg .get_activity )
417+
418+ await executor .execute (
419+ fake_task ("activity_type" , "" , heartbeat_details = '"progress" 42' )
420+ )
421+
422+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (
423+ RespondActivityTaskCompletedRequest (
424+ task_token = b"task_token" ,
425+ result = Payload (data = b'["progress",42]' ),
426+ identity = "identity" ,
427+ )
428+ )
429+
430+
431+ async def test_heartbeat_details_empty_when_no_previous_heartbeat (client ):
432+ worker_stub = client .worker_stub
433+ worker_stub .RespondActivityTaskCompleted = AsyncMock (
434+ return_value = RespondActivityTaskCompletedResponse ()
435+ )
436+
437+ reg = Registry ()
438+
439+ @reg .activity (name = "activity_type" )
440+ async def activity_fn ():
441+ return activity .heartbeat_details ()
442+
443+ executor = ActivityExecutor (client , "task_list" , "identity" , 1 , reg .get_activity )
444+
445+ await executor .execute (fake_task ("activity_type" , "" ))
446+
447+ worker_stub .RespondActivityTaskCompleted .assert_called_once_with (
448+ RespondActivityTaskCompletedRequest (
449+ task_token = b"task_token" ,
450+ result = Payload (data = b"[]" ),
451+ identity = "identity" ,
452+ )
453+ )
0 commit comments