Skip to content

Commit 4320bbd

Browse files
authored
Merge branch 'main' into claude-1
2 parents 501c3ad + 108ad09 commit 4320bbd

24 files changed

+1769
-17
lines changed

.github/workflows/ci_checks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ jobs:
5151
uses: astral-sh/setup-uv@v1
5252
- name: Install dependencies
5353
run: |
54-
uv sync --extra dev
54+
uv sync --extra dev --extra openai
5555
- name: Run mypy type checker
5656
run: |
57-
uv tool run mypy cadence/
57+
uv run mypy cadence/
5858
5959
test:
6060
name: Unit Tests

cadence/_internal/activity/_activity_executor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from concurrent.futures import ThreadPoolExecutor
22
from logging import getLogger
33
from traceback import format_exception
4-
from typing import Any, Callable, cast
4+
from typing import Any, Callable, Union, cast
55
from google.protobuf.duration import to_timedelta
66
from google.protobuf.timestamp import to_datetime
77

88
from cadence._internal.activity._context import _Context, _SyncContext
99
from cadence._internal.activity._definition import BaseDefinition, ExecutionStrategy
10+
from cadence._internal.activity._heartbeat import _HeartbeatSender
1011
from cadence.activity import ActivityInfo, ActivityDefinition
1112
from cadence.api.v1.common_pb2 import Failure
1213
from cadence.api.v1.service_worker_pb2 import (
@@ -46,19 +47,33 @@ async def execute(self, task: PollForActivityTaskResponse):
4647
_logger.exception("Activity failed")
4748
await self._report_failure(task, e)
4849

49-
def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
50+
def _create_context(
51+
self, task: PollForActivityTaskResponse
52+
) -> Union[_Context, _SyncContext]:
5053
activity_type = task.activity_type.name
5154
try:
5255
activity_def = cast(BaseDefinition, self._registry(activity_type))
5356
except KeyError:
5457
raise KeyError(f"Activity type not found: {activity_type}") from None
5558

5659
info = self._create_info(task)
60+
heartbeat_sender = _HeartbeatSender(
61+
self._client.worker_stub,
62+
self._data_converter,
63+
task.task_token,
64+
self._identity,
65+
)
5766

5867
if activity_def.strategy == ExecutionStrategy.ASYNC:
59-
return _Context(self._client, info, activity_def)
68+
return _Context(self._client, info, activity_def, heartbeat_sender)
6069
else:
61-
return _SyncContext(self._client, info, activity_def, self._thread_pool)
70+
return _SyncContext(
71+
self._client,
72+
info,
73+
activity_def,
74+
self._thread_pool,
75+
heartbeat_sender,
76+
)
6277

6378
async def _report_failure(
6479
self, task: PollForActivityTaskResponse, error: Exception

cadence/_internal/activity/_context.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from cadence import Client
66
from cadence._internal.activity._definition import BaseDefinition
7+
from cadence._internal.activity._heartbeat import _HeartbeatSender
78
from cadence.activity import ActivityInfo, ActivityContext
89
from cadence.api.v1.common_pb2 import Payload
910

@@ -14,15 +15,27 @@ def __init__(
1415
client: Client,
1516
info: ActivityInfo,
1617
activity_def: BaseDefinition[[Any], Any],
18+
heartbeat_sender: _HeartbeatSender,
1719
):
1820
self._client = client
1921
self._info = info
2022
self._activity_def = activity_def
23+
self._heartbeat_sender = heartbeat_sender
24+
self._heartbeat_tasks: set[asyncio.Future[None]] = set()
2125

2226
async def execute(self, payload: Payload) -> Any:
2327
params = self._to_params(payload)
24-
with self._activate():
25-
return await self._activity_def.impl_fn(*params)
28+
try:
29+
with self._activate():
30+
return await self._activity_def.impl_fn(*params)
31+
finally:
32+
await self._wait_pending_heartbeats()
33+
34+
async def _wait_pending_heartbeats(self) -> None:
35+
if not self._heartbeat_tasks:
36+
return
37+
tasks = list(self._heartbeat_tasks)
38+
await asyncio.gather(*tasks, return_exceptions=True)
2639

2740
def _to_params(self, payload: Payload) -> list[Any]:
2841
return self._activity_def.signature.params_from_payload(
@@ -35,6 +48,13 @@ def client(self) -> Client:
3548
def info(self) -> ActivityInfo:
3649
return self._info
3750

51+
def heartbeat(self, *details: Any) -> None:
52+
heartbeat_task = asyncio.create_task(
53+
self._heartbeat_sender.send_heartbeat(*details)
54+
)
55+
self._heartbeat_tasks.add(heartbeat_task)
56+
heartbeat_task.add_done_callback(self._heartbeat_tasks.discard)
57+
3858

3959
class _SyncContext(_Context):
4060
def __init__(
@@ -43,18 +63,30 @@ def __init__(
4363
info: ActivityInfo,
4464
activity_def: BaseDefinition[[Any], Any],
4565
executor: ThreadPoolExecutor,
66+
heartbeat_sender: _HeartbeatSender,
4667
):
47-
super().__init__(client, info, activity_def)
68+
super().__init__(client, info, activity_def, heartbeat_sender)
4869
self._executor = executor
4970

5071
async def execute(self, payload: Payload) -> Any:
5172
params = self._to_params(payload)
52-
loop = asyncio.get_running_loop()
53-
return await loop.run_in_executor(self._executor, self._run, params)
73+
self._loop = asyncio.get_running_loop()
74+
try:
75+
return await self._loop.run_in_executor(self._executor, self._run, params)
76+
finally:
77+
await self._wait_pending_heartbeats()
5478

5579
def _run(self, args: list[Any]) -> Any:
5680
with self._activate():
5781
return self._activity_def.impl_fn(*args)
5882

5983
def client(self) -> Client:
6084
raise RuntimeError("client is only supported in async activities")
85+
86+
def heartbeat(self, *details: Any) -> None:
87+
future = asyncio.run_coroutine_threadsafe(
88+
self._heartbeat_sender.send_heartbeat(*details), self._loop
89+
)
90+
wrapped = asyncio.wrap_future(future, loop=self._loop)
91+
self._heartbeat_tasks.add(wrapped)
92+
wrapped.add_done_callback(self._heartbeat_tasks.discard)

cadence/_internal/activity/_definition.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
ParamSpec,
1515
TypeVar,
1616
Awaitable,
17+
Type,
1718
cast,
19+
overload,
1820
Concatenate,
1921
)
2022

@@ -119,10 +121,13 @@ def __init__(
119121
super().__init__(name, wrapped, ExecutionStrategy.THREAD_POOL, signature)
120122
update_wrapper(self, wrapped)
121123

122-
def __get__(self, instance, owner):
124+
@overload
125+
def __get__(self, instance: None, owner: Type[T]) -> "SyncMethodImpl[T, P, R]": ...
126+
@overload
127+
def __get__(self, instance: T, owner: Type[T]) -> SyncImpl[P, R]: ...
128+
def __get__(self, instance: T | None, owner: Type[T]) -> "SyncImpl[P, R] | Self":
123129
if instance is None:
124130
return self
125-
# If we bound the method to an instance, then drop the self parameter. It's a normal function again
126131
return SyncImpl[P, R](
127132
partial(self._wrapped, instance), self.name, self._signature
128133
)
@@ -181,10 +186,13 @@ def __init__(
181186
else:
182187
self._is_coroutine = _COROUTINE_MARKER
183188

184-
def __get__(self, instance, owner):
189+
@overload
190+
def __get__(self, instance: None, owner: Type[T]) -> "AsyncMethodImpl[T, P, R]": ...
191+
@overload
192+
def __get__(self, instance: T, owner: Type[T]) -> AsyncImpl[P, R]: ...
193+
def __get__(self, instance: T | None, owner: Type[T]) -> "AsyncImpl[P, R] | Self":
185194
if instance is None:
186195
return self
187-
# If we bound the method to an instance, then drop the self parameter. It's a normal function again
188196
return AsyncImpl[P, R](
189197
partial(self._wrapped, instance), self.name, self._signature
190198
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from logging import getLogger
2+
from typing import Any
3+
4+
from cadence.api.v1.service_worker_pb2 import RecordActivityTaskHeartbeatRequest
5+
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
6+
from cadence.data_converter import DataConverter
7+
8+
_logger = getLogger(__name__)
9+
10+
11+
class _HeartbeatSender:
12+
def __init__(
13+
self,
14+
worker_stub: WorkerAPIStub,
15+
data_converter: DataConverter,
16+
task_token: bytes,
17+
identity: str,
18+
):
19+
self._worker_stub = worker_stub
20+
self._data_converter = data_converter
21+
self._task_token = task_token
22+
self._identity = identity
23+
24+
async def send_heartbeat(self, *details: Any) -> None:
25+
try:
26+
payload = self._data_converter.to_data(list(details))
27+
await self._worker_stub.RecordActivityTaskHeartbeat(
28+
RecordActivityTaskHeartbeatRequest(
29+
task_token=self._task_token,
30+
details=payload,
31+
identity=self._identity,
32+
)
33+
)
34+
except Exception:
35+
_logger.warning("Heartbeat failed", exc_info=True)

cadence/_internal/workflow/workflow_engine.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@
1313
from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
1414
from cadence._internal.workflow.workflow_instance import WorkflowInstance
1515
from cadence.api.v1 import history
16-
from cadence.api.v1.common_pb2 import Failure
16+
from cadence.api.v1.common_pb2 import Failure, WorkflowType
1717
from cadence.api.v1.decision_pb2 import (
1818
Decision,
1919
FailWorkflowExecutionDecisionAttributes,
2020
CompleteWorkflowExecutionDecisionAttributes,
21+
ContinueAsNewWorkflowExecutionDecisionAttributes,
2122
)
2223
from cadence.api.v1.history_pb2 import (
2324
HistoryEvent,
2425
WorkflowExecutionStartedEventAttributes,
2526
)
27+
from cadence.api.v1.tasklist_pb2 import TaskList
28+
from cadence.error import ContinueAsNewError
2629
from cadence.workflow import WorkflowDefinition, WorkflowInfo
2730

2831
logger = logging.getLogger(__name__)
@@ -187,6 +190,25 @@ def _maybe_complete_workflow(self) -> Optional[Decision]:
187190
return None
188191
except (CancelledError, InvalidStateError, FatalDecisionError):
189192
raise
193+
except ContinueAsNewError as e:
194+
# Use execution's workflow type and task list when not overridden
195+
info = self._context.info()
196+
attrs = ContinueAsNewWorkflowExecutionDecisionAttributes(
197+
workflow_type=WorkflowType(name=e.workflow_type or info.workflow_type),
198+
task_list=TaskList(name=e.task_list or info.workflow_task_list),
199+
input=info.data_converter.to_data(list(e.workflow_args)),
200+
)
201+
if e.execution_start_to_close_timeout is not None:
202+
attrs.execution_start_to_close_timeout.FromTimedelta(
203+
e.execution_start_to_close_timeout
204+
)
205+
if e.task_start_to_close_timeout is not None:
206+
attrs.task_start_to_close_timeout.FromTimedelta(
207+
e.task_start_to_close_timeout
208+
)
209+
return Decision(
210+
continue_as_new_workflow_execution_decision_attributes=attrs,
211+
)
190212
except ExceptionGroup as e:
191213
if e.subgroup((InvalidStateError, FatalDecisionError)):
192214
raise

cadence/activity.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def info() -> ActivityInfo:
6363
return ActivityContext.get().info()
6464

6565

66+
def heartbeat(*details: Any) -> None:
67+
"""Send a heartbeat for the current activity."""
68+
ActivityContext.get().heartbeat(*details)
69+
70+
6671
class ActivityContext(ABC):
6772
_var: ContextVar["ActivityContext"] = ContextVar("activity")
6873

@@ -72,6 +77,9 @@ def info(self) -> ActivityInfo: ...
7277
@abstractmethod
7378
def client(self) -> Client: ...
7479

80+
@abstractmethod
81+
def heartbeat(self, *details: Any) -> None: ...
82+
7583
@contextmanager
7684
def _activate(self) -> Iterator[None]:
7785
token = ActivityContext._var.set(self)

cadence/contrib/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)