Skip to content

Commit 85dde16

Browse files
authored
Fix flaky pause_and_assert helper (#1493)
1 parent 19aa13e commit 85dde16

2 files changed

Lines changed: 39 additions & 9 deletions

File tree

tests/helpers/__init__.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging.handlers
44
import queue
55
import socket
6+
import threading
67
import time
78
import uuid
89
from collections.abc import Awaitable, Callable, Iterator, Sequence
@@ -265,8 +266,28 @@ async def get_pending_activity_info(
265266
return None
266267

267268

269+
_wait_for_pause_events: dict[str, threading.Event] = {}
270+
271+
272+
def wait_for_pause_event(activity_id: str) -> None:
273+
event = _wait_for_pause_events.get(activity_id)
274+
if event is not None:
275+
event.wait()
276+
277+
278+
async def async_wait_for_pause_event(activity_id: str) -> None:
279+
event = _wait_for_pause_events.get(activity_id)
280+
if event is not None:
281+
await asyncio.get_running_loop().run_in_executor(None, event.wait)
282+
283+
268284
async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):
269-
"""Pause the given activity and assert it becomes paused."""
285+
"""Pause the given activity and assert it becomes paused.
286+
287+
Registers an event before calling the pause API so cooperating test
288+
activities (those that catch the pause-induced cancel via
289+
wait_for_pause_release) hang until we have observed paused=true.
290+
"""
270291
desc = await handle.describe()
271292
req = PauseActivityRequest(
272293
namespace=client.namespace,
@@ -276,14 +297,19 @@ async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id:
276297
),
277298
id=activity_id,
278299
)
279-
await client.workflow_service.pause_activity(req)
280300

281-
# Assert eventually paused
282-
async def check_paused() -> bool:
283-
info = await assert_pending_activity_exists_eventually(handle, activity_id)
284-
return info.paused
301+
_wait_for_pause_events[activity_id] = threading.Event()
302+
try:
303+
await client.workflow_service.pause_activity(req)
304+
305+
async def check_paused() -> None:
306+
info = await assert_pending_activity_exists_eventually(handle, activity_id)
307+
assert info.paused, f"Activity {activity_id} not yet paused"
285308

286-
await assert_eventually(check_paused)
309+
await assert_eventually(check_paused)
310+
finally:
311+
_wait_for_pause_events[activity_id].set()
312+
del _wait_for_pause_events[activity_id]
287313

288314

289315
async def unpause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):
@@ -300,9 +326,9 @@ async def unpause_and_assert(client: Client, handle: WorkflowHandle, activity_id
300326
await client.workflow_service.unpause_activity(req)
301327

302328
# Assert eventually not paused
303-
async def check_unpaused() -> bool:
329+
async def check_unpaused() -> None:
304330
info = await assert_pending_activity_exists_eventually(handle, activity_id)
305-
return not info.paused
331+
assert not info.paused, f"Activity {activity_id} still paused"
306332

307333
await assert_eventually(check_unpaused)
308334

tests/worker/test_workflow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,14 @@
130130
assert_pending_activity_exists_eventually,
131131
assert_task_fail_eventually,
132132
assert_workflow_exists_eventually,
133+
async_wait_for_pause_event,
133134
ensure_search_attributes_present,
134135
find_free_port,
135136
get_pending_activity_info,
136137
new_worker,
137138
pause_and_assert,
138139
unpause_and_assert,
140+
wait_for_pause_event,
139141
workflow_update_exists,
140142
)
141143
from tests.helpers.cache_eviction import (
@@ -7782,6 +7784,7 @@ async def heartbeat_activity(
77827784
except (CancelledError, asyncio.CancelledError) as err:
77837785
if not catch_err:
77847786
raise err
7787+
await async_wait_for_pause_event(activity.info().activity_id)
77857788
return activity.cancellation_details()
77867789
finally:
77877790
activity.heartbeat("finally-complete")
@@ -7801,6 +7804,7 @@ def sync_heartbeat_activity(
78017804
except (CancelledError, asyncio.CancelledError) as err:
78027805
if not catch_err:
78037806
raise err
7807+
wait_for_pause_event(activity.info().activity_id)
78047808
return activity.cancellation_details()
78057809
finally:
78067810
activity.heartbeat("finally-complete")

0 commit comments

Comments
 (0)