Skip to content

Commit ce1bade

Browse files
committed
Add test demonstrating error chain. Ensure that failures are wrapped with a NexusOperationFailureError
1 parent 883c7c7 commit ce1bade

2 files changed

Lines changed: 100 additions & 37 deletions

File tree

temporalio/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10066,8 +10066,10 @@ async def get_nexus_operation_result(
1006610066
return result
1006710067

1006810068
case "failure":
10069-
raise await self._client.data_converter.decode_failure(
10070-
res.failure
10069+
raise NexusOperationFailureError(
10070+
cause=await self._client.data_converter.decode_failure(
10071+
res.failure
10072+
)
1007110073
)
1007210074

1007310075
case None:

tests/nexus/test_standalone_operations.py

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import uuid
1313
from dataclasses import dataclass
1414
from datetime import timedelta
15-
from typing import Any
15+
from typing import Any, Literal
1616

1717
import nexusrpc
1818
import pytest
@@ -47,14 +47,16 @@
4747
WorkflowIDReusePolicy,
4848
)
4949
from temporalio.exceptions import (
50+
ApplicationError,
5051
CancelledError,
5152
NexusOperationAlreadyStartedError,
53+
NexusOperationError,
5254
TerminatedError,
5355
)
5456
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
5557
from temporalio.service import RPCError
5658
from temporalio.testing import WorkflowEnvironment
57-
from temporalio.types import ReturnType
59+
from temporalio.types import ParamType, ReturnType
5860
from temporalio.worker import Worker
5961
from tests.helpers import assert_eventually
6062
from tests.helpers.nexus import make_nexus_endpoint_name
@@ -74,6 +76,11 @@ class EchoOutput:
7476
value: str
7577

7678

79+
@dataclass
80+
class RaiseErrInput:
81+
err_type: Literal["handler_err", "application_err"]
82+
83+
7784
# ---------------------------------------------------------------------------
7885
# Service definition
7986
# ---------------------------------------------------------------------------
@@ -84,6 +91,7 @@ class StandaloneTestService:
8491
echo_sync: nexusrpc.Operation[EchoInput, EchoOutput]
8592
echo_async: nexusrpc.Operation[EchoInput, EchoOutput]
8693
blocking_async: nexusrpc.Operation[EchoInput, EchoOutput]
94+
raise_err: nexusrpc.Operation[RaiseErrInput, None]
8795

8896

8997
# ---------------------------------------------------------------------------
@@ -155,6 +163,20 @@ async def blocking_async(
155163
self.started_blocking.set()
156164
return handle
157165

166+
@sync_operation
167+
async def raise_err(
168+
self, _ctx: StartOperationContext, input: RaiseErrInput
169+
) -> None:
170+
match input.err_type:
171+
case "handler_err":
172+
raise nexusrpc.HandlerError(
173+
"test handler error",
174+
type=nexusrpc.HandlerErrorType.INTERNAL,
175+
retryable_override=False,
176+
)
177+
case "application_err":
178+
raise ApplicationError("test application error", non_retryable=True)
179+
158180

159181
# ---------------------------------------------------------------------------
160182
# Retry helper for endpoint propagation
@@ -163,8 +185,8 @@ async def blocking_async(
163185

164186
async def start_with_retry(
165187
nexus_client: Any,
166-
operation: nexusrpc.Operation[Any, ReturnType],
167-
arg: Any,
188+
operation: nexusrpc.Operation[ParamType, ReturnType],
189+
arg: ParamType,
168190
*,
169191
id: str,
170192
id_reuse_policy: NexusOperationIDReusePolicy = NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
@@ -199,8 +221,8 @@ async def start_with_retry(
199221

200222
async def execute_with_retry(
201223
nexus_client: Any,
202-
operation: nexusrpc.Operation[Any, ReturnType],
203-
arg: Any,
224+
operation: nexusrpc.Operation[ParamType, ReturnType],
225+
arg: ParamType,
204226
*,
205227
id: str,
206228
id_reuse_policy: NexusOperationIDReusePolicy = NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
@@ -238,9 +260,6 @@ async def test_start_sync_operation_and_get_result(
238260
client: Client, env: WorkflowEnvironment
239261
):
240262
"""Start a sync nexus operation, call handle.result(), verify return value."""
241-
if env.supports_time_skipping:
242-
pytest.skip("Nexus tests don't work with time-skipping server")
243-
244263
task_queue = str(uuid.uuid4())
245264
endpoint_name = make_nexus_endpoint_name(task_queue)
246265

@@ -273,9 +292,6 @@ async def test_start_async_operation_and_poll_result(
273292
client: Client, env: WorkflowEnvironment
274293
):
275294
"""Start a workflow_run operation, poll result, verify."""
276-
if env.supports_time_skipping:
277-
pytest.skip("Nexus tests don't work with time-skipping server")
278-
279295
task_queue = str(uuid.uuid4())
280296
endpoint_name = make_nexus_endpoint_name(task_queue)
281297

@@ -304,9 +320,6 @@ async def test_start_async_operation_and_poll_result(
304320

305321
async def test_execute_operation(client: Client, env: WorkflowEnvironment):
306322
"""Use execute_operation convenience method, verify it returns result directly."""
307-
if env.supports_time_skipping:
308-
pytest.skip("Nexus tests don't work with time-skipping server")
309-
310323
task_queue = str(uuid.uuid4())
311324
endpoint_name = make_nexus_endpoint_name(task_queue)
312325

@@ -334,11 +347,61 @@ async def test_execute_operation(client: Client, env: WorkflowEnvironment):
334347
assert result.value == "execute"
335348

336349

350+
async def test_errors(client: Client, env: WorkflowEnvironment):
351+
"""Execute operations that raise errors"""
352+
task_queue = str(uuid.uuid4())
353+
endpoint_name = make_nexus_endpoint_name(task_queue)
354+
355+
async with Worker(
356+
client,
357+
task_queue=task_queue,
358+
nexus_service_handlers=[StandaloneTestServiceHandler()],
359+
workflows=[EchoHandlerWorkflow, BlockingHandlerWorkflow],
360+
):
361+
await env.create_nexus_endpoint(endpoint_name, task_queue)
362+
363+
nexus_client = client.create_nexus_client(
364+
service=StandaloneTestService, endpoint=endpoint_name
365+
)
366+
367+
# Expect temporalio.exceptions.NexusOperationError
368+
handle = await start_with_retry(
369+
nexus_client,
370+
StandaloneTestService.raise_err,
371+
RaiseErrInput("handler_err"),
372+
id=str(uuid.uuid4()),
373+
id_reuse_policy=NexusOperationIDReusePolicy.REJECT_DUPLICATE,
374+
id_conflict_policy=NexusOperationIDConflictPolicy.FAIL,
375+
schedule_to_close_timeout=timedelta(seconds=30),
376+
)
377+
378+
with pytest.raises(NexusOperationFailureError) as err:
379+
await handle.result()
380+
381+
assert err.value.__cause__
382+
assert isinstance(err.value.__cause__, nexusrpc.HandlerError)
383+
384+
handle = await start_with_retry(
385+
nexus_client,
386+
StandaloneTestService.raise_err,
387+
RaiseErrInput("application_err"),
388+
id=str(uuid.uuid4()),
389+
id_reuse_policy=NexusOperationIDReusePolicy.REJECT_DUPLICATE,
390+
id_conflict_policy=NexusOperationIDConflictPolicy.FAIL,
391+
schedule_to_close_timeout=timedelta(seconds=30),
392+
)
393+
394+
with pytest.raises(NexusOperationFailureError) as err:
395+
await handle.result()
396+
397+
assert err.value.__cause__
398+
assert isinstance(err.value.__cause__, nexusrpc.HandlerError)
399+
assert err.value.__cause__.__cause__
400+
assert isinstance(err.value.__cause__.__cause__, ApplicationError)
401+
402+
337403
async def test_describe_operation(client: Client, env: WorkflowEnvironment):
338404
"""Start op, get result first, then describe, verify fields populated."""
339-
if env.supports_time_skipping:
340-
pytest.skip("Nexus tests don't work with time-skipping server")
341-
342405
task_queue = str(uuid.uuid4())
343406
endpoint_name = make_nexus_endpoint_name(task_queue)
344407

@@ -384,10 +447,9 @@ async def test_describe_operation(client: Client, env: WorkflowEnvironment):
384447

385448

386449
async def test_cancel_operation(client: Client, env: WorkflowEnvironment):
387-
"""Start blocking async op, cancel it, verify awaiting result raises CancelledError."""
388-
if env.supports_time_skipping:
389-
pytest.skip("Nexus tests don't work with time-skipping server")
390-
450+
"""Start blocking async op, cancel it, verify awaiting result raises NexusOperationFailureError
451+
from a CancelledError.
452+
"""
391453
task_queue = str(uuid.uuid4())
392454
endpoint_name = make_nexus_endpoint_name(task_queue)
393455

@@ -415,12 +477,17 @@ async def test_cancel_operation(client: Client, env: WorkflowEnvironment):
415477
# Cancel the operation
416478
await handle.cancel()
417479

418-
with pytest.raises(CancelledError):
480+
with pytest.raises(NexusOperationFailureError) as err:
419481
await handle.result()
420482

483+
assert err.value.__cause__
484+
assert isinstance(err.value.__cause__, CancelledError)
485+
421486

422487
async def test_terminate_operation(client: Client, env: WorkflowEnvironment):
423-
"""Start blocking async op, terminate it, verify awaiting the result raises TerminatedError."""
488+
"""Start blocking async op, terminate it, verify awaiting the result raises NexusOperationFailureError
489+
from a TerminatedError.
490+
"""
424491
task_queue = str(uuid.uuid4())
425492
endpoint_name = make_nexus_endpoint_name(task_queue)
426493

@@ -448,15 +515,15 @@ async def test_terminate_operation(client: Client, env: WorkflowEnvironment):
448515
# Terminate the operation
449516
await handle.terminate(reason="test termination")
450517

451-
with pytest.raises(TerminatedError):
518+
with pytest.raises(NexusOperationFailureError) as err:
452519
await handle.result()
453520

521+
assert err.value.__cause__
522+
assert isinstance(err.value.__cause__, TerminatedError)
523+
454524

455525
async def test_list_operations(client: Client, env: WorkflowEnvironment):
456526
"""Start multiple ops, list them, verify iteration yields correct results."""
457-
if env.supports_time_skipping:
458-
pytest.skip("Nexus tests don't work with time-skipping server")
459-
460527
task_queue = str(uuid.uuid4())
461528
endpoint_name = make_nexus_endpoint_name(task_queue)
462529

@@ -540,9 +607,6 @@ async def check_count() -> None:
540607

541608
async def test_get_nexus_operation_handle(client: Client, env: WorkflowEnvironment):
542609
"""Start op, get result, then get handle by ID and get result again."""
543-
if env.supports_time_skipping:
544-
pytest.skip("Nexus tests don't work with time-skipping server")
545-
546610
task_queue = str(uuid.uuid4())
547611
endpoint_name = make_nexus_endpoint_name(task_queue)
548612

@@ -761,9 +825,6 @@ async def test_interceptor_receives_inputs(client: Client, env: WorkflowEnvironm
761825
762826
Also verifies that result() does NOT trigger any interceptor call.
763827
"""
764-
if env.supports_time_skipping:
765-
pytest.skip("Nexus tests don't work with time-skipping server")
766-
767828
task_queue = str(uuid.uuid4())
768829
endpoint_name = make_nexus_endpoint_name(task_queue)
769830

@@ -821,7 +882,7 @@ async def test_interceptor_receives_inputs(client: Client, env: WorkflowEnvironm
821882
assert cancel_input.operation_id == op_id
822883

823884
# GetResult
824-
with pytest.raises(CancelledError):
885+
with pytest.raises(NexusOperationFailureError):
825886
await handle.result()
826887
assert len(interceptor.result_calls) == 1
827888
result_input = interceptor.result_calls[0]

0 commit comments

Comments
 (0)