Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cadence/_internal/activity/_activity_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _create_context(
self._data_converter,
task.task_token,
self._identity,
task.heartbeat_details,
)

if activity_def.strategy == ExecutionStrategy.ASYNC:
Expand Down
5 changes: 4 additions & 1 deletion cadence/_internal/activity/_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any
from typing import Any, Type

from cadence import Client
from cadence._internal.activity._definition import BaseDefinition
Expand Down Expand Up @@ -55,6 +55,9 @@ def heartbeat(self, *details: Any) -> None:
self._heartbeat_tasks.add(heartbeat_task)
heartbeat_task.add_done_callback(self._heartbeat_tasks.discard)

def heartbeat_details(self, *types: Type) -> list[Any]:
return self._heartbeat_sender.get_details(*types)


class _SyncContext(_Context):
def __init__(
Expand Down
9 changes: 8 additions & 1 deletion cadence/_internal/activity/_heartbeat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from logging import getLogger
from typing import Any
from typing import Any, Type

from cadence.api.v1.common_pb2 import Payload
from cadence.api.v1.service_worker_pb2 import RecordActivityTaskHeartbeatRequest
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from cadence.data_converter import DataConverter
Expand All @@ -15,11 +16,16 @@ def __init__(
data_converter: DataConverter,
task_token: bytes,
identity: str,
previous_details: Payload,
):
self._worker_stub = worker_stub
self._data_converter = data_converter
self._task_token = task_token
self._identity = identity
self._previous_details = previous_details

def get_details(self, *types: Type) -> list[Any]:
return self._data_converter.from_data(self._previous_details, list(types))

async def send_heartbeat(self, *details: Any) -> None:
try:
Expand All @@ -31,5 +37,6 @@ async def send_heartbeat(self, *details: Any) -> None:
identity=self._identity,
)
)
self._previous_details = payload
except Exception:
_logger.warning("Heartbeat failed", exc_info=True)
14 changes: 14 additions & 0 deletions cadence/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ def heartbeat(*details: Any) -> None:
ActivityContext.get().heartbeat(*details)


def heartbeat_details(*types: Type) -> list[Any]:
"""Return heartbeat details from the previous attempt.

Pass type hints to decode the values into specific Python types:
step, total = activity.heartbeat_details(int, int)

Without type hints, returns raw JSON-decoded values.
"""
return ActivityContext.get().heartbeat_details(*types)


class ActivityContext(ABC):
_var: ContextVar["ActivityContext"] = ContextVar("activity")

Expand All @@ -80,6 +91,9 @@ def client(self) -> Client: ...
@abstractmethod
def heartbeat(self, *details: Any) -> None: ...

@abstractmethod
def heartbeat_details(self, *types: Type) -> list[Any]: ...

@contextmanager
def _activate(self) -> Iterator[None]:
token = ActivityContext._var.set(self)
Expand Down
3 changes: 3 additions & 0 deletions cadence/data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def from_data(
if not payload.data:
return DefaultDataConverter._convert_into([], type_hints)

if not type_hints:
type_hints = [None]

payload_str = payload.data.decode()

return self._decode_whitespace_delimited(payload_str, type_hints)
Expand Down
146 changes: 145 additions & 1 deletion tests/cadence/_internal/activity/test_activity_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,11 @@ def fake_info(activity_type: str) -> ActivityInfo:
)


def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskResponse:
def fake_task(
activity_type: str,
input_json: str,
heartbeat_details: str = "",
) -> PollForActivityTaskResponse:
return PollForActivityTaskResponse(
task_token=b"task_token",
workflow_domain="workflow_domain",
Expand All @@ -298,6 +302,9 @@ def fake_task(activity_type: str, input_json: str) -> PollForActivityTaskRespons
scheduled_time=from_datetime(datetime(2020, 1, 2, 3)),
started_time=from_datetime(datetime(2020, 1, 2, 4)),
start_to_close_timeout=from_timedelta(timedelta(seconds=2)),
heartbeat_details=Payload(data=heartbeat_details.encode())
if heartbeat_details
else Payload(),
)


Expand Down Expand Up @@ -365,3 +372,140 @@ def activity_fn():
identity="identity",
)
)


async def test_heartbeat_details_recovery_async(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(
return_value=RespondActivityTaskCompletedResponse()
)

reg = Registry()

@reg.activity(name="activity_type")
async def activity_fn():
return activity.heartbeat_details(str, int)

executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)

