Skip to content

Commit 0ac92cd

Browse files
committed
Add workflow invocation methods
Signed-off-by: Tim Li <ltim@uber.com>
1 parent c183c31 commit 0ac92cd

File tree

2 files changed

+569
-1
lines changed

2 files changed

+569
-1
lines changed

cadence/client.py

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import os
22
import socket
3-
from typing import TypedDict, Unpack, Any, cast
3+
import uuid
4+
from dataclasses import dataclass
5+
from datetime import timedelta
6+
from typing import TypedDict, Unpack, Any, cast, Union, Optional, Callable
47

58
from grpc import ChannelCredentials, Compression
9+
from google.protobuf.duration_pb2 import Duration
610

711
from cadence._internal.rpc.error import CadenceErrorInterceptor
812
from cadence._internal.rpc.retry import RetryInterceptor
@@ -11,10 +15,51 @@
1115
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
1216
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
1317
from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub
18+
from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse
19+
from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution
20+
from cadence.api.v1.tasklist_pb2 import TaskList
21+
from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy
1422
from cadence.data_converter import DataConverter, DefaultDataConverter
1523
from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter
1624

1725

26+
@dataclass
27+
class WorkflowRun:
28+
"""Represents a workflow run that can be used to get results."""
29+
execution: WorkflowExecution
30+
client: 'Client'
31+
32+
@property
33+
def workflow_id(self) -> str:
34+
"""Get the workflow ID."""
35+
return self.execution.workflow_id
36+
37+
@property
38+
def run_id(self) -> str:
39+
"""Get the run ID."""
40+
return self.execution.run_id
41+
42+
async def get_result(self, result_type: Optional[type] = None) -> Any: # noqa: ARG002
43+
"""Wait for workflow completion and return result."""
44+
# TODO: Implement workflow result retrieval
45+
# This would involve polling GetWorkflowExecutionHistory until completion
46+
# and extracting the result from the final event
47+
raise NotImplementedError("get_result not yet implemented")
48+
49+
50+
@dataclass
51+
class StartWorkflowOptions:
52+
"""Options for starting a workflow execution."""
53+
workflow_id: Optional[str] = None
54+
task_list: str = ""
55+
execution_start_to_close_timeout: Optional[timedelta] = None
56+
task_start_to_close_timeout: Optional[timedelta] = None
57+
workflow_id_reuse_policy: int = WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE
58+
cron_schedule: Optional[str] = None
59+
memo: Optional[dict[str, Any]] = None
60+
search_attributes: Optional[dict[str, Any]] = None
61+
62+
1863
class ClientOptions(TypedDict, total=False):
1964
domain: str
2065
target: str
@@ -88,6 +133,142 @@ async def __aenter__(self) -> 'Client':
88133
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
89134
await self.close()
90135

136+
async def _build_start_workflow_request(
137+
self,
138+
workflow: Union[str, Callable],
139+
args: tuple[Any, ...],
140+
options: StartWorkflowOptions
141+
) -> StartWorkflowExecutionRequest:
142+
"""Build a StartWorkflowExecutionRequest from parameters."""
143+
# Generate workflow ID if not provided
144+
workflow_id = options.workflow_id or str(uuid.uuid4())
145+
146+
# Validate required fields
147+
if not options.task_list:
148+
raise ValueError("task_list is required")
149+
150+
# Determine workflow type name
151+
if isinstance(workflow, str):
152+
workflow_type_name = workflow
153+
else:
154+
# For callable, use function name or __name__ attribute
155+
workflow_type_name = getattr(workflow, '__name__', str(workflow))
156+
157+
# Encode input arguments
158+
input_payload = None
159+
if args:
160+
try:
161+
input_payload = await self.data_converter.to_data(list(args))
162+
except Exception as e:
163+
raise ValueError(f"Failed to encode workflow arguments: {e}")
164+
165+
# Convert timedelta to protobuf Duration
166+
execution_timeout = None
167+
if options.execution_start_to_close_timeout:
168+
execution_timeout = Duration()
169+
execution_timeout.FromTimedelta(options.execution_start_to_close_timeout)
170+
171+
task_timeout = None
172+
if options.task_start_to_close_timeout:
173+
task_timeout = Duration()
174+
task_timeout.FromTimedelta(options.task_start_to_close_timeout)
175+
176+
# Build the request
177+
request = StartWorkflowExecutionRequest(
178+
domain=self.domain,
179+
workflow_id=workflow_id,
180+
workflow_type=WorkflowType(name=workflow_type_name),
181+
task_list=TaskList(name=options.task_list),
182+
identity=self.identity,
183+
request_id=str(uuid.uuid4())
184+
)
185+
186+
# Set workflow_id_reuse_policy separately to avoid type issues
187+
request.workflow_id_reuse_policy = options.workflow_id_reuse_policy # type: ignore[assignment]
188+
189+
# Set optional fields
190+
if input_payload:
191+
request.input.CopyFrom(input_payload)
192+
if execution_timeout:
193+
request.execution_start_to_close_timeout.CopyFrom(execution_timeout)
194+
if task_timeout:
195+
request.task_start_to_close_timeout.CopyFrom(task_timeout)
196+
if options.cron_schedule:
197+
request.cron_schedule = options.cron_schedule
198+
199+
return request
200+
201+
async def start_workflow(
202+
self,
203+
workflow: Union[str, Callable],
204+
*args,
205+
**options_kwargs
206+
) -> WorkflowExecution:
207+
"""
208+
Start a workflow execution asynchronously.
209+
210+
Args:
211+
workflow: Workflow function or workflow type name string
212+
*args: Arguments to pass to the workflow
213+
**options_kwargs: StartWorkflowOptions as keyword arguments
214+
215+
Returns:
216+
WorkflowExecution with workflow_id and run_id
217+
218+
Raises:
219+
ValueError: If required parameters are missing or invalid
220+
Exception: If the gRPC call fails
221+
"""
222+
# Convert kwargs to StartWorkflowOptions
223+
options = StartWorkflowOptions(**options_kwargs)
224+
225+
# Build the gRPC request
226+
request = await self._build_start_workflow_request(workflow, args, options)
227+
228+
# Execute the gRPC call
229+
try:
230+
response: StartWorkflowExecutionResponse = await self.workflow_stub.StartWorkflowExecution(request)
231+
232+
# Emit metrics if available
233+
if self.metrics_emitter:
234+
# TODO: Add workflow start metrics similar to Go client
235+
pass
236+
237+
execution = WorkflowExecution()
238+
execution.workflow_id = request.workflow_id
239+
execution.run_id = response.run_id
240+
return execution
241+
except Exception as e:
242+
raise Exception(f"Failed to start workflow: {e}") from e
243+
244+
async def execute_workflow(
245+
self,
246+
workflow: Union[str, Callable],
247+
*args,
248+
**options_kwargs
249+
) -> WorkflowRun:
250+
"""
251+
Start a workflow execution and return a handle to get the result.
252+
253+
Args:
254+
workflow: Workflow function or workflow type name string
255+
*args: Arguments to pass to the workflow
256+
**options_kwargs: StartWorkflowOptions as keyword arguments
257+
258+
Returns:
259+
WorkflowRun that can be used to get the workflow result
260+
261+
Raises:
262+
ValueError: If required parameters are missing or invalid
263+
Exception: If the gRPC call fails
264+
"""
265+
execution = await self.start_workflow(workflow, *args, **options_kwargs)
266+
267+
return WorkflowRun(
268+
execution=execution,
269+
client=self
270+
)
271+
91272
def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
92273
if "target" not in options:
93274
raise ValueError("target must be specified")

0 commit comments

Comments
 (0)