Skip to content

Commit 318a4f0

Browse files
committed
fix: Rewrite Type Handling
Move type handling logic from activity.py to fn_signature.py, providing utilities for understanding a functions parameters and converting calls/payloads to match the function's signature. Use this new logic in Workflows, to support any number of typed parameters. Rewrite the Activity decorators to be fully type-safe. Introduce `@activity.method` and `@activity.impl` to satisfy type safety while allowing for interface types or classes that contain activities. Fix type-safety for Workflows. This only addresses the annotations and doesn't provide a type-safe invocation mechanism like what is supported with Activities.
1 parent 5b44b12 commit 318a4f0

File tree

18 files changed

+1265
-321
lines changed

18 files changed

+1265
-321
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from ._activity_executor import ActivityExecutor
2+
from ._definition import AsyncImpl, AsyncMethodImpl, SyncImpl, SyncMethodImpl
23

34
__all__ = [
45
"ActivityExecutor",
6+
"AsyncImpl",
7+
"AsyncMethodImpl",
8+
"SyncImpl",
9+
"SyncMethodImpl",
510
]

cadence/_internal/activity/_activity_executor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from concurrent.futures import ThreadPoolExecutor
22
from logging import getLogger
33
from traceback import format_exception
4-
from typing import Any, Callable
4+
from typing import Any, Callable, cast
55
from google.protobuf.duration import to_timedelta
66
from google.protobuf.timestamp import to_datetime
77