await executor.execute(
fake_task("activity_type", "", heartbeat_details='"progress" 42')
)

worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
RespondActivityTaskCompletedRequest(
task_token=b"task_token",
result=Payload(data=b'["progress",42]'),
identity="identity",
)
)


async def test_heartbeat_details_recovery_sync(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(
return_value=RespondActivityTaskCompletedResponse()
)

reg = Registry()

@reg.activity(name="activity_type")
def activity_fn():
return activity.heartbeat_details(str, int)

executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)

await executor.execute(
fake_task("activity_type", "", heartbeat_details='"progress" 42')
)

worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
RespondActivityTaskCompletedRequest(
task_token=b"task_token",
result=Payload(data=b'["progress",42]'),
identity="identity",
)
)


async def test_heartbeat_details_empty_when_no_previous_heartbeat(client):
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskCompleted = AsyncMock(
return_value=RespondActivityTaskCompletedResponse()
)

reg = Registry()

@reg.activity(name="activity_type")
async def activity_fn():
return activity.heartbeat_details()

executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)

await executor.execute(fake_task("activity_type", ""))

worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
RespondActivityTaskCompletedRequest(
task_token=b"task_token",
result=Payload(data=b"[]"),
identity="identity",
)
)


async def test_heartbeat_details_recovery_across_attempts(client):
"""Simulate retry: first attempt has no heartbeat details and fails,
second attempt receives heartbeat details from the server and succeeds."""
worker_stub = client.worker_stub
worker_stub.RespondActivityTaskFailed = AsyncMock(
return_value=RespondActivityTaskFailedResponse()
)
worker_stub.RespondActivityTaskCompleted = AsyncMock(
return_value=RespondActivityTaskCompletedResponse()
)
worker_stub.RecordActivityTaskHeartbeat = AsyncMock(
return_value=RecordActivityTaskHeartbeatResponse()
)

attempt_count = 0

reg = Registry()

@reg.activity(name="activity_type")
async def activity_fn():
nonlocal attempt_count
attempt_count += 1

details = activity.heartbeat_details()
if not details:
activity.heartbeat("step1", 50)
raise RuntimeError("simulated failure on first attempt")

return activity.heartbeat_details(str, int)

executor = ActivityExecutor(client, "task_list", "identity", 1, reg.get_activity)

# First attempt: no heartbeat details, activity heartbeats progress then fails
await executor.execute(fake_task("activity_type", ""))
worker_stub.RespondActivityTaskFailed.assert_called_once()
worker_stub.RecordActivityTaskHeartbeat.assert_called_once_with(
RecordActivityTaskHeartbeatRequest(
task_token=b"task_token",
details=Payload(data=b'"step1" 50'),
identity="identity",
)
)

# Second attempt: server provides heartbeat details from previous attempt
await executor.execute(
fake_task("activity_type", "", heartbeat_details='"step1" 50')
)
worker_stub.RespondActivityTaskCompleted.assert_called_once_with(
RespondActivityTaskCompletedRequest(
task_token=b"task_token",
result=Payload(data=b'["step1",50]'),
identity="identity",
)
)

assert attempt_count == 2
29 changes: 29 additions & 0 deletions tests/cadence/_internal/activity/test_heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def sender(worker_stub, data_converter) -> _HeartbeatSender:
data_converter=data_converter,
task_token=b"task_token",
identity="test-identity",
previous_details=Payload(),
)


Expand Down Expand Up @@ -66,3 +67,31 @@ async def test_heartbeat_no_details(sender, worker_stub):
call = worker_stub.RecordActivityTaskHeartbeat.call_args[0][0]
assert call.task_token == b"task_token"
assert call.identity == "test-identity"


async def test_heartbeat_updates_previous_details(sender, worker_stub):
await sender.send_heartbeat("step1", 10)

details = sender.get_details(str, int)
assert details == ["step1", 10]


async def test_heartbeat_details_not_updated_on_failure(
worker_stub,
data_converter,
):
worker_stub.RecordActivityTaskHeartbeat = AsyncMock(
side_effect=Exception("rpc error")
)
sender = _HeartbeatSender(
worker_stub=worker_stub,
data_converter=data_converter,
task_token=b"task_token",
identity="test-identity",
previous_details=Payload(data=b'"old"'),
)

await sender.send_heartbeat("new_value")

details = sender.get_details(str)
assert details == ["old"]
Loading