|
1 | 1 | import os |
2 | 2 | import socket |
3 | | -from typing import TypedDict, Unpack, Any, cast |
| 3 | +import uuid |
| 4 | +from dataclasses import dataclass |
| 5 | +from datetime import timedelta |
| 6 | +from typing import TypedDict, Unpack, Any, cast, Union, Optional, Callable |
4 | 7 |
|
5 | 8 | from grpc import ChannelCredentials, Compression |
| 9 | +from google.protobuf.duration_pb2 import Duration |
6 | 10 |
|
7 | 11 | from cadence._internal.rpc.error import CadenceErrorInterceptor |
8 | 12 | from cadence._internal.rpc.retry import RetryInterceptor |
|
11 | 15 | from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub |
12 | 16 | from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel |
13 | 17 | from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub |
| 18 | +from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse |
| 19 | +from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution |
| 20 | +from cadence.api.v1.tasklist_pb2 import TaskList |
| 21 | +from cadence.api.v1.workflow_pb2 import WorkflowIdReusePolicy |
14 | 22 | from cadence.data_converter import DataConverter, DefaultDataConverter |
15 | 23 | from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter |
16 | 24 |
|
17 | 25 |
|
| 26 | +@dataclass |
| 27 | +class WorkflowRun: |
| 28 | + """Represents a workflow run that can be used to get results.""" |
| 29 | + execution: WorkflowExecution |
| 30 | + client: 'Client' |
| 31 | + |
| 32 | + @property |
| 33 | + def workflow_id(self) -> str: |
| 34 | + """Get the workflow ID.""" |
| 35 | + return self.execution.workflow_id |
| 36 | + |
| 37 | + @property |
| 38 | + def run_id(self) -> str: |
| 39 | + """Get the run ID.""" |
| 40 | + return self.execution.run_id |
| 41 | + |
| 42 | + async def get_result(self, result_type: Optional[type] = None) -> Any: # noqa: ARG002 |
| 43 | + """Wait for workflow completion and return result.""" |
| 44 | + # TODO: Implement workflow result retrieval |
| 45 | + # This would involve polling GetWorkflowExecutionHistory until completion |
| 46 | + # and extracting the result from the final event |
| 47 | + raise NotImplementedError("get_result not yet implemented") |
| 48 | + |
| 49 | + |
| 50 | +@dataclass |
| 51 | +class StartWorkflowOptions: |
| 52 | + """Options for starting a workflow execution.""" |
| 53 | + workflow_id: Optional[str] = None |
| 54 | + task_list: str = "" |
| 55 | + execution_start_to_close_timeout: Optional[timedelta] = None |
| 56 | + task_start_to_close_timeout: Optional[timedelta] = None |
| 57 | + workflow_id_reuse_policy: int = WorkflowIdReusePolicy.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE |
| 58 | + cron_schedule: Optional[str] = None |
| 59 | + memo: Optional[dict[str, Any]] = None |
| 60 | + search_attributes: Optional[dict[str, Any]] = None |
| 61 | + |
| 62 | + |
18 | 63 | class ClientOptions(TypedDict, total=False): |
19 | 64 | domain: str |
20 | 65 | target: str |
@@ -88,6 +133,142 @@ async def __aenter__(self) -> 'Client': |
88 | 133 | async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: |
89 | 134 | await self.close() |
90 | 135 |
|
| 136 | + async def _build_start_workflow_request( |
| 137 | + self, |
| 138 | + workflow: Union[str, Callable], |
| 139 | + args: tuple[Any, ...], |
| 140 | + options: StartWorkflowOptions |
| 141 | + ) -> StartWorkflowExecutionRequest: |
| 142 | + """Build a StartWorkflowExecutionRequest from parameters.""" |
| 143 | + # Generate workflow ID if not provided |
| 144 | + workflow_id = options.workflow_id or str(uuid.uuid4()) |
| 145 | + |
| 146 | + # Validate required fields |
| 147 | + if not options.task_list: |
| 148 | + raise ValueError("task_list is required") |
| 149 | + |
| 150 | + # Determine workflow type name |
| 151 | + if isinstance(workflow, str): |
| 152 | + workflow_type_name = workflow |
| 153 | + else: |
| 154 | + # For callable, use function name or __name__ attribute |
| 155 | + workflow_type_name = getattr(workflow, '__name__', str(workflow)) |
| 156 | + |
| 157 | + # Encode input arguments |
| 158 | + input_payload = None |
| 159 | + if args: |
| 160 | + try: |
| 161 | + input_payload = await self.data_converter.to_data(list(args)) |
| 162 | + except Exception as e: |
| 163 | + raise ValueError(f"Failed to encode workflow arguments: {e}") |
| 164 | + |
| 165 | + # Convert timedelta to protobuf Duration |
| 166 | + execution_timeout = None |
| 167 | + if options.execution_start_to_close_timeout: |
| 168 | + execution_timeout = Duration() |
| 169 | + execution_timeout.FromTimedelta(options.execution_start_to_close_timeout) |
| 170 | + |
| 171 | + task_timeout = None |
| 172 | + if options.task_start_to_close_timeout: |
| 173 | + task_timeout = Duration() |
| 174 | + task_timeout.FromTimedelta(options.task_start_to_close_timeout) |
| 175 | + |
| 176 | + # Build the request |
| 177 | + request = StartWorkflowExecutionRequest( |
| 178 | + domain=self.domain, |
| 179 | + workflow_id=workflow_id, |
| 180 | + workflow_type=WorkflowType(name=workflow_type_name), |
| 181 | + task_list=TaskList(name=options.task_list), |
| 182 | + identity=self.identity, |
| 183 | + request_id=str(uuid.uuid4()) |
| 184 | + ) |
| 185 | + |
| 186 | + # Set workflow_id_reuse_policy separately to avoid type issues |
| 187 | + request.workflow_id_reuse_policy = options.workflow_id_reuse_policy # type: ignore[assignment] |
| 188 | + |
| 189 | + # Set optional fields |
| 190 | + if input_payload: |
| 191 | + request.input.CopyFrom(input_payload) |
| 192 | + if execution_timeout: |
| 193 | + request.execution_start_to_close_timeout.CopyFrom(execution_timeout) |
| 194 | + if task_timeout: |
| 195 | + request.task_start_to_close_timeout.CopyFrom(task_timeout) |
| 196 | + if options.cron_schedule: |
| 197 | + request.cron_schedule = options.cron_schedule |
| 198 | + |
| 199 | + return request |
| 200 | + |
| 201 | + async def start_workflow( |
| 202 | + self, |
| 203 | + workflow: Union[str, Callable], |
| 204 | + *args, |
| 205 | + **options_kwargs |
| 206 | + ) -> WorkflowExecution: |
| 207 | + """ |
| 208 | + Start a workflow execution asynchronously. |
| 209 | +
|
| 210 | + Args: |
| 211 | + workflow: Workflow function or workflow type name string |
| 212 | + *args: Arguments to pass to the workflow |
| 213 | + **options_kwargs: StartWorkflowOptions as keyword arguments |
| 214 | +
|
| 215 | + Returns: |
| 216 | + WorkflowExecution with workflow_id and run_id |
| 217 | +
|
| 218 | + Raises: |
| 219 | + ValueError: If required parameters are missing or invalid |
| 220 | + Exception: If the gRPC call fails |
| 221 | + """ |
| 222 | + # Convert kwargs to StartWorkflowOptions |
| 223 | + options = StartWorkflowOptions(**options_kwargs) |
| 224 | + |
| 225 | + # Build the gRPC request |
| 226 | + request = await self._build_start_workflow_request(workflow, args, options) |
| 227 | + |
| 228 | + # Execute the gRPC call |
| 229 | + try: |
| 230 | + response: StartWorkflowExecutionResponse = await self.workflow_stub.StartWorkflowExecution(request) |
| 231 | + |
| 232 | + # Emit metrics if available |
| 233 | + if self.metrics_emitter: |
| 234 | + # TODO: Add workflow start metrics similar to Go client |
| 235 | + pass |
| 236 | + |
| 237 | + execution = WorkflowExecution() |
| 238 | + execution.workflow_id = request.workflow_id |
| 239 | + execution.run_id = response.run_id |
| 240 | + return execution |
| 241 | + except Exception as e: |
| 242 | + raise Exception(f"Failed to start workflow: {e}") from e |
| 243 | + |
| 244 | + async def execute_workflow( |
| 245 | + self, |
| 246 | + workflow: Union[str, Callable], |
| 247 | + *args, |
| 248 | + **options_kwargs |
| 249 | + ) -> WorkflowRun: |
| 250 | + """ |
| 251 | + Start a workflow execution and return a handle to get the result. |
| 252 | +
|
| 253 | + Args: |
| 254 | + workflow: Workflow function or workflow type name string |
| 255 | + *args: Arguments to pass to the workflow |
| 256 | + **options_kwargs: StartWorkflowOptions as keyword arguments |
| 257 | +
|
| 258 | + Returns: |
| 259 | + WorkflowRun that can be used to get the workflow result |
| 260 | +
|
| 261 | + Raises: |
| 262 | + ValueError: If required parameters are missing or invalid |
| 263 | + Exception: If the gRPC call fails |
| 264 | + """ |
| 265 | + execution = await self.start_workflow(workflow, *args, **options_kwargs) |
| 266 | + |
| 267 | + return WorkflowRun( |
| 268 | + execution=execution, |
| 269 | + client=self |
| 270 | + ) |
| 271 | + |
91 | 272 | def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: |
92 | 273 | if "target" not in options: |
93 | 274 | raise ValueError("target must be specified") |
|
0 commit comments