From 2745658225b108d5882d20f821caea0900bcf80f Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 9 Apr 2025 08:45:53 -0700 Subject: [PATCH 01/12] Added unit test for generic dataclass --- tests/worker/test_workflow.py | 667 ++++++++++++++++++---------------- 1 file changed, 350 insertions(+), 317 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 67aadd0f..67b34a31 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -181,7 +181,7 @@ async def test_workflow_multi_param(client: Client): # This test is mostly just here to confirm MyPy type checks the multi-param # overload approach properly async with new_worker( - client, MultiParamWorkflow, activities=[multi_param_activity] + client, MultiParamWorkflow, activities=[multi_param_activity] ) as worker: result = await client.execute_workflow( MultiParamWorkflow.run, @@ -267,7 +267,7 @@ def get_history_info(self) -> HistoryInfo: async def test_workflow_history_info( - client: Client, env: WorkflowEnvironment, continue_as_new_suggest_history_count: int + client: Client, env: WorkflowEnvironment, continue_as_new_suggest_history_count: int ): if env.supports_time_skipping: pytest.skip("Java test server does not support should continue as new") @@ -589,7 +589,7 @@ def new_handler(name: str, *args: Any) -> str: async def test_workflow_signal_qnd_query_handlers_old_dynamic_style(client: Client): async with new_worker( - client, SignalAndQueryHandlersOldDynamicStyleWorkflow + client, SignalAndQueryHandlersOldDynamicStyleWorkflow ) as worker: handle = await client.start_workflow( SignalAndQueryHandlersOldDynamicStyleWorkflow.run, @@ -654,9 +654,9 @@ async def test_workflow_bad_signal_param(client: Client): await handle.signal("some_signal", 123) await handle.signal("some_signal", BadSignalParam(some_str="finish")) assert [ - BadSignalParam(some_str="good"), - BadSignalParam(some_str="finish"), - ] == await handle.result() + BadSignalParam(some_str="good"), + BadSignalParam(some_str="finish"), + ] == await handle.result() @workflow.defn @@ -779,7 +779,7 @@ async def run(self, name: str) -> str: async def test_workflow_simple_activity(client: Client): async with new_worker( - client, SimpleActivityWorkflow, activities=[say_hello] + client, SimpleActivityWorkflow, activities=[say_hello] ) as worker: result = await client.execute_workflow( SimpleActivityWorkflow.run, @@ -801,7 +801,7 @@ async def run(self, name: str) -> str: async def test_workflow_simple_local_activity(client: Client): async with new_worker( - client, SimpleLocalActivityWorkflow, activities=[say_hello] + client, SimpleLocalActivityWorkflow, activities=[say_hello] ) as worker: result = await client.execute_workflow( SimpleLocalActivityWorkflow.run, @@ -903,7 +903,7 @@ async def test_workflow_cancel_activity(client: Client, local: bool): activity_inst = ActivityWaitCancelNotify() async with new_worker( - client, CancelActivityWorkflow, activities=[activity_inst.wait_cancel] + client, CancelActivityWorkflow, activities=[activity_inst.wait_cancel] ) as worker: # Try cancel - confirm error and activity was sent the cancel handle = await client.start_workflow( @@ -1070,7 +1070,7 @@ def started(self) -> bool: @pytest.mark.parametrize("activity", [True, False]) async def test_workflow_uncaught_cancel(client: Client, activity: bool): async with new_worker( - client, UncaughtCancelWorkflow, activities=[wait_forever] + client, UncaughtCancelWorkflow, activities=[wait_forever] ) as worker: # Start workflow waiting on activity or child workflow, cancel it, and # confirm the workflow is shown as cancelled @@ -1231,7 +1231,7 @@ async def run(self, args: SignalExternalWorkflowArgs) -> None: async def test_workflow_signal_external(client: Client): async with new_worker( - client, SignalExternalWorkflow, ReturnSignalWorkflow + client, SignalExternalWorkflow, ReturnSignalWorkflow ) as worker: # Start return signal, then signal and check that it got signalled return_signal_handle = await client.start_workflow( @@ -1305,7 +1305,7 @@ async def child(id: str): async def test_workflow_cancel_multi(client: Client): async with new_worker( - client, MultiCancelWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, MultiCancelWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: results = await client.execute_workflow( MultiCancelWorkflow.run, @@ -1400,7 +1400,7 @@ async def wait_and_swallow(self, aw: Awaitable) -> None: async def test_workflow_cancel_unsent(client: Client): workflow_id = f"workflow-{uuid.uuid4()}" async with new_worker( - client, CancelUnsentWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, CancelUnsentWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: await client.execute_workflow( CancelUnsentWorkflow.run, @@ -1419,14 +1419,14 @@ async def test_workflow_cancel_unsent(client: Client): # No activities or children scheduled assert event.event_type is not EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED assert ( - event.event_type - is not EventType.EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED + event.event_type + is not EventType.EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED ) # Make sure timer is just our 0.01 timer if event.event_type is EventType.EVENT_TYPE_TIMER_STARTED: assert ( - event.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() - == 10 + event.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() + == 10 ) found_timer = True assert found_timer @@ -1445,7 +1445,7 @@ async def run(self) -> None: async def test_workflow_activity_timeout(client: Client): async with new_worker( - client, ActivityTimeoutWorkflow, activities=[wait_cancel] + client, ActivityTimeoutWorkflow, activities=[wait_cancel] ) as worker: with pytest.raises(WorkflowFailureError) as err: await client.execute_workflow( @@ -1522,7 +1522,7 @@ def create_instance(self, det: WorkflowInstanceDetails) -> WorkflowInstance: class CustomWorkflowInstance(WorkflowInstance): def __init__( - self, runner: CustomWorkflowRunner, unsandboxed: WorkflowInstance + self, runner: CustomWorkflowRunner, unsandboxed: WorkflowInstance ) -> None: super().__init__() self._runner = runner @@ -1546,13 +1546,13 @@ async def test_workflow_with_custom_runner(client: Client): assert result == "Hello, Temporal!" # Confirm first activation and last non-eviction-reply completion assert ( - runner._pairs[0][0].jobs[0].initialize_workflow.workflow_type == "HelloWorkflow" + runner._pairs[0][0].jobs[0].initialize_workflow.workflow_type == "HelloWorkflow" ) assert ( - runner._pairs[-2][-1] - .successful.commands[0] - .complete_workflow_execution.result.data - == b'"Hello, Temporal!"' + runner._pairs[-2][-1] + .successful.commands[0] + .complete_workflow_execution.result.data + == b'"Hello, Temporal!"' ) @@ -1604,7 +1604,7 @@ async def test_workflow_continue_as_new(client: Client, env: WorkflowEnvironment def search_attributes_to_serializable( - attrs: Union[SearchAttributes, TypedSearchAttributes], + attrs: Union[SearchAttributes, TypedSearchAttributes], ) -> Mapping[str, Any]: if isinstance(attrs, TypedSearchAttributes): return { @@ -1785,7 +1785,7 @@ async def describe_attributes_untyped(handle: WorkflowHandle) -> SearchAttribute } async def describe_attributes_typed( - handle: WorkflowHandle, + handle: WorkflowHandle, ) -> TypedSearchAttributes: # Remove any not our prefix attrs = (await handle.describe()).typed_search_attributes @@ -1940,13 +1940,13 @@ def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: async def test_workflow_logging(client: Client, env: WorkflowEnvironment): workflow.logger.full_workflow_info_on_extra = True with LogCapturer().logs_captured( - workflow.logger.base_logger, activity.logger.base_logger + workflow.logger.base_logger, activity.logger.base_logger ) as capturer: # Log two signals and kill worker before completing. Need to disable # workflow cache since we restart the worker and don't want to pay the # sticky queue penalty. async with new_worker( - client, LoggingWorkflow, max_cached_workflows=0 + client, LoggingWorkflow, max_cached_workflows=0 ) as worker: handle = await client.start_workflow( LoggingWorkflow.run, @@ -1973,30 +1973,30 @@ async def test_workflow_logging(client: Client, env: WorkflowEnvironment): # Also make sure it has some workflow info and correct funcName record = capturer.find_log("Signal: signal 1") assert ( - record - and record.__dict__["temporal_workflow"]["workflow_type"] - == "LoggingWorkflow" - and record.funcName == "my_signal" + record + and record.__dict__["temporal_workflow"]["workflow_type"] + == "LoggingWorkflow" + and record.funcName == "my_signal" ) # Since we enabled full info, make sure it's there assert isinstance(record.__dict__["workflow_info"], workflow.Info) # Check the log emitted by the update execution. record = capturer.find_log("Update: update 1") assert ( - record - and record.__dict__["temporal_workflow"]["update_id"] == "update-1" - and record.__dict__["temporal_workflow"]["update_name"] == "my_update" - and "'update_id': 'update-1'" in record.message - and "'update_name': 'my_update'" in record.message + record + and record.__dict__["temporal_workflow"]["update_id"] == "update-1" + and record.__dict__["temporal_workflow"]["update_name"] == "my_update" + and "'update_id': 'update-1'" in record.message + and "'update_name': 'my_update'" in record.message ) # Clear queue and start a new one with more signals capturer.log_queue.queue.clear() async with new_worker( - client, - LoggingWorkflow, - task_queue=worker.task_queue, - max_cached_workflows=0, + client, + LoggingWorkflow, + task_queue=worker.task_queue, + max_cached_workflows=0, ) as worker: # Send signals and updates await handle.signal(LoggingWorkflow.my_signal, "signal 3") @@ -2043,10 +2043,10 @@ async def run(self) -> None: async def test_workflow_logging_task_fail(client: Client): with LogCapturer().logs_captured( - activity.logger.base_logger, temporalio.worker._workflow_instance.logger + activity.logger.base_logger, temporalio.worker._workflow_instance.logger ) as capturer: async with new_worker( - client, TaskFailOnceWorkflow, activities=[task_fail_once_activity] + client, TaskFailOnceWorkflow, activities=[task_fail_once_activity] ) as worker: await client.execute_workflow( TaskFailOnceWorkflow.run, @@ -2058,16 +2058,16 @@ async def test_workflow_logging_task_fail(client: Client): assert wf_task_record assert "Intentional workflow task failure" in wf_task_record.message assert ( - getattr(wf_task_record, "temporal_workflow")["workflow_type"] - == "TaskFailOnceWorkflow" + getattr(wf_task_record, "temporal_workflow")["workflow_type"] + == "TaskFailOnceWorkflow" ) act_task_record = capturer.find_log("Completing activity as failed") assert act_task_record assert "Intentional activity task failure" in act_task_record.message assert ( - getattr(act_task_record, "temporal_activity")["activity_type"] - == "task_fail_once_activity" + getattr(act_task_record, "temporal_activity")["activity_type"] + == "task_fail_once_activity" ) @@ -2102,7 +2102,7 @@ def status(self) -> str: async def test_workflow_stack_trace(client: Client): async with new_worker( - client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: handle = await client.start_workflow( StackTraceWorkflow.run, @@ -2156,7 +2156,7 @@ async def test_workflow_enhanced_stack_trace(client: Client): """ async with new_worker( - client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: handle = await client.start_workflow( StackTraceWorkflow.run, @@ -2195,9 +2195,9 @@ async def status() -> str: async def test_workflow_external_enhanced_stack_trace(client: Client): async with new_worker( - client, - ExternalStackTraceWorkflow, - activities=[external_wait_cancel], + client, + ExternalStackTraceWorkflow, + activities=[external_wait_cancel], ) as worker: handle = await client.start_workflow( ExternalStackTraceWorkflow.run, @@ -2227,8 +2227,8 @@ async def status() -> str: assert fn is not None assert ( - 'status[0] = "waiting" # external coroutine test' - in trace.sources[fn].content + 'status[0] = "waiting" # external coroutine test' + in trace.sources[fn].content ) assert trace.sdk.version == __version__ @@ -2243,12 +2243,31 @@ def assert_expected(self) -> None: assert self.field1 == "some value" +T = typing.TypeVar('T') + + +@dataclass +class MyGenericDataClass(typing.Generic[T]): + field1: T + + def assert_expected(self) -> None: + # Part of the assertion is that this is the right type, which is + # confirmed just by calling the method. We also check the field. + assert str(self.field1) == "some value2" + + @activity.defn async def data_class_typed_activity(param: MyDataClass) -> MyDataClass: param.assert_expected() return param +@activity.defn +async def generic_data_class_typed_activity(param: MyGenericDataClass[str]) -> MyGenericDataClass[str]: + param.assert_expected() + return param + + @runtime_checkable @workflow.defn(name="DataClassTypedWorkflow") class DataClassTypedWorkflowProto(Protocol): @@ -2306,6 +2325,19 @@ async def run(self, param: MyDataClass) -> MyDataClass: start_to_close_timeout=timedelta(seconds=30), ) param.assert_expected() + param = await workflow.execute_activity( + generic_data_class_typed_activity, + param, + start_to_close_timeout=timedelta(seconds=30), + ) + param.assert_expected() + param = await workflow.execute_local_activity( + generic_data_class_typed_activity, + param, + start_to_close_timeout=timedelta(seconds=30), + ) + param.assert_expected() + child_handle = await workflow.start_child_workflow( DataClassTypedWorkflow.run, param, @@ -2348,7 +2380,8 @@ async def test_workflow_dataclass_typed(client: Client, env: WorkflowEnvironment "Java test server: https://github.com/temporalio/sdk-core/issues/390" ) async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, DataClassTypedWorkflow, + activities=[data_class_typed_activity, generic_data_class_typed_activity] ) as worker: val = MyDataClass(field1="some value") handle = await client.start_workflow( @@ -2373,7 +2406,7 @@ async def test_workflow_separate_protocol(client: Client): # This test is to confirm that protocols can be used as "interfaces" for # when the workflow impl is absent async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, DataClassTypedWorkflow, activities=[data_class_typed_activity] ) as worker: assert isinstance(DataClassTypedWorkflow(), DataClassTypedWorkflowProto) val = MyDataClass(field1="some value") @@ -2395,7 +2428,7 @@ async def test_workflow_separate_abstract(client: Client): # This test is to confirm that abstract classes can be used as "interfaces" # for when the workflow impl is absent async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, DataClassTypedWorkflow, activities=[data_class_typed_activity] ) as worker: assert issubclass(DataClassTypedWorkflow, DataClassTypedWorkflowAbstract) val = MyDataClass(field1="some value") @@ -2455,7 +2488,7 @@ async def test_workflow_child_already_started(client: Client, env: WorkflowEnvir "Java test server: https://github.com/temporalio/sdk-java/issues/1220" ) async with new_worker( - client, ChildAlreadyStartedWorkflow, LongSleepWorkflow + client, ChildAlreadyStartedWorkflow, LongSleepWorkflow ) as worker: with pytest.raises(WorkflowFailureError) as err: await client.execute_workflow( @@ -2503,10 +2536,10 @@ async def run(self) -> None: async def test_workflow_typed_config(client: Client): async with new_worker( - client, - TypedConfigWorkflow, - FailUntilAttemptWorkflow, - activities=[fail_until_attempt_activity], + client, + TypedConfigWorkflow, + FailUntilAttemptWorkflow, + activities=[fail_until_attempt_activity], ) as worker: await client.execute_workflow( TypedConfigWorkflow.run, @@ -2549,7 +2582,7 @@ async def run(self) -> None: async def test_workflow_local_activity_backoff(client: Client): workflow_id = f"workflow-{uuid.uuid4()}" async with new_worker( - client, LocalActivityBackoffWorkflow, activities=[fail_until_attempt_activity] + client, LocalActivityBackoffWorkflow, activities=[fail_until_attempt_activity] ) as worker: await client.execute_workflow( LocalActivityBackoffWorkflow.run, @@ -2592,7 +2625,7 @@ async def run(self) -> None: async def test_workflow_deadlock(client: Client): # Disable safe eviction so the worker can complete async with new_worker( - client, DeadlockedWorkflow, disable_safe_workflow_eviction=True + client, DeadlockedWorkflow, disable_safe_workflow_eviction=True ) as worker: if worker._workflow_worker: worker._workflow_worker._deadlock_timeout_seconds = 1 @@ -2742,7 +2775,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Run a simple pre-patch workflow. Need to disable workflow cache since we # restart the worker and don't want to pay the sticky queue penalty. async with new_worker( - client, PrePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, PrePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): pre_patch_handle = await execute() assert "pre-patch" == await query_result(pre_patch_handle) @@ -2750,7 +2783,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Confirm patched workflow gives old result for pre-patched but new result # for patched async with new_worker( - client, PatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, PatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): patch_handle = await execute() assert "post-patch" == await query_result(patch_handle) @@ -2758,7 +2791,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Confirm what works during deprecated async with new_worker( - client, DeprecatePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, DeprecatePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): deprecate_patch_handle = await execute() assert "post-patch" == await query_result(deprecate_patch_handle) @@ -2766,7 +2799,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Confirm what works when deprecation gone async with new_worker( - client, PostPatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, PostPatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): post_patch_handle = await execute() assert "post-patch" == await query_result(post_patch_handle) @@ -2819,10 +2852,10 @@ async def test_workflow_patch_memoized(client: Client): # to pay the sticky queue penalty. task_queue = f"tq-{uuid.uuid4()}" async with Worker( - client, - task_queue=task_queue, - workflows=[PatchMemoizedWorkflowUnpatched], - max_cached_workflows=0, + client, + task_queue=task_queue, + workflows=[PatchMemoizedWorkflowUnpatched], + max_cached_workflows=0, ): pre_patch_handle = await client.start_workflow( PatchMemoizedWorkflowUnpatched.run, @@ -2840,10 +2873,10 @@ async def waiting_signal() -> bool: # Now start the worker again, but this time with a patched workflow async with Worker( - client, - task_queue=task_queue, - workflows=[PatchMemoizedWorkflowPatched], - max_cached_workflows=0, + client, + task_queue=task_queue, + workflows=[PatchMemoizedWorkflowPatched], + max_cached_workflows=0, ): # Start a new workflow post patch post_patch_handle = await client.start_workflow( @@ -2859,10 +2892,10 @@ async def waiting_signal() -> bool: # Confirm expected values assert ["some-value"] == await pre_patch_handle.result() assert [ - "pre-patch", - "some-value", - "post-patch", - ] == await post_patch_handle.result() + "pre-patch", + "some-value", + "post-patch", + ] == await post_patch_handle.result() @workflow.defn @@ -2882,7 +2915,7 @@ def result(self) -> str: async def test_workflow_uuid(client: Client): task_queue = str(uuid.uuid4()) async with new_worker( - client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 ): # Get two handle UUID results. Need to disable workflow cache since we # restart the worker and don't want to pay the sticky queue penalty. @@ -2908,7 +2941,7 @@ async def test_workflow_uuid(client: Client): # Now confirm those results are the same even on a new worker async with new_worker( - client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 ): assert handle1_query_result == await handle1.query(UUIDWorkflow.result) assert handle2_query_result == await handle2.query(UUIDWorkflow.result) @@ -2937,7 +2970,7 @@ async def run(self, to_add: MyDataClass) -> MyDataClass: async def test_workflow_activity_callable_class(client: Client): activity_instance = CallableClassActivity("in worker") async with new_worker( - client, ActivityCallableClassWorkflow, activities=[activity_instance] + client, ActivityCallableClassWorkflow, activities=[activity_instance] ) as worker: result = await client.execute_workflow( ActivityCallableClassWorkflow.run, @@ -2987,9 +3020,9 @@ async def run(self, to_add: MyDataClass) -> MyDataClass: async def test_workflow_activity_method(client: Client): activity_instance = MethodActivity("in worker") async with new_worker( - client, - ActivityMethodWorkflow, - activities=[activity_instance.add, activity_instance.add_multi], + client, + ActivityMethodWorkflow, + activities=[activity_instance.add, activity_instance.add_multi], ) as worker: result = await client.execute_workflow( ActivityMethodWorkflow.run, @@ -3030,8 +3063,8 @@ def waiting(self) -> bool: async def test_workflow_wait_condition_timeout(client: Client): async with new_worker( - client, - WaitConditionTimeoutWorkflow, + client, + WaitConditionTimeoutWorkflow, ) as worker: handle = await client.start_workflow( WaitConditionTimeoutWorkflow.run, @@ -3063,8 +3096,8 @@ def some_query(self) -> str: async def test_workflow_query_rpc_timeout(client: Client): # Run workflow under worker and confirm query works async with new_worker( - client, - HelloWorkflowWithQuery, + client, + HelloWorkflowWithQuery, ) as worker: handle = await client.start_workflow( HelloWorkflowWithQuery.run, @@ -3081,9 +3114,9 @@ async def test_workflow_query_rpc_timeout(client: Client): HelloWorkflowWithQuery.some_query, rpc_timeout=timedelta(seconds=1) ) assert ( - err.value.status == RPCStatusCode.CANCELLED - and "timeout" in str(err.value).lower() - ) or err.value.status == RPCStatusCode.DEADLINE_EXCEEDED + err.value.status == RPCStatusCode.CANCELLED + and "timeout" in str(err.value).lower() + ) or err.value.status == RPCStatusCode.DEADLINE_EXCEEDED @dataclass @@ -3222,7 +3255,7 @@ def cancel_timer(self) -> None: async def test_workflow_cancel_signal_and_timer_fired_in_same_task( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): # This test only works when we support time skipping if not env.supports_time_skipping: @@ -3236,7 +3269,7 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task( # Start worker for 30 mins. Need to disable workflow cache since we # restart the worker and don't want to pay the sticky queue penalty. async with new_worker( - client, CancelSignalAndTimerFiredInSameTaskWorkflow, max_cached_workflows=0 + client, CancelSignalAndTimerFiredInSameTaskWorkflow, max_cached_workflows=0 ) as worker: task_queue = worker.task_queue handle = await client.start_workflow( @@ -3257,10 +3290,10 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task( # Start worker again and wait for workflow completion async with new_worker( - client, - CancelSignalAndTimerFiredInSameTaskWorkflow, - task_queue=task_queue, - max_cached_workflows=0, + client, + CancelSignalAndTimerFiredInSameTaskWorkflow, + task_queue=task_queue, + max_cached_workflows=0, ): # This used to not complete because a signal cancelling the timer was # not respected by the timer fire @@ -3292,7 +3325,7 @@ async def run(self) -> NoReturn: class CustomFailureConverter(DefaultFailureConverterWithEncodedAttributes): # We'll override from failure to convert back to our type def from_failure( - self, failure: Failure, payload_converter: PayloadConverter + self, failure: Failure, payload_converter: PayloadConverter ) -> BaseException: err = super().from_failure(failure, payload_converter) if isinstance(err, ApplicationError) and err.type == "MyCustomError": @@ -3314,7 +3347,7 @@ async def test_workflow_custom_failure_converter(client: Client): # Run workflow and confirm error async with new_worker( - client, CustomErrorWorkflow, activities=[custom_error_activity] + client, CustomErrorWorkflow, activities=[custom_error_activity] ) as worker: handle = await client.start_workflow( CustomErrorWorkflow.run, @@ -3361,11 +3394,11 @@ class OptionalParam: class OptionalParamWorkflow: @workflow.run async def run( - self, some_param: Optional[OptionalParam] = OptionalParam(some_string="default") + self, some_param: Optional[OptionalParam] = OptionalParam(some_string="default") ) -> Optional[OptionalParam]: assert some_param is None or ( - isinstance(some_param, OptionalParam) - and some_param.some_string in ["default", "foo"] + isinstance(some_param, OptionalParam) + and some_param.some_string in ["default", "foo"] ) return some_param @@ -3404,20 +3437,20 @@ class ExceptionRaisingPayloadConverter(DefaultPayloadConverter): def to_payloads(self, values: Sequence[Any]) -> List[Payload]: if any( - value == ExceptionRaisingPayloadConverter.bad_outbound_str - for value in values + value == ExceptionRaisingPayloadConverter.bad_outbound_str + for value in values ): raise ApplicationError("Intentional outbound converter failure") return super().to_payloads(values) def from_payloads( - self, payloads: Sequence[Payload], type_hints: Optional[List] = None + self, payloads: Sequence[Payload], type_hints: Optional[List] = None ) -> List[Any]: # Check if any payloads contain the bad data for payload in payloads: if ( - ExceptionRaisingPayloadConverter.bad_inbound_str.encode() - in payload.data + ExceptionRaisingPayloadConverter.bad_inbound_str.encode() + in payload.data ): raise ApplicationError("Intentional inbound converter failure") return super().from_payloads(payloads, type_hints) @@ -3533,7 +3566,7 @@ def some_query(self) -> ManualResultType: async def test_manual_result_type(client: Client): async with new_worker( - client, ManualResultTypeWorkflow, activities=[manual_result_type_activity] + client, ManualResultTypeWorkflow, activities=[manual_result_type_activity] ) as worker: # Workflow without result type and with res1 = await client.execute_workflow( @@ -3627,11 +3660,11 @@ async def test_cache_eviction_tear_down(client: Client): # chooses, but now we expect eviction to properly tear down tasks and # therefore we cancel them async with new_worker( - client, - CacheEvictionTearDownWorkflow, - WaitForeverWorkflow, - activities=[wait_forever_activity], - max_cached_workflows=0, + client, + CacheEvictionTearDownWorkflow, + WaitForeverWorkflow, + activities=[wait_forever_activity], + max_cached_workflows=0, ) as worker: # Put a hook to catch unraisable exceptions old_hook = sys.unraisablehook @@ -3697,7 +3730,7 @@ async def test_workflow_eviction_exception(client: Client): # Run workflow with no cache (forces eviction every step) async with new_worker( - client, EvictionCaptureExceptionWorkflow, max_cached_workflows=0 + client, EvictionCaptureExceptionWorkflow, max_cached_workflows=0 ) as worker: await client.execute_workflow( EvictionCaptureExceptionWorkflow.run, @@ -3709,8 +3742,8 @@ async def test_workflow_eviction_exception(client: Client): assert len(captured_eviction_exceptions) == 1 assert captured_eviction_exceptions[0].is_replaying assert ( - type(captured_eviction_exceptions[0].exception).__name__ - == "_WorkflowBeingEvictedError" + type(captured_eviction_exceptions[0].exception).__name__ + == "_WorkflowBeingEvictedError" ) @@ -3809,11 +3842,11 @@ async def assert_bad_query(bad_thing: str) -> None: # typing.Self only in 3.11+ if sys.version_info >= (3, 11): - @dataclass class AnnotatedWithSelfParam: some_str: str + @workflow.defn class WorkflowAnnotatedWithSelf: @workflow.run @@ -3821,6 +3854,7 @@ async def run(self: typing.Self, some_arg: AnnotatedWithSelfParam) -> str: assert isinstance(some_arg, AnnotatedWithSelfParam) return some_arg.some_str + async def test_workflow_annotated_with_self(client: Client): async with new_worker(client, WorkflowAnnotatedWithSelf) as worker: assert "foo" == await client.execute_workflow( @@ -3861,7 +3895,7 @@ async def test_workflow_custom_metrics(client: Client): # Run worker with default runtime which is noop meter just to confirm it # doesn't fail async with new_worker( - client, CustomMetricsWorkflow, activities=[custom_metrics_activity] + client, CustomMetricsWorkflow, activities=[custom_metrics_activity] ) as worker: await client.execute_workflow( CustomMetricsWorkflow.run, @@ -3890,7 +3924,7 @@ async def test_workflow_custom_metrics(client: Client): ) async with new_worker( - client, CustomMetricsWorkflow, activities=[custom_metrics_activity] + client, CustomMetricsWorkflow, activities=[custom_metrics_activity] ) as worker: # Record a gauge at runtime level gauge = runtime.metric_meter.with_additional_attributes( @@ -3912,7 +3946,7 @@ async def test_workflow_custom_metrics(client: Client): # Intentionally naive metric checker def matches_metric_line( - line: str, name: str, at_least_labels: Mapping[str, str], value: int + line: str, name: str, at_least_labels: Mapping[str, str], value: int ) -> bool: # Must have metric name if not line.startswith(name + "{"): @@ -3924,7 +3958,7 @@ def matches_metric_line( return line.endswith(f" {value}") def assert_metric_exists( - name: str, at_least_labels: Mapping[str, str], value: int + name: str, at_least_labels: Mapping[str, str], value: int ) -> None: assert any( matches_metric_line(line, name, at_least_labels, value) @@ -4045,7 +4079,7 @@ async def test_workflow_buffered_metrics(client: Client): runtime=runtime, ) async with new_worker( - client, CustomMetricsWorkflow, activities=[custom_metrics_activity] + client, CustomMetricsWorkflow, activities=[custom_metrics_activity] ) as worker: await client.execute_workflow( CustomMetricsWorkflow.run, @@ -4268,7 +4302,7 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir "Java test server: https://github.com/temporalio/sdk-java/issues/1903" ) async with new_worker( - client, UpdateHandlersWorkflow, activities=[say_hello] + client, UpdateHandlersWorkflow, activities=[say_hello] ) as worker: wf_id = f"update-handlers-workflow-{uuid.uuid4()}" handle = await client.start_workflow( @@ -4309,7 +4343,7 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir async def test_workflow_update_handlers_unhappy( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4360,8 +4394,8 @@ async def test_workflow_update_handlers_unhappy( await handle.execute_update("last_event", args=[121, "badarg"]) assert isinstance(err.value.cause, ApplicationError) assert ( - "last_event_validator() takes 2 positional arguments but 3 were given" - in err.value.cause.message + "last_event_validator() takes 2 positional arguments but 3 were given" + in err.value.cause.message ) # Un-deserializeable nonsense @@ -4394,7 +4428,7 @@ async def test_workflow_update_task_fails(client: Client, env: WorkflowEnvironme ) # Need to not sandbox so behavior can change based on globals async with new_worker( - client, UpdateHandlersWorkflow, workflow_runner=UnsandboxedWorkflowRunner() + client, UpdateHandlersWorkflow, workflow_runner=UnsandboxedWorkflowRunner() ) as worker: handle = await client.start_workflow( UpdateHandlersWorkflow.run, @@ -4432,7 +4466,7 @@ async def update(self) -> None: async def test_workflow_update_respects_first_execution_run_id( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4443,7 +4477,6 @@ async def test_workflow_update_respects_first_execution_run_id( # r1. workflow_id = f"update-respects-first-execution-run-id-{uuid.uuid4()}" async with new_worker(client, UpdateRespectsFirstExecutionRunIdWorkflow) as worker: - async def start_workflow(workflow_id: str) -> WorkflowHandle: return await client.start_workflow( UpdateRespectsFirstExecutionRunIdWorkflow.run, @@ -4488,7 +4521,7 @@ def got_update(self) -> str: async def test_workflow_update_before_worker_start( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4526,10 +4559,10 @@ async def test_workflow_update_before_worker_start( # Start no-cache worker on the task queue async with new_worker( - client, - ImmediatelyCompleteUpdateAndWorkflow, - task_queue=task_queue, - max_cached_workflows=0, + client, + ImmediatelyCompleteUpdateAndWorkflow, + task_queue=task_queue, + max_cached_workflows=0, ): # Confirm workflow completed as expected assert "workflow-done" == await handle.result() @@ -4562,7 +4595,7 @@ async def signal(self) -> None: async def test_workflow_update_separate_handle( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4610,7 +4643,7 @@ async def do_update(self, sleep: float) -> None: async def test_workflow_update_timeout_or_cancel( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4752,7 +4785,7 @@ async def test_workflow_timeout_support(client: Client, approach: str): if sys.version_info < (3, 11): pytest.skip("Timeout only in >= 3.11") async with new_worker( - client, TimeoutSupportWorkflow, activities=[wait_cancel] + client, TimeoutSupportWorkflow, activities=[wait_cancel] ) as worker: # Run and confirm activity gets cancelled handle = await client.start_workflow( @@ -4770,8 +4803,8 @@ async def test_workflow_timeout_support(client: Client, approach: str): async for e in handle.fetch_history_events(): if e.HasField("timer_started_event_attributes"): assert ( - e.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() - == 200 + e.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() + == 200 ) found_timer = True break @@ -4801,18 +4834,18 @@ async def finish(self): async def test_workflow_current_build_id_appropriately_set( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip("Java test server does not support worker versioning") task_queue = str(uuid.uuid4()) async with new_worker( - client, - BuildIDInfoWorkflow, - activities=[say_hello], - build_id="1.0", - task_queue=task_queue, + client, + BuildIDInfoWorkflow, + activities=[say_hello], + build_id="1.0", + task_queue=task_queue, ) as worker: handle = await client.start_workflow( BuildIDInfoWorkflow.run, @@ -4831,11 +4864,11 @@ async def test_workflow_current_build_id_appropriately_set( ) async with new_worker( - client, - BuildIDInfoWorkflow, - activities=[say_hello], - build_id="1.1", - task_queue=task_queue, + client, + BuildIDInfoWorkflow, + activities=[say_hello], + build_id="1.1", + task_queue=task_queue, ) as worker: bid = await handle.query(BuildIDInfoWorkflow.get_build_id) assert bid == "1.0" @@ -4907,7 +4940,7 @@ async def run(self, scenario: FailureTypesScenario) -> None: async def test_workflow_failure_types_configured( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4916,14 +4949,14 @@ async def test_workflow_failure_types_configured( # Asserter for a single scenario async def assert_scenario( - workflow: Type[FailureTypesWorkflowBase], - *, - expect_task_fail: bool, - fail_message_contains: str, - worker_level_failure_exception_type: Optional[Type[Exception]] = None, - workflow_scenario: Optional[FailureTypesScenario] = None, - signal_scenario: Optional[FailureTypesScenario] = None, - update_scenario: Optional[FailureTypesScenario] = None, + workflow: Type[FailureTypesWorkflowBase], + *, + expect_task_fail: bool, + fail_message_contains: str, + worker_level_failure_exception_type: Optional[Type[Exception]] = None, + workflow_scenario: Optional[FailureTypesScenario] = None, + signal_scenario: Optional[FailureTypesScenario] = None, + update_scenario: Optional[FailureTypesScenario] = None, ) -> None: logging.debug( "Asserting scenario %s", @@ -4938,12 +4971,12 @@ async def assert_scenario( }, ) async with new_worker( - client, - workflow, - max_cached_workflows=0, - workflow_failure_exception_types=[worker_level_failure_exception_type] - if worker_level_failure_exception_type - else [], + client, + workflow, + max_cached_workflows=0, + workflow_failure_exception_types=[worker_level_failure_exception_type] + if worker_level_failure_exception_type + else [], ) as worker: # Start workflow handle = await client.start_workflow( @@ -4969,9 +5002,9 @@ async def assert_scenario( async def has_expected_task_fail() -> bool: async for e in handle.fetch_history_events(): if ( - e.HasField("workflow_task_failed_event_attributes") - and fail_message_contains - in e.workflow_task_failed_event_attributes.failure.message + e.HasField("workflow_task_failed_event_attributes") + and fail_message_contains + in e.workflow_task_failed_event_attributes.failure.message ): return True return False @@ -4984,9 +5017,9 @@ async def has_expected_task_fail() -> bool: # Update does not throw on non-determinism, the workflow # does instead if ( - update_handle - and update_scenario - == FailureTypesScenario.THROW_CUSTOM_EXCEPTION + update_handle + and update_scenario + == FailureTypesScenario.THROW_CUSTOM_EXCEPTION ): await update_handle.result() else: @@ -4996,11 +5029,11 @@ async def has_expected_task_fail() -> bool: # Run a scenario async def run_scenario( - workflow: Type[FailureTypesWorkflowBase], - scenario: FailureTypesScenario, - *, - expect_task_fail: bool = False, - worker_level_failure_exception_type: Optional[Type[Exception]] = None, + workflow: Type[FailureTypesWorkflowBase], + scenario: FailureTypesScenario, + *, + expect_task_fail: bool = False, + worker_level_failure_exception_type: Optional[Type[Exception]] = None, ) -> None: # Run for workflow, signal, and update fail_message_contains = ( @@ -5141,11 +5174,11 @@ async def any_task_completed(handle: WorkflowHandle) -> bool: # Now start the worker on the first env async with Worker( - client, - task_queue=task_queue, - workflows=[TickingWorkflow], - max_cached_workflows=0, - max_concurrent_workflow_task_polls=1, + client, + task_queue=task_queue, + workflows=[TickingWorkflow], + max_cached_workflows=0, + max_concurrent_workflow_task_polls=1, ) as worker: # Confirm the first ticking workflow has completed a task but not # the second @@ -5192,10 +5225,10 @@ async def run(self) -> List[str]: async def test_workflow_as_completed_utility(client: Client): # Disable cache to force replay async with new_worker( - client, - AsCompletedWorkflow, - activities=[return_name_activity], - max_cached_workflows=0, + client, + AsCompletedWorkflow, + activities=[return_name_activity], + max_cached_workflows=0, ) as worker: # This would fail if we used asyncio.as_completed in the workflow result = await client.execute_workflow( @@ -5234,7 +5267,7 @@ async def new_activity_name(index: int) -> str: async def test_workflow_wait_utility(client: Client): # Disable cache to force replay async with new_worker( - client, WaitWorkflow, activities=[return_name_activity], max_cached_workflows=0 + client, WaitWorkflow, activities=[return_name_activity], max_cached_workflows=0 ) as worker: # This would fail if we used asyncio.wait in the workflow result = await client.execute_workflow( @@ -5456,9 +5489,9 @@ async def _workflow_task_failed(self, workflow_id: str) -> bool: return False async def _get_workflow_result_and_warning( - self, - wait_all_handlers_finished: bool, - unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, + self, + wait_all_handlers_finished: bool, + unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, ) -> Tuple[bool, bool]: with pytest.WarningsRecorder() as warnings: wf_result = await self._get_workflow_result( @@ -5471,10 +5504,10 @@ async def _get_workflow_result_and_warning( return wf_result, unfinished_handler_warning_emitted async def _get_workflow_result( - self, - wait_all_handlers_finished: bool, - unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, - handle_future: Optional[asyncio.Future[WorkflowHandle]] = None, + self, + wait_all_handlers_finished: bool, + unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, + handle_future: Optional[asyncio.Future[WorkflowHandle]] = None, ) -> bool: handle = await self.client.start_workflow( UnfinishedHandlersWarningsWorkflow.run, @@ -5516,30 +5549,30 @@ def __init__(self) -> None: @workflow.run async def run( - self, - workflow_termination_type: Literal[ - "-cancellation-", - "-failure-", - "-continue-as-new-", - "-fail-post-continue-as-new-run-", - ], - handler_registration: Literal["-late-registered-", "-not-late-registered-"], - handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], - handler_waiting: Literal[ - "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" - ], + self, + workflow_termination_type: Literal[ + "-cancellation-", + "-failure-", + "-continue-as-new-", + "-fail-post-continue-as-new-run-", + ], + handler_registration: Literal["-late-registered-", "-not-late-registered-"], + handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], + handler_waiting: Literal[ + "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" + ], ) -> NoReturn: if handler_registration == "-late-registered-": if handler_dynamism == "-dynamic-": async def my_late_registered_dynamic_update( - name: str, args: Sequence[RawValue] + name: str, args: Sequence[RawValue] ) -> str: await workflow.wait_condition(lambda: self.handlers_may_finish) return "my-late-registered-dynamic-update-result" async def my_late_registered_dynamic_signal( - name: str, args: Sequence[RawValue] + name: str, args: Sequence[RawValue] ) -> None: await workflow.wait_condition(lambda: self.handlers_may_finish) @@ -5616,17 +5649,17 @@ async def my_dynamic_signal(self, name: str, args: Sequence[RawValue]) -> None: "workflow_termination_type", ["-cancellation-", "-failure-", "-continue-as-new-"] ) async def test_unfinished_handler_on_workflow_termination( - client: Client, - env: WorkflowEnvironment, - handler_type: Literal["-signal-", "-update-"], - handler_registration: Literal["-late-registered-", "-not-late-registered-"], - handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], - handler_waiting: Literal[ - "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" - ], - workflow_termination_type: Literal[ - "-cancellation-", "-failure-", "-continue-as-new-" - ], + client: Client, + env: WorkflowEnvironment, + handler_type: Literal["-signal-", "-update-"], + handler_registration: Literal["-late-registered-", "-not-late-registered-"], + handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], + handler_waiting: Literal[ + "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" + ], + workflow_termination_type: Literal[ + "-cancellation-", "-failure-", "-continue-as-new-" + ], ): skip_unfinished_handler_tests_in_older_python() if handler_type == "-update-" and env.supports_time_skipping: @@ -5657,10 +5690,10 @@ class _UnfinishedHandlersOnWorkflowTerminationTest: ] async def test_warning_is_issued_on_exit_with_unfinished_handler( - self, + self, ): assert await self._run_workflow_and_get_warning() == ( - self.handler_waiting == "-no-wait-all-handlers-finish-" + self.handler_waiting == "-no-wait-all-handlers-finish-" ) async def _run_workflow_and_get_warning(self) -> bool: @@ -5714,9 +5747,9 @@ async def _run_workflow_and_get_warning(self) -> bool: await handle.signal(signal_method) # type: ignore async with new_worker( - self.client, - UnfinishedHandlersOnWorkflowTerminationWorkflow, - task_queue=task_queue, + self.client, + UnfinishedHandlersOnWorkflowTerminationWorkflow, + task_queue=task_queue, ): with pytest.WarningsRecorder() as warnings: if self.handler_type == "-update-": @@ -5729,7 +5762,7 @@ async def _run_workflow_and_get_warning(self) -> bool: update_err = err_info.value assert isinstance(update_err.cause, ApplicationError) assert ( - update_err.cause.type == "AcceptedUpdateCompletedWorkflow" + update_err.cause.type == "AcceptedUpdateCompletedWorkflow" ) with pytest.raises(WorkflowFailureError) as err: @@ -5744,8 +5777,8 @@ async def _run_workflow_and_get_warning(self) -> bool: ) if self.workflow_termination_type == "-continue-as-new-": assert ( - str(err.value.cause) - == "Deliberately failing post-ContinueAsNew run" + str(err.value.cause) + == "Deliberately failing post-ContinueAsNew run" ) unfinished_handler_warning_emitted = any( @@ -5845,8 +5878,8 @@ async def my_update(self) -> str: async def test_update_completion_is_honored_when_after_workflow_return_1( - client: Client, - env: WorkflowEnvironment, + client: Client, + env: WorkflowEnvironment, ): if env.supports_time_skipping: pytest.skip( @@ -5868,9 +5901,9 @@ async def test_update_completion_is_honored_when_after_workflow_return_1( await workflow_update_exists(client, wf_handle.id, update_id) async with Worker( - client, - task_queue=task_queue, - workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturn1Workflow], + client, + task_queue=task_queue, + workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturn1Workflow], ): assert await wf_handle.result() == "workflow-result" assert await update_result_task == "update-result" @@ -5901,17 +5934,17 @@ async def my_update(self) -> str: async def test_update_completion_is_honored_when_after_workflow_return_2( - client: Client, - env: WorkflowEnvironment, + client: Client, + env: WorkflowEnvironment, ): if env.supports_time_skipping: pytest.skip( "Java test server: https://github.com/temporalio/sdk-java/issues/1903" ) async with Worker( - client, - task_queue="tq", - workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturnWorkflow2], + client, + task_queue="tq", + workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturnWorkflow2], ) as worker: handle = await client.start_workflow( UpdateCompletionIsHonoredWhenAfterWorkflowReturnWorkflow2.run, @@ -5988,7 +6021,7 @@ async def test_first_of_two_signal_completion_commands_is_honored(client: Client async def test_workflow_return_is_honored_when_it_precedes_signal_completion_command( - client: Client, + client: Client, ): await _do_first_completion_command_is_honored_test( client, main_workflow_returns_before_signal_completions=True @@ -5996,7 +6029,7 @@ async def test_workflow_return_is_honored_when_it_precedes_signal_completion_com async def _do_first_completion_command_is_honored_test( - client: Client, main_workflow_returns_before_signal_completions: bool + client: Client, main_workflow_returns_before_signal_completions: bool ): workflow_cls: Union[ Type[FirstCompletionCommandIsHonoredPingPongWorkflow], @@ -6007,9 +6040,9 @@ async def _do_first_completion_command_is_honored_test( else FirstCompletionCommandIsHonoredWorkflow ) async with Worker( - client, - task_queue="tq", - workflows=[workflow_cls], + client, + task_queue="tq", + workflows=[workflow_cls], ) as worker: handle = await client.start_workflow( workflow_cls.run, @@ -6029,8 +6062,8 @@ async def _do_first_completion_command_is_honored_test( assert str(err.cause).startswith("Client should see this error") else: assert ( - main_workflow_returns_before_signal_completions - and result == "workflow-result" + main_workflow_returns_before_signal_completions + and result == "workflow-result" ) @@ -6055,7 +6088,7 @@ async def my_signal(self): async def test_timer_started_after_workflow_completion(client: Client): async with new_worker( - client, TimerStartedAfterWorkflowCompletionWorkflow + client, TimerStartedAfterWorkflowCompletionWorkflow ) as worker: handle = await client.start_workflow( TimerStartedAfterWorkflowCompletionWorkflow.run, @@ -6090,7 +6123,7 @@ async def run(self) -> None: async def test_activity_retry_delay(client: Client): async with new_worker( - client, ActivitiesWithRetryDelayWorkflow, activities=[activity_with_retry_delay] + client, ActivitiesWithRetryDelayWorkflow, activities=[activity_with_retry_delay] ) as worker: try: await client.execute_workflow( @@ -6102,11 +6135,11 @@ async def test_activity_retry_delay(client: Client): assert isinstance(err.cause, ActivityError) assert isinstance(err.cause.cause, ApplicationError) assert ( - str(err.cause.cause) == ActivitiesWithRetryDelayWorkflow.error_message + str(err.cause.cause) == ActivitiesWithRetryDelayWorkflow.error_message ) assert ( - err.cause.cause.next_retry_delay - == ActivitiesWithRetryDelayWorkflow.next_retry_delay + err.cause.cause.next_retry_delay + == ActivitiesWithRetryDelayWorkflow.next_retry_delay ) @@ -6169,7 +6202,7 @@ async def run(self, _: str) -> str: ], ) async def test_update_in_first_wft_sees_workflow_init( - client: Client, client_cls: Type, worker_cls: Type + client: Client, client_cls: Type, worker_cls: Type ): """ Test how @workflow.init affects what an update in the first WFT sees. @@ -6270,7 +6303,7 @@ async def test_user_metadata_is_set(client: Client, env: WorkflowEnvironment): "Java test server: https://github.com/temporalio/sdk-java/issues/2219" ) async with new_worker( - client, UserMetadataWorkflow, activities=[say_hello] + client, UserMetadataWorkflow, activities=[say_hello] ) as worker: handle = await client.start_workflow( UserMetadataWorkflow.run, @@ -6400,7 +6433,7 @@ async def make_timers(self, start: int, end: int): async def test_concurrent_sleeps_use_proper_options( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -6445,10 +6478,10 @@ class BadFailureConverterError(Exception): class BadFailureConverter(DefaultFailureConverter): def to_failure( - self, - exception: BaseException, - payload_converter: PayloadConverter, - failure: Failure, + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: Failure, ) -> None: if isinstance(exception, BadFailureConverterError): raise RuntimeError("Intentional failure conversion error") @@ -6482,7 +6515,7 @@ async def test_bad_failure_converter(client: Client): ) client = Client(**config) async with new_worker( - client, BadFailureConverterWorkflow, activities=[bad_failure_converter_activity] + client, BadFailureConverterWorkflow, activities=[bad_failure_converter_activity] ) as worker: # Check activity with pytest.raises(WorkflowFailureError) as err: @@ -6495,8 +6528,8 @@ async def test_bad_failure_converter(client: Client): assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, ApplicationError) assert ( - err.value.cause.cause.message - == "Failed building exception result: Intentional failure conversion error" + err.value.cause.cause.message + == "Failed building exception result: Intentional failure conversion error" ) # Check workflow @@ -6595,10 +6628,10 @@ async def test_async_loop_ordering(client: Client, env: WorkflowEnvironment): await handle.signal(SignalsActivitiesTimersUpdatesTracingWorkflow.dosig, "before") async with new_worker( - client, - SignalsActivitiesTimersUpdatesTracingWorkflow, - activities=[say_hello], - task_queue=task_queue, + client, + SignalsActivitiesTimersUpdatesTracingWorkflow, + activities=[say_hello], + task_queue=task_queue, ): await asyncio.sleep(0.2) await handle.signal(SignalsActivitiesTimersUpdatesTracingWorkflow.dosig, "1") @@ -6655,18 +6688,18 @@ async def test_alternate_async_loop_ordering(client: Client, env: WorkflowEnviro ) async with new_worker( - client, - ActivityAndSignalsWhileWorkflowDown, - activities=[say_hello], - task_queue=task_queue, + client, + ActivityAndSignalsWhileWorkflowDown, + activities=[say_hello], + task_queue=task_queue, ): # This sleep exists to make sure the first WFT is processed await asyncio.sleep(0.2) async with new_worker( - client, - activities=[say_hello], - task_queue=activity_tq, + client, + activities=[say_hello], + task_queue=activity_tq, ): # Make sure the activity starts being processed before sending signals await asyncio.sleep(1) @@ -6674,10 +6707,10 @@ async def test_alternate_async_loop_ordering(client: Client, env: WorkflowEnviro await handle.signal(ActivityAndSignalsWhileWorkflowDown.dosig, "2") async with new_worker( - client, - ActivityAndSignalsWhileWorkflowDown, - activities=[say_hello], - task_queue=task_queue, + client, + ActivityAndSignalsWhileWorkflowDown, + activities=[say_hello], + task_queue=task_queue, ): await handle.result() @@ -6726,8 +6759,8 @@ def init(self, params: UseLockOrSemaphoreWorkflowParameters): @workflow.run async def run( - self, - params: Optional[UseLockOrSemaphoreWorkflowParameters], + self, + params: Optional[UseLockOrSemaphoreWorkflowParameters], ) -> LockOrSemaphoreWorkflowConcurrencySummary: # TODO: Use workflow init method when it exists. assert params @@ -6784,8 +6817,8 @@ def __init__(self) -> None: @workflow.run async def run( - self, - _: Optional[UseLockOrSemaphoreWorkflowParameters] = None, + self, + _: Optional[UseLockOrSemaphoreWorkflowParameters] = None, ) -> LockOrSemaphoreWorkflowConcurrencySummary: await workflow.wait_condition(lambda: self.workflow_may_exit) return LockOrSemaphoreWorkflowConcurrencySummary( @@ -6807,14 +6840,14 @@ async def finish(self): async def _do_workflow_coroutines_lock_or_semaphore_test( - client: Client, - params: UseLockOrSemaphoreWorkflowParameters, - expectation: LockOrSemaphoreWorkflowConcurrencySummary, + client: Client, + params: UseLockOrSemaphoreWorkflowParameters, + expectation: LockOrSemaphoreWorkflowConcurrencySummary, ): async with new_worker( - client, - CoroutinesUseLockOrSemaphoreWorkflow, - activities=[noop_activity_for_lock_or_semaphore_tests], + client, + CoroutinesUseLockOrSemaphoreWorkflow, + activities=[noop_activity_for_lock_or_semaphore_tests], ) as worker: summary = await client.execute_workflow( CoroutinesUseLockOrSemaphoreWorkflow.run, @@ -6826,11 +6859,11 @@ async def _do_workflow_coroutines_lock_or_semaphore_test( async def _do_update_handler_lock_or_semaphore_test( - client: Client, - env: WorkflowEnvironment, - params: UseLockOrSemaphoreWorkflowParameters, - n_updates: int, - expectation: LockOrSemaphoreWorkflowConcurrencySummary, + client: Client, + env: WorkflowEnvironment, + params: UseLockOrSemaphoreWorkflowParameters, + n_updates: int, + expectation: LockOrSemaphoreWorkflowConcurrencySummary, ): if env.supports_time_skipping: pytest.skip( @@ -6855,10 +6888,10 @@ async def _do_update_handler_lock_or_semaphore_test( for i in range(n_updates) ] async with new_worker( - client, - HandlerCoroutinesUseLockOrSemaphoreWorkflow, - activities=[noop_activity_for_lock_or_semaphore_tests], - task_queue=task_queue, + client, + HandlerCoroutinesUseLockOrSemaphoreWorkflow, + activities=[noop_activity_for_lock_or_semaphore_tests], + task_queue=task_queue, ): for update_task in admitted_updates: await update_task @@ -6879,7 +6912,7 @@ async def test_workflow_coroutines_can_use_lock(client: Client): async def test_update_handler_can_use_lock_to_serialize_handler_executions( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -6905,7 +6938,7 @@ async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Cli async def test_update_handler_lock_acquisition_respects_timeout( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -6931,7 +6964,7 @@ async def test_workflow_coroutines_can_use_semaphore(client: Client): async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -6946,7 +6979,7 @@ async def test_update_handler_can_use_semaphore_to_control_handler_execution_con async def test_workflow_coroutine_semaphore_acquisition_respects_timeout( - client: Client, + client: Client, ): await _do_workflow_coroutines_lock_or_semaphore_test( client, @@ -6962,7 +6995,7 @@ async def test_workflow_coroutine_semaphore_acquisition_respects_timeout( async def test_update_handler_semaphore_acquisition_respects_timeout( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -7119,11 +7152,11 @@ async def test_workflow_deadlock_fill_up_slots(client: Client): # This worker used to not be able to shutdown because we hung evictions on # deadlock. async with new_worker( - client, - DeadlockFillUpBlockWorkflow, - DeadlockFillUpSimpleWorkflow, - # Start the worker with CPU count + 10 task slots - max_concurrent_workflow_tasks=cpu_count + 10, + client, + DeadlockFillUpBlockWorkflow, + DeadlockFillUpSimpleWorkflow, + # Start the worker with CPU count + 10 task slots + max_concurrent_workflow_tasks=cpu_count + 10, ) as worker: # For this test we're going to start cpu_count + 5 workflows that # deadlock. In previous SDK versions we defaulted to CPU count @@ -7223,7 +7256,7 @@ async def check_logs(): while True: log_record = log_queue.get(block=False) if log_record.message.startswith( - f"Timed out running eviction job for run ID {handle.result_run_id}" + f"Timed out running eviction job for run ID {handle.result_run_id}" ): return except queue.Empty: @@ -7251,7 +7284,7 @@ async def check_priority_activity(should_have_priorty: int) -> str: class WorkflowUsingPriorities: @workflow.run async def run( - self, expected_priority: Optional[int], stop_after_check: bool + self, expected_priority: Optional[int], stop_after_check: bool ) -> str: assert workflow.info().priority.priority_key == expected_priority if stop_after_check: @@ -7283,7 +7316,7 @@ async def test_workflow_priorities(client: Client, env: WorkflowEnvironment): ) async with new_worker( - client, WorkflowUsingPriorities, HelloWorkflow, activities=[say_hello] + client, WorkflowUsingPriorities, HelloWorkflow, activities=[say_hello] ) as worker: handle = await client.start_workflow( WorkflowUsingPriorities.run, @@ -7298,27 +7331,27 @@ async def test_workflow_priorities(client: Client, env: WorkflowEnvironment): async for e in handle.fetch_history_events(): if e.HasField("workflow_execution_started_event_attributes"): assert ( - e.workflow_execution_started_event_attributes.priority.priority_key - == 1 + e.workflow_execution_started_event_attributes.priority.priority_key + == 1 ) elif e.HasField( - "start_child_workflow_execution_initiated_event_attributes" + "start_child_workflow_execution_initiated_event_attributes" ): if first_child: assert ( - e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key - == 4 + e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key + == 4 ) first_child = False else: assert ( - e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key - == 2 + e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key + == 2 ) elif e.HasField("activity_task_scheduled_event_attributes"): assert ( - e.activity_task_scheduled_event_attributes.priority.priority_key - == 5 + e.activity_task_scheduled_event_attributes.priority.priority_key + == 5 ) # Verify a workflow started without priorities sees None for the key @@ -7361,7 +7394,7 @@ async def test_expose_root_execution(client: Client, env: WorkflowEnvironment): "Java test server needs release with: https://github.com/temporalio/sdk-java/pull/2441" ) async with new_worker( - client, ExposeRootWorkflow, ExposeRootChildWorkflow + client, ExposeRootWorkflow, ExposeRootChildWorkflow ) as worker: parent_wf_id = f"workflow-{uuid.uuid4()}" child_wf_id = parent_wf_id + "_child" From e9a38d9dfdd70b41118c5cccf971392c9737396b Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 9 Apr 2025 11:07:37 -0700 Subject: [PATCH 02/12] Added support for serialization/deserialization of generic dataclasses. Note this still doesn't support fields of generic types. --- temporalio/converter.py | 19 ++++++++++++------- tests/worker/test_workflow.py | 16 +++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 37e7641d..18c7efee 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1496,35 +1496,40 @@ def value_to_type( return ret_dict # Dataclass - if dataclasses.is_dataclass(hint): + h = hint if dataclasses.is_dataclass(hint) else None + # This allows for generic dataclasses to be passed in. + # Note that the field of a generic parameter type is still not deserializable. + # Such fields must be marked with dataclasses.field(metadata={"skip": True}, default=...). + h = origin if h is None and dataclasses.is_dataclass(origin) else h + if h is not None: if not isinstance(value, dict): raise TypeError( - f"Cannot convert to dataclass {hint}, value is {type(value)} not dict" + f"Cannot convert to dataclass {h}, value is {type(value)} not dict" ) # Obtain dataclass fields and check that all dict fields are there and # that no required fields are missing. Unknown fields are silently # ignored. - fields = dataclasses.fields(hint) - field_hints = get_type_hints(hint) + fields = dataclasses.fields(h) + field_hints = get_type_hints(h) field_values = {} for field in fields: field_value = value.get(field.name, dataclasses.MISSING) # We do not check whether field is required here. Rather, we let the # attempted instantiation of the dataclass raise if a field is # missing - if field_value is not dataclasses.MISSING: + if field_value is not dataclasses.MISSING and not field.metadata.get("skip", False) : try: field_values[field.name] = value_to_type( field_hints[field.name], field_value, custom_converters ) except Exception as err: raise TypeError( - f"Failed converting field {field.name} on dataclass {hint}" + f"Failed converting field {field.name} on dataclass {h}" ) from err # Simply instantiate the dataclass. This will fail as expected when # missing required fields. # TODO(cretz): Want way to convert snake case to camel case? - return hint(**field_values) + return h(**field_values) # Pydantic model instance # Pydantic users should use Pydantic v2 with diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 67b34a31..946dec44 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -2248,7 +2248,8 @@ def assert_expected(self) -> None: @dataclass class MyGenericDataClass(typing.Generic[T]): - field1: T + field1: str + field2: T = dataclasses.field(metadata={"skip": True}, default=None) def assert_expected(self) -> None: # Part of the assertion is that this is the right type, which is @@ -2325,18 +2326,19 @@ async def run(self, param: MyDataClass) -> MyDataClass: start_to_close_timeout=timedelta(seconds=30), ) param.assert_expected() - param = await workflow.execute_activity( + generic_param = MyGenericDataClass[str]("some value2") + generic_param = await workflow.execute_activity( generic_data_class_typed_activity, - param, + generic_param, start_to_close_timeout=timedelta(seconds=30), ) - param.assert_expected() - param = await workflow.execute_local_activity( + generic_param.assert_expected() + generic_param = await workflow.execute_local_activity( generic_data_class_typed_activity, - param, + generic_param, start_to_close_timeout=timedelta(seconds=30), ) - param.assert_expected() + generic_param.assert_expected() child_handle = await workflow.start_child_workflow( DataClassTypedWorkflow.run, From e6315b95f00947a22213fa0146b2b74f65734351 Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 9 Apr 2025 11:16:18 -0700 Subject: [PATCH 03/12] reformatted --- temporalio/converter.py | 4 +- tests/worker/test_workflow.py | 643 +++++++++++++++++----------------- 2 files changed, 326 insertions(+), 321 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 18c7efee..384dda5f 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1517,7 +1517,9 @@ def value_to_type( # We do not check whether field is required here. Rather, we let the # attempted instantiation of the dataclass raise if a field is # missing - if field_value is not dataclasses.MISSING and not field.metadata.get("skip", False) : + if field_value is not dataclasses.MISSING and not field.metadata.get( + "skip", False + ): try: field_values[field.name] = value_to_type( field_hints[field.name], field_value, custom_converters diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 946dec44..08cba4a4 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -181,7 +181,7 @@ async def test_workflow_multi_param(client: Client): # This test is mostly just here to confirm MyPy type checks the multi-param # overload approach properly async with new_worker( - client, MultiParamWorkflow, activities=[multi_param_activity] + client, MultiParamWorkflow, activities=[multi_param_activity] ) as worker: result = await client.execute_workflow( MultiParamWorkflow.run, @@ -267,7 +267,7 @@ def get_history_info(self) -> HistoryInfo: async def test_workflow_history_info( - client: Client, env: WorkflowEnvironment, continue_as_new_suggest_history_count: int + client: Client, env: WorkflowEnvironment, continue_as_new_suggest_history_count: int ): if env.supports_time_skipping: pytest.skip("Java test server does not support should continue as new") @@ -589,7 +589,7 @@ def new_handler(name: str, *args: Any) -> str: async def test_workflow_signal_qnd_query_handlers_old_dynamic_style(client: Client): async with new_worker( - client, SignalAndQueryHandlersOldDynamicStyleWorkflow + client, SignalAndQueryHandlersOldDynamicStyleWorkflow ) as worker: handle = await client.start_workflow( SignalAndQueryHandlersOldDynamicStyleWorkflow.run, @@ -654,9 +654,9 @@ async def test_workflow_bad_signal_param(client: Client): await handle.signal("some_signal", 123) await handle.signal("some_signal", BadSignalParam(some_str="finish")) assert [ - BadSignalParam(some_str="good"), - BadSignalParam(some_str="finish"), - ] == await handle.result() + BadSignalParam(some_str="good"), + BadSignalParam(some_str="finish"), + ] == await handle.result() @workflow.defn @@ -779,7 +779,7 @@ async def run(self, name: str) -> str: async def test_workflow_simple_activity(client: Client): async with new_worker( - client, SimpleActivityWorkflow, activities=[say_hello] + client, SimpleActivityWorkflow, activities=[say_hello] ) as worker: result = await client.execute_workflow( SimpleActivityWorkflow.run, @@ -801,7 +801,7 @@ async def run(self, name: str) -> str: async def test_workflow_simple_local_activity(client: Client): async with new_worker( - client, SimpleLocalActivityWorkflow, activities=[say_hello] + client, SimpleLocalActivityWorkflow, activities=[say_hello] ) as worker: result = await client.execute_workflow( SimpleLocalActivityWorkflow.run, @@ -903,7 +903,7 @@ async def test_workflow_cancel_activity(client: Client, local: bool): activity_inst = ActivityWaitCancelNotify() async with new_worker( - client, CancelActivityWorkflow, activities=[activity_inst.wait_cancel] + client, CancelActivityWorkflow, activities=[activity_inst.wait_cancel] ) as worker: # Try cancel - confirm error and activity was sent the cancel handle = await client.start_workflow( @@ -1070,7 +1070,7 @@ def started(self) -> bool: @pytest.mark.parametrize("activity", [True, False]) async def test_workflow_uncaught_cancel(client: Client, activity: bool): async with new_worker( - client, UncaughtCancelWorkflow, activities=[wait_forever] + client, UncaughtCancelWorkflow, activities=[wait_forever] ) as worker: # Start workflow waiting on activity or child workflow, cancel it, and # confirm the workflow is shown as cancelled @@ -1231,7 +1231,7 @@ async def run(self, args: SignalExternalWorkflowArgs) -> None: async def test_workflow_signal_external(client: Client): async with new_worker( - client, SignalExternalWorkflow, ReturnSignalWorkflow + client, SignalExternalWorkflow, ReturnSignalWorkflow ) as worker: # Start return signal, then signal and check that it got signalled return_signal_handle = await client.start_workflow( @@ -1305,7 +1305,7 @@ async def child(id: str): async def test_workflow_cancel_multi(client: Client): async with new_worker( - client, MultiCancelWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, MultiCancelWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: results = await client.execute_workflow( MultiCancelWorkflow.run, @@ -1400,7 +1400,7 @@ async def wait_and_swallow(self, aw: Awaitable) -> None: async def test_workflow_cancel_unsent(client: Client): workflow_id = f"workflow-{uuid.uuid4()}" async with new_worker( - client, CancelUnsentWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, CancelUnsentWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: await client.execute_workflow( CancelUnsentWorkflow.run, @@ -1419,14 +1419,14 @@ async def test_workflow_cancel_unsent(client: Client): # No activities or children scheduled assert event.event_type is not EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED assert ( - event.event_type - is not EventType.EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED + event.event_type + is not EventType.EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED ) # Make sure timer is just our 0.01 timer if event.event_type is EventType.EVENT_TYPE_TIMER_STARTED: assert ( - event.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() - == 10 + event.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() + == 10 ) found_timer = True assert found_timer @@ -1445,7 +1445,7 @@ async def run(self) -> None: async def test_workflow_activity_timeout(client: Client): async with new_worker( - client, ActivityTimeoutWorkflow, activities=[wait_cancel] + client, ActivityTimeoutWorkflow, activities=[wait_cancel] ) as worker: with pytest.raises(WorkflowFailureError) as err: await client.execute_workflow( @@ -1522,7 +1522,7 @@ def create_instance(self, det: WorkflowInstanceDetails) -> WorkflowInstance: class CustomWorkflowInstance(WorkflowInstance): def __init__( - self, runner: CustomWorkflowRunner, unsandboxed: WorkflowInstance + self, runner: CustomWorkflowRunner, unsandboxed: WorkflowInstance ) -> None: super().__init__() self._runner = runner @@ -1546,13 +1546,13 @@ async def test_workflow_with_custom_runner(client: Client): assert result == "Hello, Temporal!" # Confirm first activation and last non-eviction-reply completion assert ( - runner._pairs[0][0].jobs[0].initialize_workflow.workflow_type == "HelloWorkflow" + runner._pairs[0][0].jobs[0].initialize_workflow.workflow_type == "HelloWorkflow" ) assert ( - runner._pairs[-2][-1] - .successful.commands[0] - .complete_workflow_execution.result.data - == b'"Hello, Temporal!"' + runner._pairs[-2][-1] + .successful.commands[0] + .complete_workflow_execution.result.data + == b'"Hello, Temporal!"' ) @@ -1604,7 +1604,7 @@ async def test_workflow_continue_as_new(client: Client, env: WorkflowEnvironment def search_attributes_to_serializable( - attrs: Union[SearchAttributes, TypedSearchAttributes], + attrs: Union[SearchAttributes, TypedSearchAttributes], ) -> Mapping[str, Any]: if isinstance(attrs, TypedSearchAttributes): return { @@ -1785,7 +1785,7 @@ async def describe_attributes_untyped(handle: WorkflowHandle) -> SearchAttribute } async def describe_attributes_typed( - handle: WorkflowHandle, + handle: WorkflowHandle, ) -> TypedSearchAttributes: # Remove any not our prefix attrs = (await handle.describe()).typed_search_attributes @@ -1940,13 +1940,13 @@ def find_log(self, starts_with: str) -> Optional[logging.LogRecord]: async def test_workflow_logging(client: Client, env: WorkflowEnvironment): workflow.logger.full_workflow_info_on_extra = True with LogCapturer().logs_captured( - workflow.logger.base_logger, activity.logger.base_logger + workflow.logger.base_logger, activity.logger.base_logger ) as capturer: # Log two signals and kill worker before completing. Need to disable # workflow cache since we restart the worker and don't want to pay the # sticky queue penalty. async with new_worker( - client, LoggingWorkflow, max_cached_workflows=0 + client, LoggingWorkflow, max_cached_workflows=0 ) as worker: handle = await client.start_workflow( LoggingWorkflow.run, @@ -1973,30 +1973,30 @@ async def test_workflow_logging(client: Client, env: WorkflowEnvironment): # Also make sure it has some workflow info and correct funcName record = capturer.find_log("Signal: signal 1") assert ( - record - and record.__dict__["temporal_workflow"]["workflow_type"] - == "LoggingWorkflow" - and record.funcName == "my_signal" + record + and record.__dict__["temporal_workflow"]["workflow_type"] + == "LoggingWorkflow" + and record.funcName == "my_signal" ) # Since we enabled full info, make sure it's there assert isinstance(record.__dict__["workflow_info"], workflow.Info) # Check the log emitted by the update execution. record = capturer.find_log("Update: update 1") assert ( - record - and record.__dict__["temporal_workflow"]["update_id"] == "update-1" - and record.__dict__["temporal_workflow"]["update_name"] == "my_update" - and "'update_id': 'update-1'" in record.message - and "'update_name': 'my_update'" in record.message + record + and record.__dict__["temporal_workflow"]["update_id"] == "update-1" + and record.__dict__["temporal_workflow"]["update_name"] == "my_update" + and "'update_id': 'update-1'" in record.message + and "'update_name': 'my_update'" in record.message ) # Clear queue and start a new one with more signals capturer.log_queue.queue.clear() async with new_worker( - client, - LoggingWorkflow, - task_queue=worker.task_queue, - max_cached_workflows=0, + client, + LoggingWorkflow, + task_queue=worker.task_queue, + max_cached_workflows=0, ) as worker: # Send signals and updates await handle.signal(LoggingWorkflow.my_signal, "signal 3") @@ -2043,10 +2043,10 @@ async def run(self) -> None: async def test_workflow_logging_task_fail(client: Client): with LogCapturer().logs_captured( - activity.logger.base_logger, temporalio.worker._workflow_instance.logger + activity.logger.base_logger, temporalio.worker._workflow_instance.logger ) as capturer: async with new_worker( - client, TaskFailOnceWorkflow, activities=[task_fail_once_activity] + client, TaskFailOnceWorkflow, activities=[task_fail_once_activity] ) as worker: await client.execute_workflow( TaskFailOnceWorkflow.run, @@ -2058,16 +2058,16 @@ async def test_workflow_logging_task_fail(client: Client): assert wf_task_record assert "Intentional workflow task failure" in wf_task_record.message assert ( - getattr(wf_task_record, "temporal_workflow")["workflow_type"] - == "TaskFailOnceWorkflow" + getattr(wf_task_record, "temporal_workflow")["workflow_type"] + == "TaskFailOnceWorkflow" ) act_task_record = capturer.find_log("Completing activity as failed") assert act_task_record assert "Intentional activity task failure" in act_task_record.message assert ( - getattr(act_task_record, "temporal_activity")["activity_type"] - == "task_fail_once_activity" + getattr(act_task_record, "temporal_activity")["activity_type"] + == "task_fail_once_activity" ) @@ -2102,7 +2102,7 @@ def status(self) -> str: async def test_workflow_stack_trace(client: Client): async with new_worker( - client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: handle = await client.start_workflow( StackTraceWorkflow.run, @@ -2156,7 +2156,7 @@ async def test_workflow_enhanced_stack_trace(client: Client): """ async with new_worker( - client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] + client, StackTraceWorkflow, LongSleepWorkflow, activities=[wait_cancel] ) as worker: handle = await client.start_workflow( StackTraceWorkflow.run, @@ -2195,9 +2195,9 @@ async def status() -> str: async def test_workflow_external_enhanced_stack_trace(client: Client): async with new_worker( - client, - ExternalStackTraceWorkflow, - activities=[external_wait_cancel], + client, + ExternalStackTraceWorkflow, + activities=[external_wait_cancel], ) as worker: handle = await client.start_workflow( ExternalStackTraceWorkflow.run, @@ -2227,8 +2227,8 @@ async def status() -> str: assert fn is not None assert ( - 'status[0] = "waiting" # external coroutine test' - in trace.sources[fn].content + 'status[0] = "waiting" # external coroutine test' + in trace.sources[fn].content ) assert trace.sdk.version == __version__ @@ -2243,7 +2243,7 @@ def assert_expected(self) -> None: assert self.field1 == "some value" -T = typing.TypeVar('T') +T = typing.TypeVar("T") @dataclass @@ -2264,7 +2264,9 @@ async def data_class_typed_activity(param: MyDataClass) -> MyDataClass: @activity.defn -async def generic_data_class_typed_activity(param: MyGenericDataClass[str]) -> MyGenericDataClass[str]: +async def generic_data_class_typed_activity( + param: MyGenericDataClass[str], +) -> MyGenericDataClass[str]: param.assert_expected() return param @@ -2382,8 +2384,9 @@ async def test_workflow_dataclass_typed(client: Client, env: WorkflowEnvironment "Java test server: https://github.com/temporalio/sdk-core/issues/390" ) async with new_worker( - client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_data_class_typed_activity] + client, + DataClassTypedWorkflow, + activities=[data_class_typed_activity, generic_data_class_typed_activity], ) as worker: val = MyDataClass(field1="some value") handle = await client.start_workflow( @@ -2408,7 +2411,7 @@ async def test_workflow_separate_protocol(client: Client): # This test is to confirm that protocols can be used as "interfaces" for # when the workflow impl is absent async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, DataClassTypedWorkflow, activities=[data_class_typed_activity] ) as worker: assert isinstance(DataClassTypedWorkflow(), DataClassTypedWorkflowProto) val = MyDataClass(field1="some value") @@ -2430,7 +2433,7 @@ async def test_workflow_separate_abstract(client: Client): # This test is to confirm that abstract classes can be used as "interfaces" # for when the workflow impl is absent async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, DataClassTypedWorkflow, activities=[data_class_typed_activity] ) as worker: assert issubclass(DataClassTypedWorkflow, DataClassTypedWorkflowAbstract) val = MyDataClass(field1="some value") @@ -2490,7 +2493,7 @@ async def test_workflow_child_already_started(client: Client, env: WorkflowEnvir "Java test server: https://github.com/temporalio/sdk-java/issues/1220" ) async with new_worker( - client, ChildAlreadyStartedWorkflow, LongSleepWorkflow + client, ChildAlreadyStartedWorkflow, LongSleepWorkflow ) as worker: with pytest.raises(WorkflowFailureError) as err: await client.execute_workflow( @@ -2538,10 +2541,10 @@ async def run(self) -> None: async def test_workflow_typed_config(client: Client): async with new_worker( - client, - TypedConfigWorkflow, - FailUntilAttemptWorkflow, - activities=[fail_until_attempt_activity], + client, + TypedConfigWorkflow, + FailUntilAttemptWorkflow, + activities=[fail_until_attempt_activity], ) as worker: await client.execute_workflow( TypedConfigWorkflow.run, @@ -2584,7 +2587,7 @@ async def run(self) -> None: async def test_workflow_local_activity_backoff(client: Client): workflow_id = f"workflow-{uuid.uuid4()}" async with new_worker( - client, LocalActivityBackoffWorkflow, activities=[fail_until_attempt_activity] + client, LocalActivityBackoffWorkflow, activities=[fail_until_attempt_activity] ) as worker: await client.execute_workflow( LocalActivityBackoffWorkflow.run, @@ -2627,7 +2630,7 @@ async def run(self) -> None: async def test_workflow_deadlock(client: Client): # Disable safe eviction so the worker can complete async with new_worker( - client, DeadlockedWorkflow, disable_safe_workflow_eviction=True + client, DeadlockedWorkflow, disable_safe_workflow_eviction=True ) as worker: if worker._workflow_worker: worker._workflow_worker._deadlock_timeout_seconds = 1 @@ -2777,7 +2780,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Run a simple pre-patch workflow. Need to disable workflow cache since we # restart the worker and don't want to pay the sticky queue penalty. async with new_worker( - client, PrePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, PrePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): pre_patch_handle = await execute() assert "pre-patch" == await query_result(pre_patch_handle) @@ -2785,7 +2788,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Confirm patched workflow gives old result for pre-patched but new result # for patched async with new_worker( - client, PatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, PatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): patch_handle = await execute() assert "post-patch" == await query_result(patch_handle) @@ -2793,7 +2796,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Confirm what works during deprecated async with new_worker( - client, DeprecatePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, DeprecatePatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): deprecate_patch_handle = await execute() assert "post-patch" == await query_result(deprecate_patch_handle) @@ -2801,7 +2804,7 @@ async def query_result(handle: WorkflowHandle) -> str: # Confirm what works when deprecation gone async with new_worker( - client, PostPatchWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, PostPatchWorkflow, task_queue=task_queue, max_cached_workflows=0 ): post_patch_handle = await execute() assert "post-patch" == await query_result(post_patch_handle) @@ -2854,10 +2857,10 @@ async def test_workflow_patch_memoized(client: Client): # to pay the sticky queue penalty. task_queue = f"tq-{uuid.uuid4()}" async with Worker( - client, - task_queue=task_queue, - workflows=[PatchMemoizedWorkflowUnpatched], - max_cached_workflows=0, + client, + task_queue=task_queue, + workflows=[PatchMemoizedWorkflowUnpatched], + max_cached_workflows=0, ): pre_patch_handle = await client.start_workflow( PatchMemoizedWorkflowUnpatched.run, @@ -2875,10 +2878,10 @@ async def waiting_signal() -> bool: # Now start the worker again, but this time with a patched workflow async with Worker( - client, - task_queue=task_queue, - workflows=[PatchMemoizedWorkflowPatched], - max_cached_workflows=0, + client, + task_queue=task_queue, + workflows=[PatchMemoizedWorkflowPatched], + max_cached_workflows=0, ): # Start a new workflow post patch post_patch_handle = await client.start_workflow( @@ -2894,10 +2897,10 @@ async def waiting_signal() -> bool: # Confirm expected values assert ["some-value"] == await pre_patch_handle.result() assert [ - "pre-patch", - "some-value", - "post-patch", - ] == await post_patch_handle.result() + "pre-patch", + "some-value", + "post-patch", + ] == await post_patch_handle.result() @workflow.defn @@ -2917,7 +2920,7 @@ def result(self) -> str: async def test_workflow_uuid(client: Client): task_queue = str(uuid.uuid4()) async with new_worker( - client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 ): # Get two handle UUID results. Need to disable workflow cache since we # restart the worker and don't want to pay the sticky queue penalty. @@ -2943,7 +2946,7 @@ async def test_workflow_uuid(client: Client): # Now confirm those results are the same even on a new worker async with new_worker( - client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 + client, UUIDWorkflow, task_queue=task_queue, max_cached_workflows=0 ): assert handle1_query_result == await handle1.query(UUIDWorkflow.result) assert handle2_query_result == await handle2.query(UUIDWorkflow.result) @@ -2972,7 +2975,7 @@ async def run(self, to_add: MyDataClass) -> MyDataClass: async def test_workflow_activity_callable_class(client: Client): activity_instance = CallableClassActivity("in worker") async with new_worker( - client, ActivityCallableClassWorkflow, activities=[activity_instance] + client, ActivityCallableClassWorkflow, activities=[activity_instance] ) as worker: result = await client.execute_workflow( ActivityCallableClassWorkflow.run, @@ -3022,9 +3025,9 @@ async def run(self, to_add: MyDataClass) -> MyDataClass: async def test_workflow_activity_method(client: Client): activity_instance = MethodActivity("in worker") async with new_worker( - client, - ActivityMethodWorkflow, - activities=[activity_instance.add, activity_instance.add_multi], + client, + ActivityMethodWorkflow, + activities=[activity_instance.add, activity_instance.add_multi], ) as worker: result = await client.execute_workflow( ActivityMethodWorkflow.run, @@ -3065,8 +3068,8 @@ def waiting(self) -> bool: async def test_workflow_wait_condition_timeout(client: Client): async with new_worker( - client, - WaitConditionTimeoutWorkflow, + client, + WaitConditionTimeoutWorkflow, ) as worker: handle = await client.start_workflow( WaitConditionTimeoutWorkflow.run, @@ -3098,8 +3101,8 @@ def some_query(self) -> str: async def test_workflow_query_rpc_timeout(client: Client): # Run workflow under worker and confirm query works async with new_worker( - client, - HelloWorkflowWithQuery, + client, + HelloWorkflowWithQuery, ) as worker: handle = await client.start_workflow( HelloWorkflowWithQuery.run, @@ -3116,9 +3119,9 @@ async def test_workflow_query_rpc_timeout(client: Client): HelloWorkflowWithQuery.some_query, rpc_timeout=timedelta(seconds=1) ) assert ( - err.value.status == RPCStatusCode.CANCELLED - and "timeout" in str(err.value).lower() - ) or err.value.status == RPCStatusCode.DEADLINE_EXCEEDED + err.value.status == RPCStatusCode.CANCELLED + and "timeout" in str(err.value).lower() + ) or err.value.status == RPCStatusCode.DEADLINE_EXCEEDED @dataclass @@ -3257,7 +3260,7 @@ def cancel_timer(self) -> None: async def test_workflow_cancel_signal_and_timer_fired_in_same_task( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): # This test only works when we support time skipping if not env.supports_time_skipping: @@ -3271,7 +3274,7 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task( # Start worker for 30 mins. Need to disable workflow cache since we # restart the worker and don't want to pay the sticky queue penalty. async with new_worker( - client, CancelSignalAndTimerFiredInSameTaskWorkflow, max_cached_workflows=0 + client, CancelSignalAndTimerFiredInSameTaskWorkflow, max_cached_workflows=0 ) as worker: task_queue = worker.task_queue handle = await client.start_workflow( @@ -3292,10 +3295,10 @@ async def test_workflow_cancel_signal_and_timer_fired_in_same_task( # Start worker again and wait for workflow completion async with new_worker( - client, - CancelSignalAndTimerFiredInSameTaskWorkflow, - task_queue=task_queue, - max_cached_workflows=0, + client, + CancelSignalAndTimerFiredInSameTaskWorkflow, + task_queue=task_queue, + max_cached_workflows=0, ): # This used to not complete because a signal cancelling the timer was # not respected by the timer fire @@ -3327,7 +3330,7 @@ async def run(self) -> NoReturn: class CustomFailureConverter(DefaultFailureConverterWithEncodedAttributes): # We'll override from failure to convert back to our type def from_failure( - self, failure: Failure, payload_converter: PayloadConverter + self, failure: Failure, payload_converter: PayloadConverter ) -> BaseException: err = super().from_failure(failure, payload_converter) if isinstance(err, ApplicationError) and err.type == "MyCustomError": @@ -3349,7 +3352,7 @@ async def test_workflow_custom_failure_converter(client: Client): # Run workflow and confirm error async with new_worker( - client, CustomErrorWorkflow, activities=[custom_error_activity] + client, CustomErrorWorkflow, activities=[custom_error_activity] ) as worker: handle = await client.start_workflow( CustomErrorWorkflow.run, @@ -3396,11 +3399,11 @@ class OptionalParam: class OptionalParamWorkflow: @workflow.run async def run( - self, some_param: Optional[OptionalParam] = OptionalParam(some_string="default") + self, some_param: Optional[OptionalParam] = OptionalParam(some_string="default") ) -> Optional[OptionalParam]: assert some_param is None or ( - isinstance(some_param, OptionalParam) - and some_param.some_string in ["default", "foo"] + isinstance(some_param, OptionalParam) + and some_param.some_string in ["default", "foo"] ) return some_param @@ -3439,20 +3442,20 @@ class ExceptionRaisingPayloadConverter(DefaultPayloadConverter): def to_payloads(self, values: Sequence[Any]) -> List[Payload]: if any( - value == ExceptionRaisingPayloadConverter.bad_outbound_str - for value in values + value == ExceptionRaisingPayloadConverter.bad_outbound_str + for value in values ): raise ApplicationError("Intentional outbound converter failure") return super().to_payloads(values) def from_payloads( - self, payloads: Sequence[Payload], type_hints: Optional[List] = None + self, payloads: Sequence[Payload], type_hints: Optional[List] = None ) -> List[Any]: # Check if any payloads contain the bad data for payload in payloads: if ( - ExceptionRaisingPayloadConverter.bad_inbound_str.encode() - in payload.data + ExceptionRaisingPayloadConverter.bad_inbound_str.encode() + in payload.data ): raise ApplicationError("Intentional inbound converter failure") return super().from_payloads(payloads, type_hints) @@ -3568,7 +3571,7 @@ def some_query(self) -> ManualResultType: async def test_manual_result_type(client: Client): async with new_worker( - client, ManualResultTypeWorkflow, activities=[manual_result_type_activity] + client, ManualResultTypeWorkflow, activities=[manual_result_type_activity] ) as worker: # Workflow without result type and with res1 = await client.execute_workflow( @@ -3662,11 +3665,11 @@ async def test_cache_eviction_tear_down(client: Client): # chooses, but now we expect eviction to properly tear down tasks and # therefore we cancel them async with new_worker( - client, - CacheEvictionTearDownWorkflow, - WaitForeverWorkflow, - activities=[wait_forever_activity], - max_cached_workflows=0, + client, + CacheEvictionTearDownWorkflow, + WaitForeverWorkflow, + activities=[wait_forever_activity], + max_cached_workflows=0, ) as worker: # Put a hook to catch unraisable exceptions old_hook = sys.unraisablehook @@ -3732,7 +3735,7 @@ async def test_workflow_eviction_exception(client: Client): # Run workflow with no cache (forces eviction every step) async with new_worker( - client, EvictionCaptureExceptionWorkflow, max_cached_workflows=0 + client, EvictionCaptureExceptionWorkflow, max_cached_workflows=0 ) as worker: await client.execute_workflow( EvictionCaptureExceptionWorkflow.run, @@ -3744,8 +3747,8 @@ async def test_workflow_eviction_exception(client: Client): assert len(captured_eviction_exceptions) == 1 assert captured_eviction_exceptions[0].is_replaying assert ( - type(captured_eviction_exceptions[0].exception).__name__ - == "_WorkflowBeingEvictedError" + type(captured_eviction_exceptions[0].exception).__name__ + == "_WorkflowBeingEvictedError" ) @@ -3844,11 +3847,11 @@ async def assert_bad_query(bad_thing: str) -> None: # typing.Self only in 3.11+ if sys.version_info >= (3, 11): + @dataclass class AnnotatedWithSelfParam: some_str: str - @workflow.defn class WorkflowAnnotatedWithSelf: @workflow.run @@ -3856,7 +3859,6 @@ async def run(self: typing.Self, some_arg: AnnotatedWithSelfParam) -> str: assert isinstance(some_arg, AnnotatedWithSelfParam) return some_arg.some_str - async def test_workflow_annotated_with_self(client: Client): async with new_worker(client, WorkflowAnnotatedWithSelf) as worker: assert "foo" == await client.execute_workflow( @@ -3897,7 +3899,7 @@ async def test_workflow_custom_metrics(client: Client): # Run worker with default runtime which is noop meter just to confirm it # doesn't fail async with new_worker( - client, CustomMetricsWorkflow, activities=[custom_metrics_activity] + client, CustomMetricsWorkflow, activities=[custom_metrics_activity] ) as worker: await client.execute_workflow( CustomMetricsWorkflow.run, @@ -3926,7 +3928,7 @@ async def test_workflow_custom_metrics(client: Client): ) async with new_worker( - client, CustomMetricsWorkflow, activities=[custom_metrics_activity] + client, CustomMetricsWorkflow, activities=[custom_metrics_activity] ) as worker: # Record a gauge at runtime level gauge = runtime.metric_meter.with_additional_attributes( @@ -3948,7 +3950,7 @@ async def test_workflow_custom_metrics(client: Client): # Intentionally naive metric checker def matches_metric_line( - line: str, name: str, at_least_labels: Mapping[str, str], value: int + line: str, name: str, at_least_labels: Mapping[str, str], value: int ) -> bool: # Must have metric name if not line.startswith(name + "{"): @@ -3960,7 +3962,7 @@ def matches_metric_line( return line.endswith(f" {value}") def assert_metric_exists( - name: str, at_least_labels: Mapping[str, str], value: int + name: str, at_least_labels: Mapping[str, str], value: int ) -> None: assert any( matches_metric_line(line, name, at_least_labels, value) @@ -4081,7 +4083,7 @@ async def test_workflow_buffered_metrics(client: Client): runtime=runtime, ) async with new_worker( - client, CustomMetricsWorkflow, activities=[custom_metrics_activity] + client, CustomMetricsWorkflow, activities=[custom_metrics_activity] ) as worker: await client.execute_workflow( CustomMetricsWorkflow.run, @@ -4304,7 +4306,7 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir "Java test server: https://github.com/temporalio/sdk-java/issues/1903" ) async with new_worker( - client, UpdateHandlersWorkflow, activities=[say_hello] + client, UpdateHandlersWorkflow, activities=[say_hello] ) as worker: wf_id = f"update-handlers-workflow-{uuid.uuid4()}" handle = await client.start_workflow( @@ -4345,7 +4347,7 @@ async def test_workflow_update_handlers_happy(client: Client, env: WorkflowEnvir async def test_workflow_update_handlers_unhappy( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4396,8 +4398,8 @@ async def test_workflow_update_handlers_unhappy( await handle.execute_update("last_event", args=[121, "badarg"]) assert isinstance(err.value.cause, ApplicationError) assert ( - "last_event_validator() takes 2 positional arguments but 3 were given" - in err.value.cause.message + "last_event_validator() takes 2 positional arguments but 3 were given" + in err.value.cause.message ) # Un-deserializeable nonsense @@ -4430,7 +4432,7 @@ async def test_workflow_update_task_fails(client: Client, env: WorkflowEnvironme ) # Need to not sandbox so behavior can change based on globals async with new_worker( - client, UpdateHandlersWorkflow, workflow_runner=UnsandboxedWorkflowRunner() + client, UpdateHandlersWorkflow, workflow_runner=UnsandboxedWorkflowRunner() ) as worker: handle = await client.start_workflow( UpdateHandlersWorkflow.run, @@ -4468,7 +4470,7 @@ async def update(self) -> None: async def test_workflow_update_respects_first_execution_run_id( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4479,6 +4481,7 @@ async def test_workflow_update_respects_first_execution_run_id( # r1. workflow_id = f"update-respects-first-execution-run-id-{uuid.uuid4()}" async with new_worker(client, UpdateRespectsFirstExecutionRunIdWorkflow) as worker: + async def start_workflow(workflow_id: str) -> WorkflowHandle: return await client.start_workflow( UpdateRespectsFirstExecutionRunIdWorkflow.run, @@ -4523,7 +4526,7 @@ def got_update(self) -> str: async def test_workflow_update_before_worker_start( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4561,10 +4564,10 @@ async def test_workflow_update_before_worker_start( # Start no-cache worker on the task queue async with new_worker( - client, - ImmediatelyCompleteUpdateAndWorkflow, - task_queue=task_queue, - max_cached_workflows=0, + client, + ImmediatelyCompleteUpdateAndWorkflow, + task_queue=task_queue, + max_cached_workflows=0, ): # Confirm workflow completed as expected assert "workflow-done" == await handle.result() @@ -4597,7 +4600,7 @@ async def signal(self) -> None: async def test_workflow_update_separate_handle( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4645,7 +4648,7 @@ async def do_update(self, sleep: float) -> None: async def test_workflow_update_timeout_or_cancel( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4787,7 +4790,7 @@ async def test_workflow_timeout_support(client: Client, approach: str): if sys.version_info < (3, 11): pytest.skip("Timeout only in >= 3.11") async with new_worker( - client, TimeoutSupportWorkflow, activities=[wait_cancel] + client, TimeoutSupportWorkflow, activities=[wait_cancel] ) as worker: # Run and confirm activity gets cancelled handle = await client.start_workflow( @@ -4805,8 +4808,8 @@ async def test_workflow_timeout_support(client: Client, approach: str): async for e in handle.fetch_history_events(): if e.HasField("timer_started_event_attributes"): assert ( - e.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() - == 200 + e.timer_started_event_attributes.start_to_fire_timeout.ToMilliseconds() + == 200 ) found_timer = True break @@ -4836,18 +4839,18 @@ async def finish(self): async def test_workflow_current_build_id_appropriately_set( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip("Java test server does not support worker versioning") task_queue = str(uuid.uuid4()) async with new_worker( - client, - BuildIDInfoWorkflow, - activities=[say_hello], - build_id="1.0", - task_queue=task_queue, + client, + BuildIDInfoWorkflow, + activities=[say_hello], + build_id="1.0", + task_queue=task_queue, ) as worker: handle = await client.start_workflow( BuildIDInfoWorkflow.run, @@ -4866,11 +4869,11 @@ async def test_workflow_current_build_id_appropriately_set( ) async with new_worker( - client, - BuildIDInfoWorkflow, - activities=[say_hello], - build_id="1.1", - task_queue=task_queue, + client, + BuildIDInfoWorkflow, + activities=[say_hello], + build_id="1.1", + task_queue=task_queue, ) as worker: bid = await handle.query(BuildIDInfoWorkflow.get_build_id) assert bid == "1.0" @@ -4942,7 +4945,7 @@ async def run(self, scenario: FailureTypesScenario) -> None: async def test_workflow_failure_types_configured( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -4951,14 +4954,14 @@ async def test_workflow_failure_types_configured( # Asserter for a single scenario async def assert_scenario( - workflow: Type[FailureTypesWorkflowBase], - *, - expect_task_fail: bool, - fail_message_contains: str, - worker_level_failure_exception_type: Optional[Type[Exception]] = None, - workflow_scenario: Optional[FailureTypesScenario] = None, - signal_scenario: Optional[FailureTypesScenario] = None, - update_scenario: Optional[FailureTypesScenario] = None, + workflow: Type[FailureTypesWorkflowBase], + *, + expect_task_fail: bool, + fail_message_contains: str, + worker_level_failure_exception_type: Optional[Type[Exception]] = None, + workflow_scenario: Optional[FailureTypesScenario] = None, + signal_scenario: Optional[FailureTypesScenario] = None, + update_scenario: Optional[FailureTypesScenario] = None, ) -> None: logging.debug( "Asserting scenario %s", @@ -4973,12 +4976,12 @@ async def assert_scenario( }, ) async with new_worker( - client, - workflow, - max_cached_workflows=0, - workflow_failure_exception_types=[worker_level_failure_exception_type] - if worker_level_failure_exception_type - else [], + client, + workflow, + max_cached_workflows=0, + workflow_failure_exception_types=[worker_level_failure_exception_type] + if worker_level_failure_exception_type + else [], ) as worker: # Start workflow handle = await client.start_workflow( @@ -5004,9 +5007,9 @@ async def assert_scenario( async def has_expected_task_fail() -> bool: async for e in handle.fetch_history_events(): if ( - e.HasField("workflow_task_failed_event_attributes") - and fail_message_contains - in e.workflow_task_failed_event_attributes.failure.message + e.HasField("workflow_task_failed_event_attributes") + and fail_message_contains + in e.workflow_task_failed_event_attributes.failure.message ): return True return False @@ -5019,9 +5022,9 @@ async def has_expected_task_fail() -> bool: # Update does not throw on non-determinism, the workflow # does instead if ( - update_handle - and update_scenario - == FailureTypesScenario.THROW_CUSTOM_EXCEPTION + update_handle + and update_scenario + == FailureTypesScenario.THROW_CUSTOM_EXCEPTION ): await update_handle.result() else: @@ -5031,11 +5034,11 @@ async def has_expected_task_fail() -> bool: # Run a scenario async def run_scenario( - workflow: Type[FailureTypesWorkflowBase], - scenario: FailureTypesScenario, - *, - expect_task_fail: bool = False, - worker_level_failure_exception_type: Optional[Type[Exception]] = None, + workflow: Type[FailureTypesWorkflowBase], + scenario: FailureTypesScenario, + *, + expect_task_fail: bool = False, + worker_level_failure_exception_type: Optional[Type[Exception]] = None, ) -> None: # Run for workflow, signal, and update fail_message_contains = ( @@ -5176,11 +5179,11 @@ async def any_task_completed(handle: WorkflowHandle) -> bool: # Now start the worker on the first env async with Worker( - client, - task_queue=task_queue, - workflows=[TickingWorkflow], - max_cached_workflows=0, - max_concurrent_workflow_task_polls=1, + client, + task_queue=task_queue, + workflows=[TickingWorkflow], + max_cached_workflows=0, + max_concurrent_workflow_task_polls=1, ) as worker: # Confirm the first ticking workflow has completed a task but not # the second @@ -5227,10 +5230,10 @@ async def run(self) -> List[str]: async def test_workflow_as_completed_utility(client: Client): # Disable cache to force replay async with new_worker( - client, - AsCompletedWorkflow, - activities=[return_name_activity], - max_cached_workflows=0, + client, + AsCompletedWorkflow, + activities=[return_name_activity], + max_cached_workflows=0, ) as worker: # This would fail if we used asyncio.as_completed in the workflow result = await client.execute_workflow( @@ -5269,7 +5272,7 @@ async def new_activity_name(index: int) -> str: async def test_workflow_wait_utility(client: Client): # Disable cache to force replay async with new_worker( - client, WaitWorkflow, activities=[return_name_activity], max_cached_workflows=0 + client, WaitWorkflow, activities=[return_name_activity], max_cached_workflows=0 ) as worker: # This would fail if we used asyncio.wait in the workflow result = await client.execute_workflow( @@ -5491,9 +5494,9 @@ async def _workflow_task_failed(self, workflow_id: str) -> bool: return False async def _get_workflow_result_and_warning( - self, - wait_all_handlers_finished: bool, - unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, + self, + wait_all_handlers_finished: bool, + unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, ) -> Tuple[bool, bool]: with pytest.WarningsRecorder() as warnings: wf_result = await self._get_workflow_result( @@ -5506,10 +5509,10 @@ async def _get_workflow_result_and_warning( return wf_result, unfinished_handler_warning_emitted async def _get_workflow_result( - self, - wait_all_handlers_finished: bool, - unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, - handle_future: Optional[asyncio.Future[WorkflowHandle]] = None, + self, + wait_all_handlers_finished: bool, + unfinished_policy: Optional[workflow.HandlerUnfinishedPolicy] = None, + handle_future: Optional[asyncio.Future[WorkflowHandle]] = None, ) -> bool: handle = await self.client.start_workflow( UnfinishedHandlersWarningsWorkflow.run, @@ -5551,30 +5554,30 @@ def __init__(self) -> None: @workflow.run async def run( - self, - workflow_termination_type: Literal[ - "-cancellation-", - "-failure-", - "-continue-as-new-", - "-fail-post-continue-as-new-run-", - ], - handler_registration: Literal["-late-registered-", "-not-late-registered-"], - handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], - handler_waiting: Literal[ - "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" - ], + self, + workflow_termination_type: Literal[ + "-cancellation-", + "-failure-", + "-continue-as-new-", + "-fail-post-continue-as-new-run-", + ], + handler_registration: Literal["-late-registered-", "-not-late-registered-"], + handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], + handler_waiting: Literal[ + "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" + ], ) -> NoReturn: if handler_registration == "-late-registered-": if handler_dynamism == "-dynamic-": async def my_late_registered_dynamic_update( - name: str, args: Sequence[RawValue] + name: str, args: Sequence[RawValue] ) -> str: await workflow.wait_condition(lambda: self.handlers_may_finish) return "my-late-registered-dynamic-update-result" async def my_late_registered_dynamic_signal( - name: str, args: Sequence[RawValue] + name: str, args: Sequence[RawValue] ) -> None: await workflow.wait_condition(lambda: self.handlers_may_finish) @@ -5651,17 +5654,17 @@ async def my_dynamic_signal(self, name: str, args: Sequence[RawValue]) -> None: "workflow_termination_type", ["-cancellation-", "-failure-", "-continue-as-new-"] ) async def test_unfinished_handler_on_workflow_termination( - client: Client, - env: WorkflowEnvironment, - handler_type: Literal["-signal-", "-update-"], - handler_registration: Literal["-late-registered-", "-not-late-registered-"], - handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], - handler_waiting: Literal[ - "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" - ], - workflow_termination_type: Literal[ - "-cancellation-", "-failure-", "-continue-as-new-" - ], + client: Client, + env: WorkflowEnvironment, + handler_type: Literal["-signal-", "-update-"], + handler_registration: Literal["-late-registered-", "-not-late-registered-"], + handler_dynamism: Literal["-dynamic-", "-not-dynamic-"], + handler_waiting: Literal[ + "-wait-all-handlers-finish-", "-no-wait-all-handlers-finish-" + ], + workflow_termination_type: Literal[ + "-cancellation-", "-failure-", "-continue-as-new-" + ], ): skip_unfinished_handler_tests_in_older_python() if handler_type == "-update-" and env.supports_time_skipping: @@ -5692,10 +5695,10 @@ class _UnfinishedHandlersOnWorkflowTerminationTest: ] async def test_warning_is_issued_on_exit_with_unfinished_handler( - self, + self, ): assert await self._run_workflow_and_get_warning() == ( - self.handler_waiting == "-no-wait-all-handlers-finish-" + self.handler_waiting == "-no-wait-all-handlers-finish-" ) async def _run_workflow_and_get_warning(self) -> bool: @@ -5749,9 +5752,9 @@ async def _run_workflow_and_get_warning(self) -> bool: await handle.signal(signal_method) # type: ignore async with new_worker( - self.client, - UnfinishedHandlersOnWorkflowTerminationWorkflow, - task_queue=task_queue, + self.client, + UnfinishedHandlersOnWorkflowTerminationWorkflow, + task_queue=task_queue, ): with pytest.WarningsRecorder() as warnings: if self.handler_type == "-update-": @@ -5764,7 +5767,7 @@ async def _run_workflow_and_get_warning(self) -> bool: update_err = err_info.value assert isinstance(update_err.cause, ApplicationError) assert ( - update_err.cause.type == "AcceptedUpdateCompletedWorkflow" + update_err.cause.type == "AcceptedUpdateCompletedWorkflow" ) with pytest.raises(WorkflowFailureError) as err: @@ -5779,8 +5782,8 @@ async def _run_workflow_and_get_warning(self) -> bool: ) if self.workflow_termination_type == "-continue-as-new-": assert ( - str(err.value.cause) - == "Deliberately failing post-ContinueAsNew run" + str(err.value.cause) + == "Deliberately failing post-ContinueAsNew run" ) unfinished_handler_warning_emitted = any( @@ -5880,8 +5883,8 @@ async def my_update(self) -> str: async def test_update_completion_is_honored_when_after_workflow_return_1( - client: Client, - env: WorkflowEnvironment, + client: Client, + env: WorkflowEnvironment, ): if env.supports_time_skipping: pytest.skip( @@ -5903,9 +5906,9 @@ async def test_update_completion_is_honored_when_after_workflow_return_1( await workflow_update_exists(client, wf_handle.id, update_id) async with Worker( - client, - task_queue=task_queue, - workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturn1Workflow], + client, + task_queue=task_queue, + workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturn1Workflow], ): assert await wf_handle.result() == "workflow-result" assert await update_result_task == "update-result" @@ -5936,17 +5939,17 @@ async def my_update(self) -> str: async def test_update_completion_is_honored_when_after_workflow_return_2( - client: Client, - env: WorkflowEnvironment, + client: Client, + env: WorkflowEnvironment, ): if env.supports_time_skipping: pytest.skip( "Java test server: https://github.com/temporalio/sdk-java/issues/1903" ) async with Worker( - client, - task_queue="tq", - workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturnWorkflow2], + client, + task_queue="tq", + workflows=[UpdateCompletionIsHonoredWhenAfterWorkflowReturnWorkflow2], ) as worker: handle = await client.start_workflow( UpdateCompletionIsHonoredWhenAfterWorkflowReturnWorkflow2.run, @@ -6023,7 +6026,7 @@ async def test_first_of_two_signal_completion_commands_is_honored(client: Client async def test_workflow_return_is_honored_when_it_precedes_signal_completion_command( - client: Client, + client: Client, ): await _do_first_completion_command_is_honored_test( client, main_workflow_returns_before_signal_completions=True @@ -6031,7 +6034,7 @@ async def test_workflow_return_is_honored_when_it_precedes_signal_completion_com async def _do_first_completion_command_is_honored_test( - client: Client, main_workflow_returns_before_signal_completions: bool + client: Client, main_workflow_returns_before_signal_completions: bool ): workflow_cls: Union[ Type[FirstCompletionCommandIsHonoredPingPongWorkflow], @@ -6042,9 +6045,9 @@ async def _do_first_completion_command_is_honored_test( else FirstCompletionCommandIsHonoredWorkflow ) async with Worker( - client, - task_queue="tq", - workflows=[workflow_cls], + client, + task_queue="tq", + workflows=[workflow_cls], ) as worker: handle = await client.start_workflow( workflow_cls.run, @@ -6064,8 +6067,8 @@ async def _do_first_completion_command_is_honored_test( assert str(err.cause).startswith("Client should see this error") else: assert ( - main_workflow_returns_before_signal_completions - and result == "workflow-result" + main_workflow_returns_before_signal_completions + and result == "workflow-result" ) @@ -6090,7 +6093,7 @@ async def my_signal(self): async def test_timer_started_after_workflow_completion(client: Client): async with new_worker( - client, TimerStartedAfterWorkflowCompletionWorkflow + client, TimerStartedAfterWorkflowCompletionWorkflow ) as worker: handle = await client.start_workflow( TimerStartedAfterWorkflowCompletionWorkflow.run, @@ -6125,7 +6128,7 @@ async def run(self) -> None: async def test_activity_retry_delay(client: Client): async with new_worker( - client, ActivitiesWithRetryDelayWorkflow, activities=[activity_with_retry_delay] + client, ActivitiesWithRetryDelayWorkflow, activities=[activity_with_retry_delay] ) as worker: try: await client.execute_workflow( @@ -6137,11 +6140,11 @@ async def test_activity_retry_delay(client: Client): assert isinstance(err.cause, ActivityError) assert isinstance(err.cause.cause, ApplicationError) assert ( - str(err.cause.cause) == ActivitiesWithRetryDelayWorkflow.error_message + str(err.cause.cause) == ActivitiesWithRetryDelayWorkflow.error_message ) assert ( - err.cause.cause.next_retry_delay - == ActivitiesWithRetryDelayWorkflow.next_retry_delay + err.cause.cause.next_retry_delay + == ActivitiesWithRetryDelayWorkflow.next_retry_delay ) @@ -6204,7 +6207,7 @@ async def run(self, _: str) -> str: ], ) async def test_update_in_first_wft_sees_workflow_init( - client: Client, client_cls: Type, worker_cls: Type + client: Client, client_cls: Type, worker_cls: Type ): """ Test how @workflow.init affects what an update in the first WFT sees. @@ -6305,7 +6308,7 @@ async def test_user_metadata_is_set(client: Client, env: WorkflowEnvironment): "Java test server: https://github.com/temporalio/sdk-java/issues/2219" ) async with new_worker( - client, UserMetadataWorkflow, activities=[say_hello] + client, UserMetadataWorkflow, activities=[say_hello] ) as worker: handle = await client.start_workflow( UserMetadataWorkflow.run, @@ -6435,7 +6438,7 @@ async def make_timers(self, start: int, end: int): async def test_concurrent_sleeps_use_proper_options( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: pytest.skip( @@ -6480,10 +6483,10 @@ class BadFailureConverterError(Exception): class BadFailureConverter(DefaultFailureConverter): def to_failure( - self, - exception: BaseException, - payload_converter: PayloadConverter, - failure: Failure, + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: Failure, ) -> None: if isinstance(exception, BadFailureConverterError): raise RuntimeError("Intentional failure conversion error") @@ -6517,7 +6520,7 @@ async def test_bad_failure_converter(client: Client): ) client = Client(**config) async with new_worker( - client, BadFailureConverterWorkflow, activities=[bad_failure_converter_activity] + client, BadFailureConverterWorkflow, activities=[bad_failure_converter_activity] ) as worker: # Check activity with pytest.raises(WorkflowFailureError) as err: @@ -6530,8 +6533,8 @@ async def test_bad_failure_converter(client: Client): assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, ApplicationError) assert ( - err.value.cause.cause.message - == "Failed building exception result: Intentional failure conversion error" + err.value.cause.cause.message + == "Failed building exception result: Intentional failure conversion error" ) # Check workflow @@ -6630,10 +6633,10 @@ async def test_async_loop_ordering(client: Client, env: WorkflowEnvironment): await handle.signal(SignalsActivitiesTimersUpdatesTracingWorkflow.dosig, "before") async with new_worker( - client, - SignalsActivitiesTimersUpdatesTracingWorkflow, - activities=[say_hello], - task_queue=task_queue, + client, + SignalsActivitiesTimersUpdatesTracingWorkflow, + activities=[say_hello], + task_queue=task_queue, ): await asyncio.sleep(0.2) await handle.signal(SignalsActivitiesTimersUpdatesTracingWorkflow.dosig, "1") @@ -6690,18 +6693,18 @@ async def test_alternate_async_loop_ordering(client: Client, env: WorkflowEnviro ) async with new_worker( - client, - ActivityAndSignalsWhileWorkflowDown, - activities=[say_hello], - task_queue=task_queue, + client, + ActivityAndSignalsWhileWorkflowDown, + activities=[say_hello], + task_queue=task_queue, ): # This sleep exists to make sure the first WFT is processed await asyncio.sleep(0.2) async with new_worker( - client, - activities=[say_hello], - task_queue=activity_tq, + client, + activities=[say_hello], + task_queue=activity_tq, ): # Make sure the activity starts being processed before sending signals await asyncio.sleep(1) @@ -6709,10 +6712,10 @@ async def test_alternate_async_loop_ordering(client: Client, env: WorkflowEnviro await handle.signal(ActivityAndSignalsWhileWorkflowDown.dosig, "2") async with new_worker( - client, - ActivityAndSignalsWhileWorkflowDown, - activities=[say_hello], - task_queue=task_queue, + client, + ActivityAndSignalsWhileWorkflowDown, + activities=[say_hello], + task_queue=task_queue, ): await handle.result() @@ -6761,8 +6764,8 @@ def init(self, params: UseLockOrSemaphoreWorkflowParameters): @workflow.run async def run( - self, - params: Optional[UseLockOrSemaphoreWorkflowParameters], + self, + params: Optional[UseLockOrSemaphoreWorkflowParameters], ) -> LockOrSemaphoreWorkflowConcurrencySummary: # TODO: Use workflow init method when it exists. assert params @@ -6819,8 +6822,8 @@ def __init__(self) -> None: @workflow.run async def run( - self, - _: Optional[UseLockOrSemaphoreWorkflowParameters] = None, + self, + _: Optional[UseLockOrSemaphoreWorkflowParameters] = None, ) -> LockOrSemaphoreWorkflowConcurrencySummary: await workflow.wait_condition(lambda: self.workflow_may_exit) return LockOrSemaphoreWorkflowConcurrencySummary( @@ -6842,14 +6845,14 @@ async def finish(self): async def _do_workflow_coroutines_lock_or_semaphore_test( - client: Client, - params: UseLockOrSemaphoreWorkflowParameters, - expectation: LockOrSemaphoreWorkflowConcurrencySummary, + client: Client, + params: UseLockOrSemaphoreWorkflowParameters, + expectation: LockOrSemaphoreWorkflowConcurrencySummary, ): async with new_worker( - client, - CoroutinesUseLockOrSemaphoreWorkflow, - activities=[noop_activity_for_lock_or_semaphore_tests], + client, + CoroutinesUseLockOrSemaphoreWorkflow, + activities=[noop_activity_for_lock_or_semaphore_tests], ) as worker: summary = await client.execute_workflow( CoroutinesUseLockOrSemaphoreWorkflow.run, @@ -6861,11 +6864,11 @@ async def _do_workflow_coroutines_lock_or_semaphore_test( async def _do_update_handler_lock_or_semaphore_test( - client: Client, - env: WorkflowEnvironment, - params: UseLockOrSemaphoreWorkflowParameters, - n_updates: int, - expectation: LockOrSemaphoreWorkflowConcurrencySummary, + client: Client, + env: WorkflowEnvironment, + params: UseLockOrSemaphoreWorkflowParameters, + n_updates: int, + expectation: LockOrSemaphoreWorkflowConcurrencySummary, ): if env.supports_time_skipping: pytest.skip( @@ -6890,10 +6893,10 @@ async def _do_update_handler_lock_or_semaphore_test( for i in range(n_updates) ] async with new_worker( - client, - HandlerCoroutinesUseLockOrSemaphoreWorkflow, - activities=[noop_activity_for_lock_or_semaphore_tests], - task_queue=task_queue, + client, + HandlerCoroutinesUseLockOrSemaphoreWorkflow, + activities=[noop_activity_for_lock_or_semaphore_tests], + task_queue=task_queue, ): for update_task in admitted_updates: await update_task @@ -6914,7 +6917,7 @@ async def test_workflow_coroutines_can_use_lock(client: Client): async def test_update_handler_can_use_lock_to_serialize_handler_executions( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -6940,7 +6943,7 @@ async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Cli async def test_update_handler_lock_acquisition_respects_timeout( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -6966,7 +6969,7 @@ async def test_workflow_coroutines_can_use_semaphore(client: Client): async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -6981,7 +6984,7 @@ async def test_update_handler_can_use_semaphore_to_control_handler_execution_con async def test_workflow_coroutine_semaphore_acquisition_respects_timeout( - client: Client, + client: Client, ): await _do_workflow_coroutines_lock_or_semaphore_test( client, @@ -6997,7 +7000,7 @@ async def test_workflow_coroutine_semaphore_acquisition_respects_timeout( async def test_update_handler_semaphore_acquisition_respects_timeout( - client: Client, env: WorkflowEnvironment + client: Client, env: WorkflowEnvironment ): await _do_update_handler_lock_or_semaphore_test( client, @@ -7154,11 +7157,11 @@ async def test_workflow_deadlock_fill_up_slots(client: Client): # This worker used to not be able to shutdown because we hung evictions on # deadlock. async with new_worker( - client, - DeadlockFillUpBlockWorkflow, - DeadlockFillUpSimpleWorkflow, - # Start the worker with CPU count + 10 task slots - max_concurrent_workflow_tasks=cpu_count + 10, + client, + DeadlockFillUpBlockWorkflow, + DeadlockFillUpSimpleWorkflow, + # Start the worker with CPU count + 10 task slots + max_concurrent_workflow_tasks=cpu_count + 10, ) as worker: # For this test we're going to start cpu_count + 5 workflows that # deadlock. In previous SDK versions we defaulted to CPU count @@ -7258,7 +7261,7 @@ async def check_logs(): while True: log_record = log_queue.get(block=False) if log_record.message.startswith( - f"Timed out running eviction job for run ID {handle.result_run_id}" + f"Timed out running eviction job for run ID {handle.result_run_id}" ): return except queue.Empty: @@ -7286,7 +7289,7 @@ async def check_priority_activity(should_have_priorty: int) -> str: class WorkflowUsingPriorities: @workflow.run async def run( - self, expected_priority: Optional[int], stop_after_check: bool + self, expected_priority: Optional[int], stop_after_check: bool ) -> str: assert workflow.info().priority.priority_key == expected_priority if stop_after_check: @@ -7318,7 +7321,7 @@ async def test_workflow_priorities(client: Client, env: WorkflowEnvironment): ) async with new_worker( - client, WorkflowUsingPriorities, HelloWorkflow, activities=[say_hello] + client, WorkflowUsingPriorities, HelloWorkflow, activities=[say_hello] ) as worker: handle = await client.start_workflow( WorkflowUsingPriorities.run, @@ -7333,27 +7336,27 @@ async def test_workflow_priorities(client: Client, env: WorkflowEnvironment): async for e in handle.fetch_history_events(): if e.HasField("workflow_execution_started_event_attributes"): assert ( - e.workflow_execution_started_event_attributes.priority.priority_key - == 1 + e.workflow_execution_started_event_attributes.priority.priority_key + == 1 ) elif e.HasField( - "start_child_workflow_execution_initiated_event_attributes" + "start_child_workflow_execution_initiated_event_attributes" ): if first_child: assert ( - e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key - == 4 + e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key + == 4 ) first_child = False else: assert ( - e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key - == 2 + e.start_child_workflow_execution_initiated_event_attributes.priority.priority_key + == 2 ) elif e.HasField("activity_task_scheduled_event_attributes"): assert ( - e.activity_task_scheduled_event_attributes.priority.priority_key - == 5 + e.activity_task_scheduled_event_attributes.priority.priority_key + == 5 ) # Verify a workflow started without priorities sees None for the key @@ -7396,7 +7399,7 @@ async def test_expose_root_execution(client: Client, env: WorkflowEnvironment): "Java test server needs release with: https://github.com/temporalio/sdk-java/pull/2441" ) async with new_worker( - client, ExposeRootWorkflow, ExposeRootChildWorkflow + client, ExposeRootWorkflow, ExposeRootChildWorkflow ) as worker: parent_wf_id = f"workflow-{uuid.uuid4()}" child_wf_id = parent_wf_id + "_child" From 896ab3011f4ffa3eeb8fb18f758082a301b14d7a Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Mon, 14 Apr 2025 17:24:56 -0700 Subject: [PATCH 04/12] Added support for temporal_from_json method. --- temporalio/converter.py | 236 +++++++++++++++++++++------------------- 1 file changed, 122 insertions(+), 114 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 384dda5f..2bc08c59 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -68,7 +68,7 @@ class PayloadConverter(ABC): @abstractmethod def to_payloads( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> List[temporalio.api.common.v1.Payload]: """Encode values into payloads. @@ -90,9 +90,9 @@ def to_payloads( @abstractmethod def from_payloads( - self, - payloads: Sequence[temporalio.api.common.v1.Payload], - type_hints: Optional[List[Type]] = None, + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """Decode payloads into values. @@ -116,7 +116,7 @@ def from_payloads( raise NotImplementedError def to_payloads_wrapper( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> temporalio.api.common.v1.Payloads: """:py:meth:`to_payloads` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -124,7 +124,7 @@ def to_payloads_wrapper( return temporalio.api.common.v1.Payloads(payloads=self.to_payloads(values)) def from_payloads_wrapper( - self, payloads: Optional[temporalio.api.common.v1.Payloads] + self, payloads: Optional[temporalio.api.common.v1.Payloads] ) -> List[Any]: """:py:meth:`from_payloads` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -152,15 +152,15 @@ def from_payload(self, payload: temporalio.api.common.v1.Payload) -> Any: ... @overload def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Type[temporalio.types.AnyType], + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Type[temporalio.types.AnyType], ) -> temporalio.types.AnyType: ... def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """Convert a single payload to a value. @@ -205,9 +205,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: @abstractmethod def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """Decode a single payload to a Python value or raise exception. @@ -249,7 +249,7 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None: self.converters = {c.encoding.encode(): c for c in converters} def to_payloads( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> List[temporalio.api.common.v1.Payload]: """Encode values trying each converter. @@ -279,9 +279,9 @@ def to_payloads( return payloads def from_payloads( - self, - payloads: Sequence[temporalio.api.common.v1.Payload], - type_hints: Optional[List[Type]] = None, + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """Decode values trying each converter. @@ -346,9 +346,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" if len(payload.data) > 0: @@ -373,9 +373,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" return payload.data @@ -405,8 +405,8 @@ def encoding(self) -> str: def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: """See base class.""" if ( - isinstance(value, google.protobuf.message.Message) - and value.DESCRIPTOR is not None + isinstance(value, google.protobuf.message.Message) + and value.DESCRIPTOR is not None ): # We have to convert to dict then to JSON because MessageToJson does # not have a compact option removing spaces and newlines @@ -425,9 +425,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" message_type = payload.metadata.get("messageType", b"").decode() @@ -455,8 +455,8 @@ def encoding(self) -> str: def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: """See base class.""" if ( - isinstance(value, google.protobuf.message.Message) - and value.DESCRIPTOR is not None + isinstance(value, google.protobuf.message.Message) + and value.DESCRIPTOR is not None ): return temporalio.api.common.v1.Payload( metadata={ @@ -468,9 +468,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" message_type = payload.metadata.get("messageType", b"").decode() @@ -531,12 +531,12 @@ class JSONPlainPayloadConverter(EncodingPayloadConverter): _encoding: str def __init__( - self, - *, - encoder: Optional[Type[json.JSONEncoder]] = AdvancedJSONEncoder, - decoder: Optional[Type[json.JSONDecoder]] = None, - encoding: str = "json/plain", - custom_type_converters: Sequence[JSONTypeConverter] = [], + self, + *, + encoder: Optional[Type[json.JSONEncoder]] = AdvancedJSONEncoder, + decoder: Optional[Type[json.JSONDecoder]] = None, + encoding: str = "json/plain", + custom_type_converters: Sequence[JSONTypeConverter] = [], ) -> None: """Initialize a JSON data converter. @@ -575,9 +575,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: ) def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" try: @@ -604,7 +604,7 @@ class JSONTypeConverter(ABC): @abstractmethod def to_typed_value( - self, hint: Type, value: Any + self, hint: Type, value: Any ) -> Union[Optional[Any], _JSONTypeConverterUnhandled]: """Convert the given value to a type based on the given hint. @@ -628,7 +628,7 @@ class PayloadCodec(ABC): @abstractmethod async def encode( - self, payloads: Sequence[temporalio.api.common.v1.Payload] + self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> List[temporalio.api.common.v1.Payload]: """Encode the given payloads. @@ -644,7 +644,7 @@ async def encode( @abstractmethod async def decode( - self, payloads: Sequence[temporalio.api.common.v1.Payload] + self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> List[temporalio.api.common.v1.Payload]: """Decode the given payloads. @@ -689,9 +689,9 @@ async def decode_failure(self, failure: temporalio.api.failure.v1.Failure) -> No await self._apply_to_failure_payloads(failure, self.decode_wrapper) async def _apply_to_failure_payloads( - self, - failure: temporalio.api.failure.v1.Failure, - cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], + self, + failure: temporalio.api.failure.v1.Failure, + cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], ) -> None: if failure.HasField("encoded_attributes"): # Wrap in payloads and merge back @@ -701,19 +701,19 @@ async def _apply_to_failure_payloads( await cb(payloads) failure.encoded_attributes.CopyFrom(payloads.payloads[0]) if failure.HasField( - "application_failure_info" + "application_failure_info" ) and failure.application_failure_info.HasField("details"): await cb(failure.application_failure_info.details) elif failure.HasField( - "timeout_failure_info" + "timeout_failure_info" ) and failure.timeout_failure_info.HasField("last_heartbeat_details"): await cb(failure.timeout_failure_info.last_heartbeat_details) elif failure.HasField( - "canceled_failure_info" + "canceled_failure_info" ) and failure.canceled_failure_info.HasField("details"): await cb(failure.canceled_failure_info.details) elif failure.HasField( - "reset_workflow_failure_info" + "reset_workflow_failure_info" ) and failure.reset_workflow_failure_info.HasField("last_heartbeat_details"): await cb(failure.reset_workflow_failure_info.last_heartbeat_details) if failure.HasField("cause"): @@ -734,10 +734,10 @@ class FailureConverter(ABC): @abstractmethod def to_failure( - self, - exception: BaseException, - payload_converter: PayloadConverter, - failure: temporalio.api.failure.v1.Failure, + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, ) -> None: """Convert the given exception to a Temporal failure. @@ -752,9 +752,9 @@ def to_failure( @abstractmethod def from_failure( - self, - failure: temporalio.api.failure.v1.Failure, - payload_converter: PayloadConverter, + self, + failure: temporalio.api.failure.v1.Failure, + payload_converter: PayloadConverter, ) -> BaseException: """Convert the given Temporal failure to an exception. @@ -789,10 +789,10 @@ def __init__(self, *, encode_common_attributes: bool = False) -> None: self._encode_common_attributes = encode_common_attributes def to_failure( - self, - exception: BaseException, - payload_converter: PayloadConverter, - failure: temporalio.api.failure.v1.Failure, + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, ) -> None: """See base class.""" # If already a failure error, use that @@ -818,10 +818,10 @@ def to_failure( failure.stack_trace = "" def _error_to_failure( - self, - error: temporalio.exceptions.FailureError, - payload_converter: PayloadConverter, - failure: temporalio.api.failure.v1.Failure, + self, + error: temporalio.exceptions.FailureError, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, ) -> None: # If there is an underlying proto already, just use that if error.failure: @@ -903,9 +903,9 @@ def _error_to_failure( ) def from_failure( - self, - failure: temporalio.api.failure.v1.Failure, - payload_converter: PayloadConverter, + self, + failure: temporalio.api.failure.v1.Failure, + payload_converter: PayloadConverter, ) -> BaseException: """See base class.""" # If encoded attributes are present and have the fields we expect, @@ -1042,7 +1042,7 @@ def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "failure_converter", self.failure_converter_class()) async def encode( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> List[temporalio.api.common.v1.Payload]: """Encode values into payloads. @@ -1062,9 +1062,9 @@ async def encode( return payloads async def decode( - self, - payloads: Sequence[temporalio.api.common.v1.Payload], - type_hints: Optional[List[Type]] = None, + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """Decode payloads into values. @@ -1081,7 +1081,7 @@ async def decode( return self.payload_converter.from_payloads(payloads, type_hints) async def encode_wrapper( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> temporalio.api.common.v1.Payloads: """:py:meth:`encode` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -1089,9 +1089,9 @@ async def encode_wrapper( return temporalio.api.common.v1.Payloads(payloads=(await self.encode(values))) async def decode_wrapper( - self, - payloads: Optional[temporalio.api.common.v1.Payloads], - type_hints: Optional[List[Type]] = None, + self, + payloads: Optional[temporalio.api.common.v1.Payloads], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """:py:meth:`decode` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -1101,7 +1101,7 @@ async def decode_wrapper( return await self.decode(payloads.payloads, type_hints) async def encode_failure( - self, exception: BaseException, failure: temporalio.api.failure.v1.Failure + self, exception: BaseException, failure: temporalio.api.failure.v1.Failure ) -> None: """Convert and encode failure.""" self.failure_converter.to_failure(exception, self.payload_converter, failure) @@ -1109,7 +1109,7 @@ async def encode_failure( await self.payload_codec.encode_failure(failure) async def decode_failure( - self, failure: temporalio.api.failure.v1.Failure + self, failure: temporalio.api.failure.v1.Failure ) -> BaseException: """Decode and convert failure.""" if self.payload_codec: @@ -1142,10 +1142,10 @@ def default() -> DataConverter: def encode_search_attributes( - attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ], - api: temporalio.api.common.v1.SearchAttributes, + attributes: Union[ + temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes + ], + api: temporalio.api.common.v1.SearchAttributes, ) -> None: """Convert search attributes into an API message. @@ -1167,10 +1167,10 @@ def encode_search_attributes( def encode_typed_search_attribute_value( - key: temporalio.common.SearchAttributeKey[ - temporalio.common.SearchAttributeValueType - ], - value: Optional[temporalio.common.SearchAttributeValue], + key: temporalio.common.SearchAttributeKey[ + temporalio.common.SearchAttributeValueType + ], + value: Optional[temporalio.common.SearchAttributeValue], ) -> temporalio.api.common.v1.Payload: """Convert typed search attribute value into a payload. @@ -1205,7 +1205,7 @@ def encode_typed_search_attribute_value( def encode_search_attribute_values( - vals: temporalio.common.SearchAttributeValues, + vals: temporalio.common.SearchAttributeValues, ) -> temporalio.api.common.v1.Payload: """Convert search attribute values into a payload. @@ -1243,9 +1243,9 @@ def encode_search_attribute_values( def _encode_maybe_typed_search_attributes( - non_typed_attributes: Optional[temporalio.common.SearchAttributes], - typed_attributes: Optional[temporalio.common.TypedSearchAttributes], - api: temporalio.api.common.v1.SearchAttributes, + non_typed_attributes: Optional[temporalio.common.SearchAttributes], + typed_attributes: Optional[temporalio.common.TypedSearchAttributes], + api: temporalio.api.common.v1.SearchAttributes, ) -> None: if non_typed_attributes: if typed_attributes and typed_attributes.search_attributes: @@ -1271,7 +1271,7 @@ def _get_iso_datetime_parser() -> Callable[[str], datetime]: def decode_search_attributes( - api: temporalio.api.common.v1.SearchAttributes, + api: temporalio.api.common.v1.SearchAttributes, ) -> temporalio.common.SearchAttributes: """Decode API search attributes to values. @@ -1300,7 +1300,7 @@ def decode_search_attributes( def decode_typed_search_attributes( - api: temporalio.api.common.v1.SearchAttributes, + api: temporalio.api.common.v1.SearchAttributes, ) -> temporalio.common.TypedSearchAttributes: """Decode API search attributes to typed search attributes. @@ -1327,16 +1327,16 @@ def decode_typed_search_attributes( # If the value is a list but the type is not keyword list, pull out # single item or consider this an invalid value and ignore if ( - key.indexed_value_type - != temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST - and isinstance(val, list) + key.indexed_value_type + != temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST + and isinstance(val, list) ): if len(val) != 1: continue val = val[0] if ( - key.indexed_value_type - == temporalio.common.SearchAttributeIndexedValueType.DATETIME + key.indexed_value_type + == temporalio.common.SearchAttributeIndexedValueType.DATETIME ): parser = _get_iso_datetime_parser() # We will let this throw @@ -1348,7 +1348,7 @@ def decode_typed_search_attributes( def _decode_search_attribute_value( - payload: temporalio.api.common.v1.Payload, + payload: temporalio.api.common.v1.Payload, ) -> temporalio.common.SearchAttributeValue: val = default().payload_converter.from_payload(payload) if isinstance(val, str) and payload.metadata.get("type") == b"Datetime": @@ -1357,9 +1357,9 @@ def _decode_search_attribute_value( def value_to_type( - hint: Type, - value: Any, - custom_converters: Sequence[JSONTypeConverter] = [], + hint: Type, + value: Any, + custom_converters: Sequence[JSONTypeConverter] = [], ) -> Any: """Convert a given value to the given type hint. @@ -1429,6 +1429,14 @@ def value_to_type( raise TypeError(f"Value {value} not in literal values {type_args}") return value + # Has temporal_from_json method + from_json = "temporal_from_json" + if hasattr(hint, from_json): + attr = getattr(hint, from_json, None) + if not callable(attr, None) or not getattr(attr, "__self__", None) is hint: + raise TypeError(f"Type {hint}: temporal_from_json must be a class method") + return attr(value) + is_union = origin is Union if sys.version_info >= (3, 10): is_union = is_union or isinstance(origin, UnionType) @@ -1452,21 +1460,21 @@ def value_to_type( # and therefore can extract per-key types per_key_types: Optional[Dict[str, Type]] = None if getattr(origin, "__required_keys__", None) or getattr( - origin, "__optional_keys__", None + origin, "__optional_keys__", None ): per_key_types = get_type_hints(origin) key_type = ( type_args[0] if len(type_args) > 0 - and type_args[0] is not Any - and not isinstance(type_args[0], TypeVar) + and type_args[0] is not Any + and not isinstance(type_args[0], TypeVar) else None ) value_type = ( type_args[1] if len(type_args) > 1 - and type_args[1] is not Any - and not isinstance(type_args[1], TypeVar) + and type_args[1] is not Any + and not isinstance(type_args[1], TypeVar) else None ) # Convert each key/value @@ -1518,7 +1526,7 @@ def value_to_type( # attempted instantiation of the dataclass raise if a field is # missing if field_value is not dataclasses.MISSING and not field.metadata.get( - "skip", False + "skip", False ): try: field_values[field.name] = value_to_type( @@ -1541,7 +1549,7 @@ def value_to_type( # compatibility with pydantic v1 users, but this is deprecated. parse_obj_attr = inspect.getattr_static(hint, "parse_obj", None) if isinstance(parse_obj_attr, classmethod) or isinstance( - parse_obj_attr, staticmethod + parse_obj_attr, staticmethod ): if not isinstance(value, dict): raise TypeError( @@ -1577,8 +1585,8 @@ def value_to_type( ret_list = [] # If there is no type arg, just return value as is if not type_args or ( - len(type_args) == 1 - and (isinstance(type_args[0], TypeVar) or type_args[0] is Ellipsis) + len(type_args) == 1 + and (isinstance(type_args[0], TypeVar) or type_args[0] is Ellipsis) ): ret_list = list(value) else: From d68fee1f567c289af7f57ab9ca2b508bda755c9d Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 16 Apr 2025 18:27:05 -0700 Subject: [PATCH 05/12] Renamed to to_json, from_json. Added unit test. --- temporalio/converter.py | 243 ++++++++++++++++++---------------- tests/worker/test_workflow.py | 36 +++-- 2 files changed, 154 insertions(+), 125 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 2bc08c59..dc118b3c 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -68,7 +68,7 @@ class PayloadConverter(ABC): @abstractmethod def to_payloads( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> List[temporalio.api.common.v1.Payload]: """Encode values into payloads. @@ -90,9 +90,9 @@ def to_payloads( @abstractmethod def from_payloads( - self, - payloads: Sequence[temporalio.api.common.v1.Payload], - type_hints: Optional[List[Type]] = None, + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """Decode payloads into values. @@ -116,7 +116,7 @@ def from_payloads( raise NotImplementedError def to_payloads_wrapper( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> temporalio.api.common.v1.Payloads: """:py:meth:`to_payloads` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -124,7 +124,7 @@ def to_payloads_wrapper( return temporalio.api.common.v1.Payloads(payloads=self.to_payloads(values)) def from_payloads_wrapper( - self, payloads: Optional[temporalio.api.common.v1.Payloads] + self, payloads: Optional[temporalio.api.common.v1.Payloads] ) -> List[Any]: """:py:meth:`from_payloads` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -152,15 +152,15 @@ def from_payload(self, payload: temporalio.api.common.v1.Payload) -> Any: ... @overload def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Type[temporalio.types.AnyType], + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Type[temporalio.types.AnyType], ) -> temporalio.types.AnyType: ... def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """Convert a single payload to a value. @@ -205,9 +205,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: @abstractmethod def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """Decode a single payload to a Python value or raise exception. @@ -249,7 +249,7 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None: self.converters = {c.encoding.encode(): c for c in converters} def to_payloads( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> List[temporalio.api.common.v1.Payload]: """Encode values trying each converter. @@ -279,9 +279,9 @@ def to_payloads( return payloads def from_payloads( - self, - payloads: Sequence[temporalio.api.common.v1.Payload], - type_hints: Optional[List[Type]] = None, + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """Decode values trying each converter. @@ -346,9 +346,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" if len(payload.data) > 0: @@ -373,9 +373,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" return payload.data @@ -405,8 +405,8 @@ def encoding(self) -> str: def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: """See base class.""" if ( - isinstance(value, google.protobuf.message.Message) - and value.DESCRIPTOR is not None + isinstance(value, google.protobuf.message.Message) + and value.DESCRIPTOR is not None ): # We have to convert to dict then to JSON because MessageToJson does # not have a compact option removing spaces and newlines @@ -425,9 +425,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" message_type = payload.metadata.get("messageType", b"").decode() @@ -455,8 +455,8 @@ def encoding(self) -> str: def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: """See base class.""" if ( - isinstance(value, google.protobuf.message.Message) - and value.DESCRIPTOR is not None + isinstance(value, google.protobuf.message.Message) + and value.DESCRIPTOR is not None ): return temporalio.api.common.v1.Payload( metadata={ @@ -468,9 +468,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: return None def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" message_type = payload.metadata.get("messageType", b"").decode() @@ -499,6 +499,16 @@ def default(self, o: Any) -> Any: See :py:meth:`json.JSONEncoder.default`. """ + # Custom encoding and decoding through temporal_to_json and temporal_from_json + to_json = "temporal_to_json" + if hasattr(o, to_json): + attr = getattr(o, to_json) + if not callable(attr): + raise TypeError( + f"Type {o.__class__}: temporal_to_json must be a method" + ) + return attr() + # Dataclass support if dataclasses.is_dataclass(o): return dataclasses.asdict(o) @@ -531,12 +541,12 @@ class JSONPlainPayloadConverter(EncodingPayloadConverter): _encoding: str def __init__( - self, - *, - encoder: Optional[Type[json.JSONEncoder]] = AdvancedJSONEncoder, - decoder: Optional[Type[json.JSONDecoder]] = None, - encoding: str = "json/plain", - custom_type_converters: Sequence[JSONTypeConverter] = [], + self, + *, + encoder: Optional[Type[json.JSONEncoder]] = AdvancedJSONEncoder, + decoder: Optional[Type[json.JSONDecoder]] = None, + encoding: str = "json/plain", + custom_type_converters: Sequence[JSONTypeConverter] = [], ) -> None: """Initialize a JSON data converter. @@ -575,9 +585,9 @@ def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]: ) def from_payload( - self, - payload: temporalio.api.common.v1.Payload, - type_hint: Optional[Type] = None, + self, + payload: temporalio.api.common.v1.Payload, + type_hint: Optional[Type] = None, ) -> Any: """See base class.""" try: @@ -604,7 +614,7 @@ class JSONTypeConverter(ABC): @abstractmethod def to_typed_value( - self, hint: Type, value: Any + self, hint: Type, value: Any ) -> Union[Optional[Any], _JSONTypeConverterUnhandled]: """Convert the given value to a type based on the given hint. @@ -628,7 +638,7 @@ class PayloadCodec(ABC): @abstractmethod async def encode( - self, payloads: Sequence[temporalio.api.common.v1.Payload] + self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> List[temporalio.api.common.v1.Payload]: """Encode the given payloads. @@ -644,7 +654,7 @@ async def encode( @abstractmethod async def decode( - self, payloads: Sequence[temporalio.api.common.v1.Payload] + self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> List[temporalio.api.common.v1.Payload]: """Decode the given payloads. @@ -689,9 +699,9 @@ async def decode_failure(self, failure: temporalio.api.failure.v1.Failure) -> No await self._apply_to_failure_payloads(failure, self.decode_wrapper) async def _apply_to_failure_payloads( - self, - failure: temporalio.api.failure.v1.Failure, - cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], + self, + failure: temporalio.api.failure.v1.Failure, + cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], ) -> None: if failure.HasField("encoded_attributes"): # Wrap in payloads and merge back @@ -701,19 +711,19 @@ async def _apply_to_failure_payloads( await cb(payloads) failure.encoded_attributes.CopyFrom(payloads.payloads[0]) if failure.HasField( - "application_failure_info" + "application_failure_info" ) and failure.application_failure_info.HasField("details"): await cb(failure.application_failure_info.details) elif failure.HasField( - "timeout_failure_info" + "timeout_failure_info" ) and failure.timeout_failure_info.HasField("last_heartbeat_details"): await cb(failure.timeout_failure_info.last_heartbeat_details) elif failure.HasField( - "canceled_failure_info" + "canceled_failure_info" ) and failure.canceled_failure_info.HasField("details"): await cb(failure.canceled_failure_info.details) elif failure.HasField( - "reset_workflow_failure_info" + "reset_workflow_failure_info" ) and failure.reset_workflow_failure_info.HasField("last_heartbeat_details"): await cb(failure.reset_workflow_failure_info.last_heartbeat_details) if failure.HasField("cause"): @@ -734,10 +744,10 @@ class FailureConverter(ABC): @abstractmethod def to_failure( - self, - exception: BaseException, - payload_converter: PayloadConverter, - failure: temporalio.api.failure.v1.Failure, + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, ) -> None: """Convert the given exception to a Temporal failure. @@ -752,9 +762,9 @@ def to_failure( @abstractmethod def from_failure( - self, - failure: temporalio.api.failure.v1.Failure, - payload_converter: PayloadConverter, + self, + failure: temporalio.api.failure.v1.Failure, + payload_converter: PayloadConverter, ) -> BaseException: """Convert the given Temporal failure to an exception. @@ -789,10 +799,10 @@ def __init__(self, *, encode_common_attributes: bool = False) -> None: self._encode_common_attributes = encode_common_attributes def to_failure( - self, - exception: BaseException, - payload_converter: PayloadConverter, - failure: temporalio.api.failure.v1.Failure, + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, ) -> None: """See base class.""" # If already a failure error, use that @@ -818,10 +828,10 @@ def to_failure( failure.stack_trace = "" def _error_to_failure( - self, - error: temporalio.exceptions.FailureError, - payload_converter: PayloadConverter, - failure: temporalio.api.failure.v1.Failure, + self, + error: temporalio.exceptions.FailureError, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, ) -> None: # If there is an underlying proto already, just use that if error.failure: @@ -903,9 +913,9 @@ def _error_to_failure( ) def from_failure( - self, - failure: temporalio.api.failure.v1.Failure, - payload_converter: PayloadConverter, + self, + failure: temporalio.api.failure.v1.Failure, + payload_converter: PayloadConverter, ) -> BaseException: """See base class.""" # If encoded attributes are present and have the fields we expect, @@ -1042,7 +1052,7 @@ def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "failure_converter", self.failure_converter_class()) async def encode( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> List[temporalio.api.common.v1.Payload]: """Encode values into payloads. @@ -1062,9 +1072,9 @@ async def encode( return payloads async def decode( - self, - payloads: Sequence[temporalio.api.common.v1.Payload], - type_hints: Optional[List[Type]] = None, + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """Decode payloads into values. @@ -1081,7 +1091,7 @@ async def decode( return self.payload_converter.from_payloads(payloads, type_hints) async def encode_wrapper( - self, values: Sequence[Any] + self, values: Sequence[Any] ) -> temporalio.api.common.v1.Payloads: """:py:meth:`encode` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -1089,9 +1099,9 @@ async def encode_wrapper( return temporalio.api.common.v1.Payloads(payloads=(await self.encode(values))) async def decode_wrapper( - self, - payloads: Optional[temporalio.api.common.v1.Payloads], - type_hints: Optional[List[Type]] = None, + self, + payloads: Optional[temporalio.api.common.v1.Payloads], + type_hints: Optional[List[Type]] = None, ) -> List[Any]: """:py:meth:`decode` for the :py:class:`temporalio.api.common.v1.Payloads` wrapper. @@ -1101,7 +1111,7 @@ async def decode_wrapper( return await self.decode(payloads.payloads, type_hints) async def encode_failure( - self, exception: BaseException, failure: temporalio.api.failure.v1.Failure + self, exception: BaseException, failure: temporalio.api.failure.v1.Failure ) -> None: """Convert and encode failure.""" self.failure_converter.to_failure(exception, self.payload_converter, failure) @@ -1109,7 +1119,7 @@ async def encode_failure( await self.payload_codec.encode_failure(failure) async def decode_failure( - self, failure: temporalio.api.failure.v1.Failure + self, failure: temporalio.api.failure.v1.Failure ) -> BaseException: """Decode and convert failure.""" if self.payload_codec: @@ -1142,10 +1152,10 @@ def default() -> DataConverter: def encode_search_attributes( - attributes: Union[ - temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes - ], - api: temporalio.api.common.v1.SearchAttributes, + attributes: Union[ + temporalio.common.SearchAttributes, temporalio.common.TypedSearchAttributes + ], + api: temporalio.api.common.v1.SearchAttributes, ) -> None: """Convert search attributes into an API message. @@ -1167,10 +1177,10 @@ def encode_search_attributes( def encode_typed_search_attribute_value( - key: temporalio.common.SearchAttributeKey[ - temporalio.common.SearchAttributeValueType - ], - value: Optional[temporalio.common.SearchAttributeValue], + key: temporalio.common.SearchAttributeKey[ + temporalio.common.SearchAttributeValueType + ], + value: Optional[temporalio.common.SearchAttributeValue], ) -> temporalio.api.common.v1.Payload: """Convert typed search attribute value into a payload. @@ -1205,7 +1215,7 @@ def encode_typed_search_attribute_value( def encode_search_attribute_values( - vals: temporalio.common.SearchAttributeValues, + vals: temporalio.common.SearchAttributeValues, ) -> temporalio.api.common.v1.Payload: """Convert search attribute values into a payload. @@ -1243,9 +1253,9 @@ def encode_search_attribute_values( def _encode_maybe_typed_search_attributes( - non_typed_attributes: Optional[temporalio.common.SearchAttributes], - typed_attributes: Optional[temporalio.common.TypedSearchAttributes], - api: temporalio.api.common.v1.SearchAttributes, + non_typed_attributes: Optional[temporalio.common.SearchAttributes], + typed_attributes: Optional[temporalio.common.TypedSearchAttributes], + api: temporalio.api.common.v1.SearchAttributes, ) -> None: if non_typed_attributes: if typed_attributes and typed_attributes.search_attributes: @@ -1271,7 +1281,7 @@ def _get_iso_datetime_parser() -> Callable[[str], datetime]: def decode_search_attributes( - api: temporalio.api.common.v1.SearchAttributes, + api: temporalio.api.common.v1.SearchAttributes, ) -> temporalio.common.SearchAttributes: """Decode API search attributes to values. @@ -1300,7 +1310,7 @@ def decode_search_attributes( def decode_typed_search_attributes( - api: temporalio.api.common.v1.SearchAttributes, + api: temporalio.api.common.v1.SearchAttributes, ) -> temporalio.common.TypedSearchAttributes: """Decode API search attributes to typed search attributes. @@ -1327,16 +1337,16 @@ def decode_typed_search_attributes( # If the value is a list but the type is not keyword list, pull out # single item or consider this an invalid value and ignore if ( - key.indexed_value_type - != temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST - and isinstance(val, list) + key.indexed_value_type + != temporalio.common.SearchAttributeIndexedValueType.KEYWORD_LIST + and isinstance(val, list) ): if len(val) != 1: continue val = val[0] if ( - key.indexed_value_type - == temporalio.common.SearchAttributeIndexedValueType.DATETIME + key.indexed_value_type + == temporalio.common.SearchAttributeIndexedValueType.DATETIME ): parser = _get_iso_datetime_parser() # We will let this throw @@ -1348,7 +1358,7 @@ def decode_typed_search_attributes( def _decode_search_attribute_value( - payload: temporalio.api.common.v1.Payload, + payload: temporalio.api.common.v1.Payload, ) -> temporalio.common.SearchAttributeValue: val = default().payload_converter.from_payload(payload) if isinstance(val, str) and payload.metadata.get("type") == b"Datetime": @@ -1357,9 +1367,9 @@ def _decode_search_attribute_value( def value_to_type( - hint: Type, - value: Any, - custom_converters: Sequence[JSONTypeConverter] = [], + hint: Type, + value: Any, + custom_converters: Sequence[JSONTypeConverter] = [], ) -> Any: """Convert a given value to the given type hint. @@ -1432,8 +1442,9 @@ def value_to_type( # Has temporal_from_json method from_json = "temporal_from_json" if hasattr(hint, from_json): - attr = getattr(hint, from_json, None) - if not callable(attr, None) or not getattr(attr, "__self__", None) is hint: + attr = getattr(hint, from_json) + attrCls = getattr(attr, "__self__") + if not callable(attr) or not attrCls == origin: raise TypeError(f"Type {hint}: temporal_from_json must be a class method") return attr(value) @@ -1460,21 +1471,21 @@ def value_to_type( # and therefore can extract per-key types per_key_types: Optional[Dict[str, Type]] = None if getattr(origin, "__required_keys__", None) or getattr( - origin, "__optional_keys__", None + origin, "__optional_keys__", None ): per_key_types = get_type_hints(origin) key_type = ( type_args[0] if len(type_args) > 0 - and type_args[0] is not Any - and not isinstance(type_args[0], TypeVar) + and type_args[0] is not Any + and not isinstance(type_args[0], TypeVar) else None ) value_type = ( type_args[1] if len(type_args) > 1 - and type_args[1] is not Any - and not isinstance(type_args[1], TypeVar) + and type_args[1] is not Any + and not isinstance(type_args[1], TypeVar) else None ) # Convert each key/value @@ -1526,7 +1537,7 @@ def value_to_type( # attempted instantiation of the dataclass raise if a field is # missing if field_value is not dataclasses.MISSING and not field.metadata.get( - "skip", False + "skip", False ): try: field_values[field.name] = value_to_type( @@ -1549,7 +1560,7 @@ def value_to_type( # compatibility with pydantic v1 users, but this is deprecated. parse_obj_attr = inspect.getattr_static(hint, "parse_obj", None) if isinstance(parse_obj_attr, classmethod) or isinstance( - parse_obj_attr, staticmethod + parse_obj_attr, staticmethod ): if not isinstance(value, dict): raise TypeError( @@ -1585,8 +1596,8 @@ def value_to_type( ret_list = [] # If there is no type arg, just return value as is if not type_args or ( - len(type_args) == 1 - and (isinstance(type_args[0], TypeVar) or type_args[0] is Ellipsis) + len(type_args) == 1 + and (isinstance(type_args[0], TypeVar) or type_args[0] is Ellipsis) ): ret_list = list(value) else: diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 08cba4a4..04ece679 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -2251,10 +2251,21 @@ class MyGenericDataClass(typing.Generic[T]): field1: str field2: T = dataclasses.field(metadata={"skip": True}, default=None) - def assert_expected(self) -> None: + def __init__(self, field1: str): + self.field1 = field1 + + @classmethod + def temporal_from_json(cls, json_obj: Dict[str, Any]) -> MyGenericDataClass: + json_obj["field1"] = json_obj["field1"] + "_from_json" + return cls(**json_obj) + + def temporal_to_json(self) -> Dict[str, Any]: + return {"field1": self.field1 + "_to_json"} + + def assert_expected(self, value: str) -> None: # Part of the assertion is that this is the right type, which is # confirmed just by calling the method. We also check the field. - assert str(self.field1) == "some value2" + assert str(self.field1) == value @activity.defn @@ -2267,7 +2278,6 @@ async def data_class_typed_activity(param: MyDataClass) -> MyDataClass: async def generic_data_class_typed_activity( param: MyGenericDataClass[str], ) -> MyGenericDataClass[str]: - param.assert_expected() return param @@ -2328,20 +2338,24 @@ async def run(self, param: MyDataClass) -> MyDataClass: start_to_close_timeout=timedelta(seconds=30), ) param.assert_expected() - generic_param = MyGenericDataClass[str]("some value2") + generic_param = MyGenericDataClass[str]("some_value2") generic_param = await workflow.execute_activity( generic_data_class_typed_activity, generic_param, start_to_close_timeout=timedelta(seconds=30), ) - generic_param.assert_expected() + generic_param.assert_expected( + "some_value2_to_json_from_json_to_json_from_json" + ) + generic_param = MyGenericDataClass[str]("some_value2") generic_param = await workflow.execute_local_activity( generic_data_class_typed_activity, generic_param, start_to_close_timeout=timedelta(seconds=30), ) - generic_param.assert_expected() - + generic_param.assert_expected( + "some_value2_to_json_from_json_to_json_from_json" + ) child_handle = await workflow.start_child_workflow( DataClassTypedWorkflow.run, param, @@ -2411,7 +2425,9 @@ async def test_workflow_separate_protocol(client: Client): # This test is to confirm that protocols can be used as "interfaces" for # when the workflow impl is absent async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, + DataClassTypedWorkflow, + activities=[data_class_typed_activity, generic_data_class_typed_activity], ) as worker: assert isinstance(DataClassTypedWorkflow(), DataClassTypedWorkflowProto) val = MyDataClass(field1="some value") @@ -2433,7 +2449,9 @@ async def test_workflow_separate_abstract(client: Client): # This test is to confirm that abstract classes can be used as "interfaces" # for when the workflow impl is absent async with new_worker( - client, DataClassTypedWorkflow, activities=[data_class_typed_activity] + client, + DataClassTypedWorkflow, + activities=[data_class_typed_activity, generic_data_class_typed_activity], ) as worker: assert issubclass(DataClassTypedWorkflow, DataClassTypedWorkflowAbstract) val = MyDataClass(field1="some value") From 5deb4fbc2982d53abd5976718700a34c453333f2 Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 16 Apr 2025 19:01:03 -0700 Subject: [PATCH 06/12] Renamed to to_json, from_json. Added unit test. --- temporalio/converter.py | 69 +++++++++++++++++++++++++---------- tests/worker/test_workflow.py | 38 ++++++++++--------- 2 files changed, 69 insertions(+), 38 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index dc118b3c..495fc7da 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -489,6 +489,41 @@ class AdvancedJSONEncoder(json.JSONEncoder): This encoder supports dataclasses and all iterables as lists. + A class can implement to_json and from_json methods to support custom conversion logic. + These methods should have the following signatures: + + .. code-block:: python + + class MyClass: + ... + + @classmethod + def from_json(cls, json: Any) -> MyClass: + ... + + def to_json(self) -> Any: + ... + + The to_json should return the same Python JSON types produced by JSONEncoder: + + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str | string | + +-------------------+---------------+ + | int, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ + It also uses Pydantic v1's "dict" methods if available on the object, but this is deprecated. Pydantic users should upgrade to v2 and use temporalio.contrib.pydantic.pydantic_data_converter. @@ -499,8 +534,9 @@ def default(self, o: Any) -> Any: See :py:meth:`json.JSONEncoder.default`. """ - # Custom encoding and decoding through temporal_to_json and temporal_from_json - to_json = "temporal_to_json" + # Custom encoding and decoding through to_json and from_json + # to_json should be an instance method with only self argument + to_json = "to_json" if hasattr(o, to_json): attr = getattr(o, to_json) if not callable(attr): @@ -1439,12 +1475,12 @@ def value_to_type( raise TypeError(f"Value {value} not in literal values {type_args}") return value - # Has temporal_from_json method - from_json = "temporal_from_json" + # Has from_json class method (must have to_json as well) + from_json = "from_json" if hasattr(hint, from_json): attr = getattr(hint, from_json) - attrCls = getattr(attr, "__self__") - if not callable(attr) or not attrCls == origin: + attr_cls = getattr(attr, "__self__") + if not callable(attr) or not attr_cls == origin: raise TypeError(f"Type {hint}: temporal_from_json must be a class method") return attr(value) @@ -1515,42 +1551,35 @@ def value_to_type( return ret_dict # Dataclass - h = hint if dataclasses.is_dataclass(hint) else None - # This allows for generic dataclasses to be passed in. - # Note that the field of a generic parameter type is still not deserializable. - # Such fields must be marked with dataclasses.field(metadata={"skip": True}, default=...). - h = origin if h is None and dataclasses.is_dataclass(origin) else h - if h is not None: + if dataclasses.is_dataclass(hint): if not isinstance(value, dict): raise TypeError( - f"Cannot convert to dataclass {h}, value is {type(value)} not dict" + f"Cannot convert to dataclass {hint}, value is {type(value)} not dict" ) # Obtain dataclass fields and check that all dict fields are there and # that no required fields are missing. Unknown fields are silently # ignored. - fields = dataclasses.fields(h) - field_hints = get_type_hints(h) + fields = dataclasses.fields(hint) + field_hints = get_type_hints(hint) field_values = {} for field in fields: field_value = value.get(field.name, dataclasses.MISSING) # We do not check whether field is required here. Rather, we let the # attempted instantiation of the dataclass raise if a field is # missing - if field_value is not dataclasses.MISSING and not field.metadata.get( - "skip", False - ): + if field_value is not dataclasses.MISSING: try: field_values[field.name] = value_to_type( field_hints[field.name], field_value, custom_converters ) except Exception as err: raise TypeError( - f"Failed converting field {field.name} on dataclass {h}" + f"Failed converting field {field.name} on dataclass {hint}" ) from err # Simply instantiate the dataclass. This will fail as expected when # missing required fields. # TODO(cretz): Want way to convert snake case to camel case? - return h(**field_values) + return hint(**field_values) # Pydantic model instance # Pydantic users should use Pydantic v2 with diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 04ece679..ababa453 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -2246,21 +2246,23 @@ def assert_expected(self) -> None: T = typing.TypeVar("T") -@dataclass -class MyGenericDataClass(typing.Generic[T]): +class MyGenericClass(typing.Generic[T]): + """ + Demonstrates custom conversion and that it works even with generic classes. + """ + field1: str - field2: T = dataclasses.field(metadata={"skip": True}, default=None) + field2: str = None def __init__(self, field1: str): self.field1 = field1 @classmethod - def temporal_from_json(cls, json_obj: Dict[str, Any]) -> MyGenericDataClass: - json_obj["field1"] = json_obj["field1"] + "_from_json" - return cls(**json_obj) + def from_json(cls, json_obj: Any) -> MyGenericClass: + return MyGenericClass(str(json_obj) + "_from_json") - def temporal_to_json(self) -> Dict[str, Any]: - return {"field1": self.field1 + "_to_json"} + def to_json(self) -> Any: + return self.field1 + "_to_json" def assert_expected(self, value: str) -> None: # Part of the assertion is that this is the right type, which is @@ -2275,9 +2277,9 @@ async def data_class_typed_activity(param: MyDataClass) -> MyDataClass: @activity.defn -async def generic_data_class_typed_activity( - param: MyGenericDataClass[str], -) -> MyGenericDataClass[str]: +async def generic_class_typed_activity( + param: MyGenericClass[str], +) -> MyGenericClass[str]: return param @@ -2338,18 +2340,18 @@ async def run(self, param: MyDataClass) -> MyDataClass: start_to_close_timeout=timedelta(seconds=30), ) param.assert_expected() - generic_param = MyGenericDataClass[str]("some_value2") + generic_param = MyGenericClass[str]("some_value2") generic_param = await workflow.execute_activity( - generic_data_class_typed_activity, + generic_class_typed_activity, generic_param, start_to_close_timeout=timedelta(seconds=30), ) generic_param.assert_expected( "some_value2_to_json_from_json_to_json_from_json" ) - generic_param = MyGenericDataClass[str]("some_value2") + generic_param = MyGenericClass[str]("some_value2") generic_param = await workflow.execute_local_activity( - generic_data_class_typed_activity, + generic_class_typed_activity, generic_param, start_to_close_timeout=timedelta(seconds=30), ) @@ -2400,7 +2402,7 @@ async def test_workflow_dataclass_typed(client: Client, env: WorkflowEnvironment async with new_worker( client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_data_class_typed_activity], + activities=[data_class_typed_activity, generic_class_typed_activity], ) as worker: val = MyDataClass(field1="some value") handle = await client.start_workflow( @@ -2427,7 +2429,7 @@ async def test_workflow_separate_protocol(client: Client): async with new_worker( client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_data_class_typed_activity], + activities=[data_class_typed_activity, generic_class_typed_activity], ) as worker: assert isinstance(DataClassTypedWorkflow(), DataClassTypedWorkflowProto) val = MyDataClass(field1="some value") @@ -2451,7 +2453,7 @@ async def test_workflow_separate_abstract(client: Client): async with new_worker( client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_data_class_typed_activity], + activities=[data_class_typed_activity, generic_class_typed_activity], ) as worker: assert issubclass(DataClassTypedWorkflow, DataClassTypedWorkflowAbstract) val = MyDataClass(field1="some value") From e4a27ea17e8c2d1321d2a2cbd4e5ecf4d966be37 Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 16 Apr 2025 19:04:07 -0700 Subject: [PATCH 07/12] Updated comment --- temporalio/converter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/converter.py b/temporalio/converter.py index 495fc7da..0cd37573 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -490,6 +490,7 @@ class AdvancedJSONEncoder(json.JSONEncoder): This encoder supports dataclasses and all iterables as lists. A class can implement to_json and from_json methods to support custom conversion logic. + Custom conversion of generic classes is supported. These methods should have the following signatures: .. code-block:: python From 2cfa0744d3137c08d18e3a3823830936d8a211ba Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 16 Apr 2025 19:18:50 -0700 Subject: [PATCH 08/12] Fixed lint error --- tests/worker/test_workflow.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 35da6b04..146e4a13 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -2251,11 +2251,9 @@ class MyGenericClass(typing.Generic[T]): Demonstrates custom conversion and that it works even with generic classes. """ - field1: str - field2: str = None - def __init__(self, field1: str): self.field1 = field1 + self.field2 = "foo" @classmethod def from_json(cls, json_obj: Any) -> MyGenericClass: From 82073c63418d2a2d9cd3ec62b0468dee513bb685 Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Wed, 16 Apr 2025 19:39:25 -0700 Subject: [PATCH 09/12] empty to force build --- temporalio/converter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/converter.py b/temporalio/converter.py index 0cd37573..19a2aa19 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -525,6 +525,7 @@ def to_json(self) -> Any: | None | null | +-------------------+---------------+ + It also uses Pydantic v1's "dict" methods if available on the object, but this is deprecated. Pydantic users should upgrade to v2 and use temporalio.contrib.pydantic.pydantic_data_converter. From 34a360b4efee8d68c4c61d6d1c9ed55c491abf92 Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Thu, 17 Apr 2025 08:52:39 -0700 Subject: [PATCH 10/12] Fixed method name in the error message. --- temporalio/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 19a2aa19..36e20d57 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -543,7 +543,7 @@ def default(self, o: Any) -> Any: attr = getattr(o, to_json) if not callable(attr): raise TypeError( - f"Type {o.__class__}: temporal_to_json must be a method" + f"Type {o.__class__}: to_json must be a method" ) return attr() From 84528fb07644c573e0a5b6365c7e5607f668091a Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Thu, 17 Apr 2025 09:08:11 -0700 Subject: [PATCH 11/12] fixed format --- temporalio/converter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 36e20d57..576a25ce 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -542,9 +542,7 @@ def default(self, o: Any) -> Any: if hasattr(o, to_json): attr = getattr(o, to_json) if not callable(attr): - raise TypeError( - f"Type {o.__class__}: to_json must be a method" - ) + raise TypeError(f"Type {o.__class__}: to_json must be a method") return attr() # Dataclass support From 981878bc82ceab41497ea86bea7035f82926d023 Mon Sep 17 00:00:00 2001 From: Maxim Fateev Date: Sat, 19 Apr 2025 12:19:32 -0700 Subject: [PATCH 12/12] PR Feedback addressed. --- README.md | 46 ++++++++++++++++++ temporalio/converter.py | 89 ++++++++++++++++++----------------- tests/worker/test_workflow.py | 61 ++++++++++++++++++++---- 3 files changed, 143 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 4c58f676..c4ab1d43 100644 --- a/README.md +++ b/README.md @@ -446,6 +446,52 @@ my_data_converter = dataclasses.replace( Now `IPv4Address` can be used in type hints including collections, optionals, etc. +When the `JSONPlainPayloadConverter` is used a class can implement `to_temporal_json` and `from_temporal_json` methods to +support custom conversion logic. Custom conversion of generic classes is supported. +These methods should have the following signatures: + +``` + class MyClass: + ... +``` +`from_temporal_json` be either classmethod: +``` + @classmethod + def from_temporal_json(cls, json: Any) -> MyClass: + ... +``` + or static method: + ``` + @staticmethod + def from_temporal_json(json: Any) -> MyClass: + ... +``` +`to_temporal_json` is always an instance method: +``` + def to_temporal_json(self) -> Any: + ... +``` +The to_json should return the same Python JSON types produced by JSONEncoder: +``` + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str | string | + +-------------------+---------------+ + | int, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ +``` + ### Workers Workers host workflows and/or activities. Here's how to run a worker: diff --git a/temporalio/converter.py b/temporalio/converter.py index 576a25ce..c7143dd4 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -489,43 +489,6 @@ class AdvancedJSONEncoder(json.JSONEncoder): This encoder supports dataclasses and all iterables as lists. - A class can implement to_json and from_json methods to support custom conversion logic. - Custom conversion of generic classes is supported. - These methods should have the following signatures: - - .. code-block:: python - - class MyClass: - ... - - @classmethod - def from_json(cls, json: Any) -> MyClass: - ... - - def to_json(self) -> Any: - ... - - The to_json should return the same Python JSON types produced by JSONEncoder: - - +-------------------+---------------+ - | Python | JSON | - +===================+===============+ - | dict | object | - +-------------------+---------------+ - | list, tuple | array | - +-------------------+---------------+ - | str | string | - +-------------------+---------------+ - | int, float | number | - +-------------------+---------------+ - | True | true | - +-------------------+---------------+ - | False | false | - +-------------------+---------------+ - | None | null | - +-------------------+---------------+ - - It also uses Pydantic v1's "dict" methods if available on the object, but this is deprecated. Pydantic users should upgrade to v2 and use temporalio.contrib.pydantic.pydantic_data_converter. @@ -538,11 +501,11 @@ def default(self, o: Any) -> Any: """ # Custom encoding and decoding through to_json and from_json # to_json should be an instance method with only self argument - to_json = "to_json" + to_json = "to_temporal_json" if hasattr(o, to_json): attr = getattr(o, to_json) if not callable(attr): - raise TypeError(f"Type {o.__class__}: to_json must be a method") + raise TypeError(f"Type {o.__class__}: {to_json} must be a method") return attr() # Dataclass support @@ -570,6 +533,44 @@ class JSONPlainPayloadConverter(EncodingPayloadConverter): For decoding, this uses type hints to attempt to rebuild the type from the type hint. + + A class can implement to_json and from_temporal_json methods to support custom conversion logic. + Custom conversion of generic classes is supported. + These methods should have the following signatures: + + .. code-block:: python + + class MyClass: + ... + + @classmethod + def from_temporal_json(cls, json: Any) -> MyClass: + ... + + def to_temporal_json(self) -> Any: + ... + + The to_json should return the same Python JSON types produced by JSONEncoder: + + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str | string | + +-------------------+---------------+ + | int, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ + + """ _encoder: Optional[Type[json.JSONEncoder]] @@ -1476,12 +1477,14 @@ def value_to_type( return value # Has from_json class method (must have to_json as well) - from_json = "from_json" + from_json = "from_temporal_json" if hasattr(hint, from_json): attr = getattr(hint, from_json) - attr_cls = getattr(attr, "__self__") - if not callable(attr) or not attr_cls == origin: - raise TypeError(f"Type {hint}: temporal_from_json must be a class method") + attr_cls = getattr(attr, "__self__", None) + if not callable(attr) or (attr_cls is not None and attr_cls is not origin): + raise TypeError( + f"Type {hint}: {from_json} must be a staticmethod or classmethod" + ) return attr(value) is_union = origin is Union diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 146e4a13..bad025a8 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -2256,10 +2256,32 @@ def __init__(self, field1: str): self.field2 = "foo" @classmethod - def from_json(cls, json_obj: Any) -> MyGenericClass: + def from_temporal_json(cls, json_obj: Any) -> MyGenericClass: return MyGenericClass(str(json_obj) + "_from_json") - def to_json(self) -> Any: + def to_temporal_json(self) -> Any: + return self.field1 + "_to_json" + + def assert_expected(self, value: str) -> None: + # Part of the assertion is that this is the right type, which is + # confirmed just by calling the method. We also check the field. + assert str(self.field1) == value + + +class MyGenericClassWithStatic(typing.Generic[T]): + """ + Demonstrates custom conversion and that it works even with generic classes. + """ + + def __init__(self, field1: str): + self.field1 = field1 + self.field2 = "foo" + + @staticmethod + def from_temporal_json(json_obj: Any) -> MyGenericClass: + return MyGenericClass(str(json_obj) + "_from_json") + + def to_temporal_json(self) -> Any: return self.field1 + "_to_json" def assert_expected(self, value: str) -> None: @@ -2281,6 +2303,13 @@ async def generic_class_typed_activity( return param +@activity.defn +async def generic_class_typed_activity_with_static( + param: MyGenericClassWithStatic[str], +) -> MyGenericClassWithStatic[str]: + return param + + @runtime_checkable @workflow.defn(name="DataClassTypedWorkflow") class DataClassTypedWorkflowProto(Protocol): @@ -2347,13 +2376,13 @@ async def run(self, param: MyDataClass) -> MyDataClass: generic_param.assert_expected( "some_value2_to_json_from_json_to_json_from_json" ) - generic_param = MyGenericClass[str]("some_value2") - generic_param = await workflow.execute_local_activity( - generic_class_typed_activity, - generic_param, + generic_param_s = MyGenericClassWithStatic[str]("some_value2") + generic_param_s = await workflow.execute_local_activity( + generic_class_typed_activity_with_static, + generic_param_s, start_to_close_timeout=timedelta(seconds=30), ) - generic_param.assert_expected( + generic_param_s.assert_expected( "some_value2_to_json_from_json_to_json_from_json" ) child_handle = await workflow.start_child_workflow( @@ -2400,7 +2429,11 @@ async def test_workflow_dataclass_typed(client: Client, env: WorkflowEnvironment async with new_worker( client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_class_typed_activity], + activities=[ + data_class_typed_activity, + generic_class_typed_activity, + generic_class_typed_activity_with_static, + ], ) as worker: val = MyDataClass(field1="some value") handle = await client.start_workflow( @@ -2427,7 +2460,11 @@ async def test_workflow_separate_protocol(client: Client): async with new_worker( client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_class_typed_activity], + activities=[ + data_class_typed_activity, + generic_class_typed_activity, + generic_class_typed_activity_with_static, + ], ) as worker: assert isinstance(DataClassTypedWorkflow(), DataClassTypedWorkflowProto) val = MyDataClass(field1="some value") @@ -2451,7 +2488,11 @@ async def test_workflow_separate_abstract(client: Client): async with new_worker( client, DataClassTypedWorkflow, - activities=[data_class_typed_activity, generic_class_typed_activity], + activities=[ + data_class_typed_activity, + generic_class_typed_activity, + generic_class_typed_activity_with_static, + ], ) as worker: assert issubclass(DataClassTypedWorkflow, DataClassTypedWorkflowAbstract) val = MyDataClass(field1="some value")