Skip to content

Commit 319f689

Browse files
committed
feat: implement heartbeat progress recovery through heartbeat details
Signed-off-by: Tim Li <ltim@uber.com>
1 parent bfe4cd7 commit 319f689

File tree

7 files changed

+117
-4
lines changed

7 files changed

+117
-4
lines changed

cadence/_internal/activity/_activity_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _create_context(
6262
self._data_converter,
6363
task.task_token,
6464
self._identity,
65+
task.heartbeat_details,
6566
)
6667

6768
if activity_def.strategy == ExecutionStrategy.ASYNC:

cadence/_internal/activity/_context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from concurrent.futures.thread import ThreadPoolExecutor
3-
from typing import Any
3+
from typing import Any, Type
44

55
from cadence import Client
66
from cadence._internal.activity._definition import BaseDefinition
@@ -55,6 +55,9 @@ def heartbeat(self, *details: Any) -> None:
5555
self._heartbeat_tasks.add(heartbeat_task)
5656
heartbeat_task.add_done_callback(self._heartbeat_tasks.discard)
5757

58+
def heartbeat_details(self, *types: Type) -> list[Any]:
59+
return self._heartbeat_sender.get_details(*types)
60+
5861

5962
class _SyncContext(_Context):
6063
def __init__(

cadence/_internal/activity/_heartbeat.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from logging import getLogger
2-
from typing import Any
2+
from typing import Any, Type
33

4+
from cadence.api.v1.common_pb2 import Payload
45
from cadence.api.v1.service_worker_pb2 import RecordActivityTaskHeartbeatRequest
56
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
67
from cadence.data_converter import DataConverter
@@ -15,11 +16,16 @@ def __init__(
1516
data_converter: DataConverter,
1617
task_token: bytes,
1718
identity: str,
19+
previous_details: Payload,
1820
):
1921
self._worker_stub = worker_stub
2022
self._data_converter = data_converter
2123
self._task_token = task_token
2224
self._identity = identity
25+
self._previous_details = previous_details
26+
27+
def get_details(self, *types: Type) -> list[Any]:
28+
return self._data_converter.from_data(self._previous_details, list(types))
2329

2430
async def send_heartbeat(self, *details: Any) -> None:
2531
try:

cadence/activity.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ def heartbeat(*details: Any) -> None:
6868
ActivityContext.get().heartbeat(*details)
6969

7070

71+
def heartbeat_details(*types: Type) -> list[Any]:
72+
"""Return heartbeat details from the previous attempt.
73+
74+
Pass type hints to decode the values into specific Python types:
75+
step, total = activity.heartbeat_details(int, int)
76+
77+
Without type hints, returns raw JSON-decoded values.
78+
"""
79+
return ActivityContext.get().heartbeat_details(*types)
80+
81+
7182
class ActivityContext(ABC):
7283
_var: ContextVar["ActivityContext"] = ContextVar("activity")
7384

@@ -80,6 +91,9 @@ def client(self) -> Client: ...
8091
@abstractmethod
8192
def heartbeat(self, *details: Any) -> None: ...
8293

94+
@abstractmethod
95+
def heartbeat_details(self, *types: Type) -> list[Any]: ...
96+
8397
@contextmanager
8498
def _activate(self) -> Iterator[None]:
8599
token = ActivityContext._var.set(self)

cadence/data_converter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _decode_whitespace_delimited(
3939
) -> List[Any]:
4040
results: List[Any] = []
4141
start, end = 0, len(payload)
42-
while start < end and len(results) < len(type_hints):
42+
while start < end and (not type_hints or len(results) < len(type_hints)):
4343
remaining = payload[start:end]
4444
(value, value_end) = self._decoder.raw_decode(remaining)
4545
start += value_end + 1
@@ -51,6 +51,8 @@ def _decode_whitespace_delimited(
5151
def _convert_into(
5252
values: List[Any], type_hints: Sequence[Type | None]
5353
) -> List[Any]:
54+
if not type_hints:
55+
return list(values)
5456
results: List[Any] = []
5557
for i, type_hint in enumerate(type_hints):
5658
if not type_hint or type_hint is Any:

tests/cadence/_internal/activity/test_activity_executor.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,11 @@ def fake_info(activity_type: str) -> ActivityInfo:
281281
)
282282

283283

284-
def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse:
284+
def fake_task(
285+
activity_type: str,
286+
input_json: str,
287+
heartbeat_details: str = "",
288+
) -> PollForActivityTaskResponse:
285289
return PollForActivityTaskResponse(
286290
task_token=b"task_token",
287291
workflow_domain="workflow_domain",
@@ -298,6 +302,9 @@ def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskRespons
298302
scheduled_time=from_datetime(datetime(2020, 1, 2, 3)),
299303
started_time=from_datetime(datetime(2020, 1, 2, 4)),
300304
start_to_close_timeout=from_timedelta(timedelta(seconds=2)),
305+
heartbeat_details=Payload(data=heartbeat_details.encode())
306+
if heartbeat_details
307+
else Payload(),
301308
)
302309

303310

@@ -365,3 +372,82 @@ def activity_fn():
365372
identity="identity",
366373
)
367374
)
375+
376+
377+
async def test_heartbeat_details_recovery_async(client):
378+
worker_stub = client.worker_stub
379+
worker_stub.RespondActivityTaskCompleted = AsyncMock(
380+
return_value=RespondActivityTaskCompletedResponse()
381+
)
382+
383+
reg = Registry()
384+
385+
@reg.activity(name="activity_type")
386+
async def activity_fn():
387+
return activity.heartbeat_details(str, int)
388+
389+
executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)
390+
391+
await executor.execute(
392+
fake_task("activity_type", "", heartbeat_details='"progress" 42')
393+
)
394+
395+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
396+
RespondActivityTaskCompletedRequest(
397+
task_token=b"task_token",
398+
result=Payload(data=b'["progress",42]'),
399+
identity="identity",
400+
)
401+
)
402+
403+
404+
async def test_heartbeat_details_recovery_sync(client):
405+
worker_stub = client.worker_stub
406+
worker_stub.RespondActivityTaskCompleted = AsyncMock(
407+
return_value=RespondActivityTaskCompletedResponse()
408+
)
409+
410+
reg = Registry()
411+
412+
@reg.activity(name="activity_type")
413+
def activity_fn():
414+
return activity.heartbeat_details(str, int)
415+
416+
executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)
417+
418+
await executor.execute(
419+
fake_task("activity_type", "", heartbeat_details='"progress" 42')
420+
)
421+
422+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
423+
RespondActivityTaskCompletedRequest(
424+
task_token=b"task_token",
425+
result=Payload(data=b'["progress",42]'),
426+
identity="identity",
427+
)
428+
)
429+
430+
431+
async def test_heartbeat_details_empty_when_no_previous_heartbeat(client):
432+
worker_stub = client.worker_stub
433+
worker_stub.RespondActivityTaskCompleted = AsyncMock(
434+
return_value=RespondActivityTaskCompletedResponse()
435+
)
436+
437+
reg = Registry()
438+
439+
@reg.activity(name="activity_type")
440+
async def activity_fn():
441+
return activity.heartbeat_details()
442+
443+
executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)
444+
445+
await executor.execute(fake_task("activity_type", ""))
446+
447+
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
448+
RespondActivityTaskCompletedRequest(
449+
task_token=b"task_token",
450+
result=Payload(data=b"[]"),
451+
identity="identity",
452+
)
453+
)

tests/cadence/_internal/activity/test_heartbeat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def sender(worker_stub, data_converter) -> _HeartbeatSender:
3232
data_converter=data_converter,
3333
task_token=b"task_token",
3434
identity="test-identity",
35+
previous_details=Payload(),
3536
)
3637

3738

0 commit comments

Comments
 (0)