88
from cadence._internal.activity._context import _Context, _SyncContext
9-
from cadence.activity import ActivityInfo, ActivityDefinition, ExecutionStrategy
9+
from cadence._internal.activity._definition import BaseDefinition, ExecutionStrategy
10+
from cadence.activity import ActivityInfo, ActivityDefinition
1011
from cadence.api.v1.common_pb2 import Failure
1112
from cadence.api.v1.service_worker_pb2 import (
1213
PollForActivityTaskResponse,
@@ -42,12 +43,13 @@ async def execute(self, task: PollForActivityTaskResponse):
4243
result = await context.execute(task.input)
4344
await self._report_success(task, result)
4445
except Exception as e:
46+
_logger.exception("Activity failed")
4547
await self._report_failure(task, e)
4648

4749
def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
4850
activity_type = task.activity_type.name
4951
try:
50-
activity_def = self._registry(activity_type)
52+
activity_def = cast(BaseDefinition, self._registry(activity_type))
5153
except KeyError:
5254
raise KeyError(f"Activity type not found: {activity_type}") from None
5355

cadence/_internal/activity/_context.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from typing import Any
44

55
from cadence import Client
6-
from cadence.activity import ActivityInfo, ActivityContext, ActivityDefinition
6+
from cadence._internal.activity._definition import BaseDefinition
7+
from cadence.activity import ActivityInfo, ActivityContext
78
from cadence.api.v1.common_pb2 import Payload
89

910

@@ -12,20 +13,21 @@ def __init__(
1213
self,
1314
client: Client,
1415
info: ActivityInfo,
15-
activity_fn: ActivityDefinition[[Any], Any],
16+
activity_def: BaseDefinition[[Any], Any],
1617
):
1718
self._client = client
1819
self._info = info
19-
self._activity_fn = activity_fn
20+
self._activity_def = activity_def
2021

2122
async def execute(self, payload: Payload) -> Any:
2223
params = self._to_params(payload)
2324
with self._activate():
24-
return await self._activity_fn(*params)
25+
return await self._activity_def.impl_fn(*params)
2526

2627
def _to_params(self, payload: Payload) -> list[Any]:
27-
type_hints = [param.type_hint for param in self._activity_fn.params]
28-
return self._client.data_converter.from_data(payload, type_hints)
28+
return self._activity_def.signature.params_from_payload(
29+
self._client.data_converter, payload
30+
)
2931

3032
def client(self) -> Client:
3133
return self._client
@@ -39,10 +41,10 @@ def __init__(
3941
self,
4042
client: Client,
4143
info: ActivityInfo,
42-
activity_fn: ActivityDefinition[[Any], Any],
44+
activity_def: BaseDefinition[[Any], Any],
4345
executor: ThreadPoolExecutor,
4446
):
45-
super().__init__(client, info, activity_fn)
47+
super().__init__(client, info, activity_def)
4648
self._executor = executor
4749

4850
async def execute(self, payload: Payload) -> Any:
@@ -52,7 +54,7 @@ async def execute(self, payload: Payload) -> Any:
5254

5355
def _run(self, args: list[Any]) -> Any:
5456
with self._activate():
55-
return self._activity_fn(*args)
57+
return self._activity_def.impl_fn(*args)
5658

5759
def client(self) -> Client:
5860
raise RuntimeError("client is only supported in async activities")
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import abc
2+
from abc import ABC
3+
from enum import Enum
4+
from functools import update_wrapper, partial
5+
from typing import (
6+
Generic,
7+
Callable,
8+
Unpack,
9+
Self,
10+
ParamSpec,
11+
TypeVar,
12+
Awaitable,
13+
cast,
14+
Concatenate,
15+
)
16+
17+
from cadence._internal.fn_signature import FnSignature
18+
from cadence.workflow import ActivityOptions, WorkflowContext, execute_activity
19+
20+
T = TypeVar("T")
21+
P = ParamSpec("P")
22+
R = TypeVar("R")
23+
24+
25+
class ExecutionStrategy(Enum):
26+
ASYNC = "async"
27+
THREAD_POOL = "thread_pool"
28+
29+
30+
class BaseDefinition(ABC, Generic[P, R]):
31+
def __init__(
32+
self,
33+
name: str,
34+
wrapped: Callable,
35+
strategy: ExecutionStrategy,
36+
signature: FnSignature,
37+
):
38+
self._name = name
39+
self._wrapped = wrapped
40+
self._strategy = strategy
41+
self._signature = signature
42+
self._execution_options = ActivityOptions()
43+
44+
@property
45+
def strategy(self) -> ExecutionStrategy:
46+
return self._strategy
47+
48+
@property
49+
def signature(self) -> FnSignature:
50+
return self._signature
51+
52+
@property
53+
def impl_fn(self) -> Callable:
54+
return self._wrapped
55+
56+
@property
57+
def name(self) -> str:
58+
return self._name
59+
60+
@abc.abstractmethod
61+
def clone(self) -> Self: ...
62+
63+
def rebind(self, fn: Callable) -> Self:
64+
res = self.clone()
65+
res._wrapped = fn
66+
return res
67+
68+
def with_options(self, **kwargs: Unpack[ActivityOptions]) -> Self:
69+
res = self.clone()
70+
new_opts = self._execution_options.copy()
71+
new_opts.update(kwargs)
72+
res._execution_options = new_opts
73+
return res
74+
75+
async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
76+
result_type = cast(type[R], self._signature.return_type)
77+
return await execute_activity(
78+
self._name,
79+
result_type,
80+
*self._signature.params_from_call(args, kwargs),
81+
**self._execution_options,
82+
)
83+
84+
85+
class SyncImpl(BaseDefinition[P, R]):
86+
def __init__(
87+
self,
88+
wrapped: Callable[P, R],
89+
name: str,
90+
signature: FnSignature,
91+
):
92+
super().__init__(name, wrapped, ExecutionStrategy.THREAD_POOL, signature)
93+
update_wrapper(self, wrapped)
94+
95+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
96+
if WorkflowContext.is_set():
97+
raise RuntimeError(
98+
"Attempting to invoke sync function in workflow. Use execute"
99+
)
100+
return self._wrapped(*args, **kwargs) # type: ignore
101+
102+
def clone(self) -> "SyncImpl[P, R]":
103+
return SyncImpl[P, R](self._wrapped, self._name, self._signature)
104+
105+
106+
class SyncMethodImpl(BaseDefinition[P, R], Generic[T, P, R]):
107+
def __init__(
108+
self,
109+
wrapped: Callable[Concatenate[T, P], R],
110+
name: str,
111+
signature: FnSignature,
112+
):
113+
super().__init__(name, wrapped, ExecutionStrategy.THREAD_POOL, signature)
114+
update_wrapper(self, wrapped)
115+
116+
def __get__(self, instance, owner):
117+
if instance is None:
118+
return self
119+
# If we bound the method to an instance, then drop the self parameter. It's a normal function again
120+
return SyncImpl[P, R](
121+
partial(self._wrapped, instance), self.name, self._signature
122+
)
123+
124+
def __call__(self, original_self: T, *args: P.args, **kwargs: P.kwargs) -> R:
125+
if WorkflowContext.is_set():
126+
raise RuntimeError(
127+
"Attempting to invoke sync function in workflow. Use execute"
128+
)
129+
return self._wrapped(original_self, *args, **kwargs) # type: ignore
130+
131+
def clone(self) -> "SyncMethodImpl[T, P, R]":
132+
return SyncMethodImpl[T, P, R](self._wrapped, self._name, self._signature)
133+
134+
135+
class AsyncImpl(BaseDefinition[P, R]):
136+
def __init__(
137+
self,
138+
wrapped: Callable[P, Awaitable[R]],
139+
name: str,
140+
signature: FnSignature,
141+
):
142+
super().__init__(name, wrapped, ExecutionStrategy.ASYNC, signature)
143+
update_wrapper(self, wrapped)
144+
145+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
146+
if WorkflowContext.is_set():
147+
return await self.execute(*args, **kwargs) # type: ignore
148+
return await self._wrapped(*args, **kwargs) # type: ignore
149+
150+
def clone(self) -> "AsyncImpl[P, R]":
151+
return AsyncImpl[P, R](self._wrapped, self._name, self._signature)
152+
153+
154+
class AsyncMethodImpl(BaseDefinition[P, R], Generic[T, P, R]):
155+
def __init__(
156+
self,
157+
wrapped: Callable[Concatenate[T, P], Awaitable[R]],
158+
name: str,
159+
signature: FnSignature,
160+
):
161+
super().__init__(name, wrapped, ExecutionStrategy.ASYNC, signature)
162+
update_wrapper(self, wrapped)
163+
164+
def __get__(self, instance, owner):
165+
if instance is None:
166+
return self
167+
# If we bound the method to an instance, then drop the self parameter. It's a normal function again
168+
return AsyncImpl[P, R](
169+
partial(self._wrapped, instance), self.name, self._signature
170+
)
171+
172+
async def __call__(self, original_self: T, *args: P.args, **kwargs: P.kwargs) -> R:
173+
if WorkflowContext.is_set():
174+
return await self.execute(*args, **kwargs) # type: ignore
175+
return await self._wrapped(original_self, *args, **kwargs) # type: ignore
176+
177+
def clone(self) -> "AsyncMethodImpl[T, P, R]":
178+
return AsyncMethodImpl[T, P, R](self._wrapped, self._name, self._signature)

cadence/_internal/fn_signature.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from dataclasses import dataclass
2+
from inspect import signature, Parameter
3+
from typing import (
4+
Type,
5+
Any,
6+
Callable,
7+
Sequence,
8+
get_type_hints,
9+
)
10+
11+
from cadence.api.v1.common_pb2 import Payload
12+
from cadence.data_converter import DataConverter
13+
14+
15+
@dataclass(frozen=True)
16+
class FnParameter:
17+
name: str
18+
type_hint: Type | None
19+
has_default: bool = False
20+
default_value: Any = None
21+
22+
23+
@dataclass(frozen=True)
24+
class FnSignature:
25+
params: list[FnParameter]
26+
return_type: Type
27+
28+
def params_from_call(
29+
self, args: Sequence[Any], kwargs: dict[str, Any]
30+
) -> list[Any]:
31+
result: list[Any] = []
32+
if len(args) > len(self.params):
33+
raise ValueError(f"Too many positional arguments: {args}")
34+
35+
for value, param_spec in zip(args, self.params):
36+
result.append(value)
37+
38+
i = len(result)
39+
while i < len(self.params):
40+
param = self.params[i]
41+
if param.name not in kwargs and not param.has_default:
42+
raise ValueError(f"Missing parameter: {param.name}")
43+
44+
value = kwargs.pop(param.name, param.default_value)
45+
result.append(value)
46+
i = i + 1
47+
48+
if len(kwargs) > 0:
49+
raise ValueError(f"Unexpected keyword arguments: {kwargs}")
50+
51+
return result
52+
53+
def params_from_payload(
54+
self, data_converter: DataConverter, payload: Payload
55+
) -> list[Any]:
56+
type_hints = [param.type_hint for param in self.params]
57+
return data_converter.from_data(payload, type_hints)
58+
59+
@staticmethod
60+
def of(fn: Callable) -> "FnSignature":
61+
sig = signature(fn)
62+
args = sig.parameters
63+
hints = get_type_hints(fn)
64+
params = []
65+
for name, param in args.items():
66+
# "unbound functions" aren't a thing in the Python spec. We don't have a way to determine whether the function
67+
# is part of a class or is standalone.
68+
# Filter out the self parameter and hope they followed the convention.
69+
if param.name == "self":
70+
continue
71+
default = None
72+
has_default = False
73+
if param.default != Parameter.empty:
74+
default = param.default
75+
has_default = param.default is not None
76+
if param.kind in (
77+
Parameter.POSITIONAL_ONLY,
78+
Parameter.POSITIONAL_OR_KEYWORD,
79+
):
80+
type_hint = hints.get(name, None)
81+
params.append(FnParameter(name, type_hint, has_default, default))
82+
else:
83+
raise ValueError(
84+
f"Parameters must be positional. {name} is {param.kind}, and not valid"
85+
)
86+
87+
# Treat unspecified return type as Any
88+
return_type = hints.get("return", Any)
89+
90+
return FnSignature(params, return_type)

cadence/_internal/workflow/workflow_intance.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from asyncio import CancelledError, InvalidStateError, Task
2-
from typing import Any, Optional
2+
from typing import Optional
33
from cadence._internal.workflow.deterministic_event_loop import DeterministicEventLoop
44
from cadence.api.v1.common_pb2 import Payload
55
from cadence.data_converter import DataConverter
@@ -20,11 +20,13 @@ def __init__(
2020
self._instance = workflow_definition.cls() # construct a new workflow object
2121
self._task: Optional[Task] = None
2222

23-
def start(self, input: Payload):
23+
def start(self, payload: Payload):
2424
if self._task is None:
2525
run_method = self._definition.get_run_method(self._instance)
26-
# TODO handle multiple inputs
27-
workflow_input = self._data_converter.from_data(input, [Any])
26+
# noinspection PyProtectedMember
27+
workflow_input = self._definition._run_signature.params_from_payload(
28+
self._data_converter, payload
29+
)
2830
self._task = self._loop.create_task(run_method(*workflow_input))
2931

3032
def run_once(self):

0 commit comments

Comments
 (0)