diff --git a/src/prefect/_internal/retries.py b/src/prefect/_internal/retries.py index 552b40b9c4a7..4e349aa266b2 100644 --- a/src/prefect/_internal/retries.py +++ b/src/prefect/_internal/retries.py @@ -28,6 +28,7 @@ def retry_async_fn( max_delay: float = 10, retry_on_exceptions: tuple[type[Exception], ...] = (Exception,), operation_name: Optional[str] = None, + should_not_retry: Callable[[Exception], bool] = lambda e: False, ) -> Callable[ [Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]] ]: @@ -44,6 +45,11 @@ def retry_async_fn( retrying on all exceptions. operation_name: Optional name to use for logging the operation instead of the function name. If None, uses the function name. + should_not_retry: An optional callable that takes the caught exception and + returns True if retries should be skipped immediately. Useful for + short-circuiting retries on non-transient errors (e.g. HTTP 410 Gone) + where retrying will never succeed. When it returns True the exception + is re-raised immediately without any backoff or retry logging. """ def decorator( @@ -56,6 +62,8 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: try: return await func(*args, **kwargs) except retry_on_exceptions as e: + if should_not_retry(e): + raise if attempt == max_attempts - 1: logger.exception( f"Function {name!r} failed after {max_attempts} attempts" diff --git a/src/prefect/concurrency/_leases.py b/src/prefect/concurrency/_leases.py index c46bd650b974..31c2febbed80 100644 --- a/src/prefect/concurrency/_leases.py +++ b/src/prefect/concurrency/_leases.py @@ -1,9 +1,11 @@ import asyncio import concurrent.futures from contextlib import asynccontextmanager, contextmanager -from typing import AsyncGenerator, Generator +from typing import AsyncGenerator, Callable, Generator from uuid import UUID +import httpx + from prefect._internal.concurrency.api import create_call from prefect._internal.concurrency.cancellation import ( AsyncCancelScope, @@ -18,6 +20,7 @@ async def _lease_renewal_loop( lease_id: UUID, lease_duration: float, + should_stop: Callable[[], bool] = lambda: False, ) -> None: """ Maintain a concurrency lease by renewing it after the given interval. @@ -25,16 +28,29 @@ async def _lease_renewal_loop( Args: lease_id: The ID of the lease to maintain. lease_duration: The duration of the lease in seconds. + should_stop: An optional callable that returns True when the renewal loop + should exit cleanly. Checked before each renewal attempt so that the + loop can stop without raising when the flow has already reached a + terminal state (e.g. the server will release the lease itself). """ async with get_client() as client: - @retry_async_fn(max_attempts=3, operation_name="concurrency lease renewal") + @retry_async_fn( + max_attempts=3, + operation_name="concurrency lease renewal", + should_not_retry=lambda e: ( + isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 410 + ), + ) async def renew() -> None: await client.renew_concurrency_lease( lease_id=lease_id, lease_duration=lease_duration ) while True: + # Exit cleanly if the caller signals that the flow is done. + if should_stop(): + return await renew() await asyncio.sleep( # Renew the lease 3/4 of the way through the lease duration lease_duration * 0.75 @@ -47,6 +63,7 @@ def maintain_concurrency_lease( lease_duration: float, raise_on_lease_renewal_failure: bool = False, suppress_warnings: bool = False, + should_stop: Callable[[], bool] = lambda: False, ) -> Generator[None, None, None]: """ Maintain a concurrency lease for the given lease ID. @@ -55,6 +72,11 @@ def maintain_concurrency_lease( lease_id: The ID of the lease to maintain. lease_duration: The duration of the lease in seconds. raise_on_lease_renewal_failure: A boolean specifying whether to raise an error if the lease renewal fails. + should_stop: An optional callable that returns True when the renewal loop + should exit cleanly without treating a failure as a crash. Typically + set to a check on the engine's current flow run state so that renewal + failures that occur after a successful terminal state transition are + silently ignored instead of propagated as crashes. """ # Start a loop to renew the lease on the global event loop to avoid blocking the main thread global_loop = get_global_loop() @@ -62,6 +84,7 @@ def maintain_concurrency_lease( _lease_renewal_loop, lease_id, lease_duration, + should_stop, ) global_loop.submit(lease_renewal_call) @@ -73,10 +96,17 @@ def handle_lease_renewal_failure(future: concurrent.futures.Future[None]): exc = future.exception() if exc: try: - # Use a run logger if available logger = get_run_logger() except Exception: logger = get_logger("concurrency") + + if should_stop(): + logger.debug( + "Concurrency lease renewal failed after flow reached terminal state - this is expected.", + exc_info=(type(exc), exc, exc.__traceback__), + ) + return + if raise_on_lease_renewal_failure: logger.error( "Concurrency lease renewal failed - slots are no longer reserved. Terminating execution to prevent over-allocation.", @@ -110,6 +140,7 @@ async def amaintain_concurrency_lease( lease_duration: float, raise_on_lease_renewal_failure: bool = False, suppress_warnings: bool = False, + should_stop: Callable[[], bool] = lambda: False, ) -> AsyncGenerator[None, None]: """ Maintain a concurrency lease for the given lease ID. @@ -118,9 +149,14 @@ async def amaintain_concurrency_lease( lease_id: The ID of the lease to maintain. lease_duration: The duration of the lease in seconds. raise_on_lease_renewal_failure: A boolean specifying whether to raise an error if the lease renewal fails. + should_stop: An optional callable that returns True when the renewal loop + should exit cleanly without treating a failure as a crash. Typically + set to a check on the engine's current flow run state so that renewal + failures that occur after a successful terminal state transition are + silently ignored instead of propagated as crashes. """ lease_renewal_task = asyncio.create_task( - _lease_renewal_loop(lease_id, lease_duration) + _lease_renewal_loop(lease_id, lease_duration, should_stop) ) with AsyncCancelScope() as cancel_scope: @@ -131,10 +167,17 @@ def handle_lease_renewal_failure(task: asyncio.Task[None]): exc = task.exception() if exc: try: - # Use a run logger if available logger = get_run_logger() except Exception: logger = get_logger("concurrency") + + if should_stop(): + logger.debug( + "Concurrency lease renewal failed after flow reached terminal state - this is expected.", + exc_info=(type(exc), exc, exc.__traceback__), + ) + return + if raise_on_lease_renewal_failure: logger.error( "Concurrency lease renewal failed - slots are no longer reserved. Terminating execution to prevent over-allocation.", diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 1eb95d2226bd..9e4b5ec1971b 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -303,6 +303,7 @@ class BaseFlowRunEngine(Generic[P, R]): _is_started: bool = False short_circuit: bool = False _flow_run_name_set: bool = False + _flow_executed: bool = False _telemetry: RunTelemetry = field(default_factory=RunTelemetry) def __post_init__(self) -> None: @@ -548,6 +549,7 @@ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": return _result def handle_success(self, result: R) -> R: + self._flow_executed = True result_store = getattr(FlowRunContext.get(), "result_store", None) if result_store is None: raise ValueError("Result store is not set") @@ -831,7 +833,17 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): if lease_id := self.state.state_details.deployment_concurrency_lease_id: stack.enter_context( maintain_concurrency_lease( - lease_id, 300, raise_on_lease_renewal_failure=True + lease_id, + 300, + raise_on_lease_renewal_failure=True, + should_stop=lambda: ( + self._flow_executed + or bool( + self.flow_run + and self.flow_run.state + and self.flow_run.state.is_final() + ) + ), ) ) @@ -932,7 +944,11 @@ def initialize_run(self): raise except BaseException as exc: # We don't want to crash a flow run if the user code finished executing - if self.flow_run.state and not self.flow_run.state.is_final(): + if ( + self.flow_run.state + and not self.flow_run.state.is_final() + and not self._flow_executed + ): # BaseExceptions are caught and handled as crashes self.handle_crash(exc) raise @@ -1151,6 +1167,7 @@ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]" return await self.state.aresult(raise_on_failure=raise_on_failure) # type: ignore async def handle_success(self, result: R) -> R: + self._flow_executed = True result_store = getattr(FlowRunContext.get(), "result_store", None) if result_store is None: raise ValueError("Result store is not set") @@ -1432,7 +1449,17 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None): if lease_id := self.state.state_details.deployment_concurrency_lease_id: await stack.enter_async_context( amaintain_concurrency_lease( - lease_id, 300, raise_on_lease_renewal_failure=True + lease_id, + 300, + raise_on_lease_renewal_failure=True, + should_stop=lambda: ( + self._flow_executed + or bool( + self.flow_run + and self.flow_run.state + and self.flow_run.state.is_final() + ) + ), ) ) @@ -1537,7 +1564,11 @@ async def initialize_run(self): raise except BaseException as exc: # We don't want to crash a flow run if the user code finished executing - if self.flow_run.state and not self.flow_run.state.is_final(): + if ( + self.flow_run.state + and not self.flow_run.state.is_final() + and not self._flow_executed + ): # BaseExceptions are caught and handled as crashes await self.handle_crash(exc) raise diff --git a/src/prefect/server/orchestration/core_policy.py b/src/prefect/server/orchestration/core_policy.py index 0b40382de3db..03be3715f2d9 100644 --- a/src/prefect/server/orchestration/core_policy.py +++ b/src/prefect/server/orchestration/core_policy.py @@ -665,15 +665,17 @@ async def cleanup( # type: ignore if not deployment or not deployment.concurrency_limit_id: return - await concurrency_limits_v2.bulk_decrement_active_slots( - session=context.session, - concurrency_limit_ids=[deployment.concurrency_limit_id], - slots=1, - ) + # Only decrement active slots if a lease was actually acquired + # (i.e., if deployment_concurrency_lease_id exists in validated_state) if ( validated_state and validated_state.state_details.deployment_concurrency_lease_id ): + await concurrency_limits_v2.bulk_decrement_active_slots( + session=context.session, + concurrency_limit_ids=[deployment.concurrency_limit_id], + slots=1, + ) lease_storage = get_concurrency_lease_storage() await lease_storage.revoke_lease( lease_id=validated_state.state_details.deployment_concurrency_lease_id, @@ -844,7 +846,7 @@ class ReleaseFlowConcurrencySlots(FlowRunUniversalTransform): Releases deployment concurrency slots held by a flow run. This rule releases a concurrency slot for a deployment when a flow run - transitions out of the Running or Cancelling state. + transitions out of the Running, Cancelling, or Pending state. """ async def after_transition( diff --git a/tests/_internal/test_retries.py b/tests/_internal/test_retries.py index 0d4376baa6ab..2dc15c418f52 100644 --- a/tests/_internal/test_retries.py +++ b/tests/_internal/test_retries.py @@ -110,3 +110,52 @@ async def eventual_success_func(): assert result == "Success" assert mock_func.call_count == 3 assert mock_sleep.call_count == 2 + + async def test_should_not_retry_short_circuits(self, mock_sleep): + """When should_not_retry returns True, raises immediately without retrying.""" + mock_func = AsyncMock(side_effect=ValueError("no retry")) + + @retry_async_fn(max_attempts=3, should_not_retry=lambda e: True) + async def fail_func(): + await mock_func() + + with pytest.raises(ValueError, match="no retry"): + await fail_func() + + assert mock_func.call_count == 1 + assert mock_sleep.call_count == 0 + + async def test_should_not_retry_false_does_not_affect_retries(self, mock_sleep): + """When should_not_retry returns False, normal retry behavior is preserved.""" + mock_func = AsyncMock(side_effect=ValueError("retry me")) + + @retry_async_fn(max_attempts=3, should_not_retry=lambda e: False) + async def fail_func(): + await mock_func() + + with pytest.raises(ValueError, match="retry me"): + await fail_func() + + assert mock_func.call_count == 3 + assert mock_sleep.call_count == 2 + + async def test_should_not_retry_only_short_circuits_matching_exceptions( + self, mock_sleep + ): + """should_not_retry inspects the exception — only short-circuits when it returns True.""" + mock_func = AsyncMock( + side_effect=[ValueError("retry this"), ValueError("stop here")] + ) + + @retry_async_fn( + max_attempts=5, + should_not_retry=lambda e: "stop" in str(e), + ) + async def mixed_func(): + await mock_func() + + with pytest.raises(ValueError, match="stop here"): + await mixed_func() + + assert mock_func.call_count == 2 # retried once, then short-circuited + assert mock_sleep.call_count == 1 diff --git a/tests/concurrency/test_leases.py b/tests/concurrency/test_leases.py index 6986b1919cde..0bcfe741981f 100644 --- a/tests/concurrency/test_leases.py +++ b/tests/concurrency/test_leases.py @@ -2,11 +2,18 @@ from unittest import mock from uuid import uuid4 +import httpx import pytest from prefect.concurrency._leases import _lease_renewal_loop +def _make_http_status_error(status_code: int) -> httpx.HTTPStatusError: + request = httpx.Request("POST", "http://test/leases/renew") + response = httpx.Response(status_code, request=request) + return httpx.HTTPStatusError(f"{status_code}", request=request, response=response) + + async def test_lease_renewal_loop_renews_lease(): mock_client = mock.AsyncMock() mock_client.renew_concurrency_lease.side_effect = [ @@ -59,3 +66,55 @@ async def test_lease_renewal_loop_raises_after_max_retry_attempts(): # retry_async_fn with max_attempts=3 tries exactly 3 times assert mock_client.renew_concurrency_lease.call_count == 3 + + +async def test_lease_renewal_loop_does_not_retry_on_410(): + """A 410 response raises httpx.HTTPStatusError immediately without any retries.""" + mock_client = mock.AsyncMock() + mock_client.renew_concurrency_lease.side_effect = _make_http_status_error(410) + + with ( + mock.patch("prefect.concurrency._leases.get_client") as mock_get_client, + mock.patch("asyncio.sleep", new_callable=mock.AsyncMock), + ): + mock_get_client.return_value.__aenter__.return_value = mock_client + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await _lease_renewal_loop(lease_id=uuid4(), lease_duration=10.0) + + assert exc_info.value.response.status_code == 410 + # No retries — called exactly once + assert mock_client.renew_concurrency_lease.call_count == 1 + + +async def test_lease_renewal_loop_retries_on_non_410_http_error(): + """Non-410 HTTP errors (e.g. 500) are retried up to 3 times like any other Exception.""" + mock_client = mock.AsyncMock() + mock_client.renew_concurrency_lease.side_effect = _make_http_status_error(500) + + with ( + mock.patch("prefect.concurrency._leases.get_client") as mock_get_client, + mock.patch("asyncio.sleep", new_callable=mock.AsyncMock), + ): + mock_get_client.return_value.__aenter__.return_value = mock_client + with pytest.raises(httpx.HTTPStatusError): + await _lease_renewal_loop(lease_id=uuid4(), lease_duration=10.0) + + # Retried up to max_attempts=3 + assert mock_client.renew_concurrency_lease.call_count == 3 + + +async def test_lease_renewal_loop_exits_cleanly_when_should_stop(): + """If should_stop() is True from the start, the loop exits without making any requests.""" + mock_client = mock.AsyncMock() + mock_client.renew_concurrency_lease.side_effect = _make_http_status_error(410) + + with ( + mock.patch("prefect.concurrency._leases.get_client") as mock_get_client, + mock.patch("asyncio.sleep", new_callable=mock.AsyncMock), + ): + mock_get_client.return_value.__aenter__.return_value = mock_client + await _lease_renewal_loop( + lease_id=uuid4(), lease_duration=10.0, should_stop=lambda: True + ) + + assert mock_client.renew_concurrency_lease.call_count == 0 diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index a4811382d96f..4f9c9ea0db85 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -1192,6 +1192,112 @@ async def begin_run_with_exception(self): # The flow run should be crashed assert flow_run.state.is_crashed() + async def test_lease_renewal_failure_during_state_transition_does_not_crash_sync( + self, prefect_client, monkeypatch, caplog + ): + """ + Test that a flow run that completes successfully but has a lease renewal + failure during the state transition API call does not get marked as crashed. + This simulates the exact scenario from issue #19068. + """ + from prefect._internal.concurrency.cancellation import CancelledError + from prefect.exceptions import UnfinishedRun + + flow_name = f"my-flow-{uuid.uuid4()}" + + @flow(name=flow_name) + def my_flow(): + return 42 + + # Mock set_state to raise CancelledError (simulating lease cancellation) + original_set_state = FlowRunEngine.set_state + call_count = {"count": 0} + + def set_state_with_cancellation(self, state, force=False): + call_count["count"] += 1 + # First call is to set Running state - let it succeed + if call_count["count"] == 1: + return original_set_state(self, state, force) + # Second call is to set Completed state - simulate cancellation + # But first mark as executed to match real behavior + self._flow_executed = True + raise CancelledError() + + monkeypatch.setattr(FlowRunEngine, "set_state", set_state_with_cancellation) + + # Run the flow, expecting it to finish without crashing + # The state transition will fail but the flow itself executes successfully + # Since the state never transitions to Completed, calling my_flow() will + # raise UnfinishedRun when trying to get the result + with pytest.raises(UnfinishedRun): + my_flow() + + flow_runs = await prefect_client.read_flow_runs( + flow_filter=FlowFilter(name=FlowFilterName(any_=[flow_name])) + ) + assert len(flow_runs) == 1 + flow_run = flow_runs[0] + # The flow run should NOT be crashed - it should stay in Running + # because the state transition to Completed failed + assert not flow_run.state.is_crashed() + # Verify the debug log message was recorded + assert ( + "BaseException was raised after user code finished executing" in caplog.text + ) + + async def test_lease_renewal_failure_during_state_transition_does_not_crash_async( + self, prefect_client, monkeypatch, caplog + ): + """ + Test that an async flow run that completes successfully but has a lease + renewal failure during the state transition API call does not get marked as crashed. + """ + from prefect._internal.concurrency.cancellation import CancelledError + from prefect.exceptions import UnfinishedRun + + flow_name = f"my-flow-{uuid.uuid4()}" + + @flow(name=flow_name) + async def my_flow(): + return 42 + + # Mock set_state to raise CancelledError (simulating lease cancellation) + original_set_state = AsyncFlowRunEngine.set_state + call_count = {"count": 0} + + async def set_state_with_cancellation(self, state, force=False): + call_count["count"] += 1 + # First call is to set Running state - let it succeed + if call_count["count"] == 1: + return await original_set_state(self, state, force) + # Second call is to set Completed state - simulate cancellation + # But first mark as executed to match real behavior + self._flow_executed = True + raise CancelledError() + + monkeypatch.setattr( + AsyncFlowRunEngine, "set_state", set_state_with_cancellation + ) + + # Run the flow, expecting it to finish without crashing + # The state transition will fail but the flow itself executes successfully + # Since the state never transitions to Completed, calling my_flow() will + # raise UnfinishedRun when trying to get the result + with pytest.raises(UnfinishedRun): + await my_flow() + + flow_runs = await prefect_client.read_flow_runs( + flow_filter=FlowFilter(name=FlowFilterName(any_=[flow_name])) + ) + assert len(flow_runs) == 1 + flow_run = flow_runs[0] + # The flow run should NOT be crashed + assert not flow_run.state.is_crashed() + # Verify the debug log message was recorded + assert ( + "BaseException was raised after user code finished executing" in caplog.text + ) + class TestPauseFlowRun: async def test_pause_flow_run_from_task_pauses_parent_flow( @@ -2577,7 +2683,7 @@ def foo(): run_flow(foo, flow_run) mock_maintain_concurrency_lease.assert_called_once_with( - ANY, 300, raise_on_lease_renewal_failure=True + ANY, 300, raise_on_lease_renewal_failure=True, should_stop=ANY ) async def test_lease_renewal_async( @@ -2612,7 +2718,7 @@ async def foo(): await run_flow(foo, flow_run) mock_maintain_concurrency_lease.assert_called_once_with( - ANY, 300, raise_on_lease_renewal_failure=True + ANY, 300, raise_on_lease_renewal_failure=True, should_stop=ANY )