Skip to content

Commit 958fbce

Browse files
authored
Reconcile automations when Postgres notifications are unavailable (#21089)
1 parent 0b734db commit 958fbce

File tree

2 files changed

+121
-7
lines changed

2 files changed

+121
-7
lines changed

src/prefect/server/events/triggers.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575

7676
AutomationID: TypeAlias = UUID
7777
TriggerID: TypeAlias = UUID
78+
AutomationStateSnapshot: TypeAlias = Tuple[
79+
Optional[prefect.types._datetime.DateTime], int
80+
]
7881

7982

8083
AUTOMATION_BUCKET_BATCH_SIZE = 500
@@ -679,6 +682,8 @@ async def periodic_evaluation(now: prefect.types._datetime.DateTime) -> None:
679682

680683
logger.debug("Running periodic evaluation as of %s (offset %ss)", as_of, offset)
681684

685+
await reconcile_automations()
686+
682687
# Any followers that have been sitting around longer than our lookback are never
683688
# going to see their leader event (maybe it was lost or took too long to arrive).
684689
# These events can just be evaluated now in the order they occurred.
@@ -713,6 +718,7 @@ async def evaluate_periodically(periodic_granularity: timedelta) -> None:
713718
automations_by_id: Dict[UUID, Automation] = {}
714719
triggers: Dict[TriggerID, EventTrigger] = {}
715720
next_proactive_runs: Dict[TriggerID, prefect.types._datetime.DateTime] = {}
721+
automation_state_snapshot: Optional[AutomationStateSnapshot] = None
716722

717723
# This lock governs any changes to the set of loaded automations; any routine that will
718724
# add/remove automations must be holding this lock when it does so. It's best to use
@@ -732,6 +738,12 @@ def find_interested_triggers(event: ReceivedEvent) -> Collection[EventTrigger]:
732738
return [trigger for trigger in candidates if trigger.covers(event)]
733739

734740

741+
def clear_loaded_automations() -> None:
742+
automations_by_id.clear()
743+
triggers.clear()
744+
next_proactive_runs.clear()
745+
746+
735747
def load_automation(automation: Optional[Automation]) -> None:
736748
"""Loads the given automation into memory so that it is available for evaluations"""
737749
if not automation:
@@ -762,14 +774,17 @@ async def automation_changed(
762774
automation_id: UUID,
763775
event: Literal["automation__created", "automation__updated", "automation__deleted"],
764776
) -> None:
777+
global automation_state_snapshot
778+
765779
async with _automations_lock():
766780
if event in ("automation__deleted", "automation__updated"):
767781
forget_automation(automation_id)
768782

769-
if event in ("automation__created", "automation__updated"):
770-
async with automations_session() as session:
783+
async with automations_session() as session:
784+
if event in ("automation__created", "automation__updated"):
771785
automation = await read_automation(session, automation_id)
772786
load_automation(automation)
787+
automation_state_snapshot = await read_automation_state_snapshot(session)
773788

774789

775790
@db_injector
@@ -788,6 +803,53 @@ async def load_automations(db: PrefectDBInterface, session: AsyncSession):
788803
)
789804

790805

806+
@db_injector
807+
async def read_automation_state_snapshot(
808+
db: PrefectDBInterface, session: AsyncSession
809+
) -> AutomationStateSnapshot:
810+
query = sa.select(
811+
sa.func.max(db.Automation.updated),
812+
sa.func.count(db.Automation.id),
813+
).select_from(db.Automation)
814+
815+
latest_updated, count = (await session.execute(query)).one()
816+
817+
return (
818+
prefect.types._datetime.create_datetime_instance(latest_updated)
819+
if latest_updated
820+
else None,
821+
count or 0,
822+
)
823+
824+
825+
async def reconcile_automations(force: bool = False) -> bool:
826+
global automation_state_snapshot
827+
828+
async with _automations_lock():
829+
async with automations_session() as session:
830+
current_snapshot = await read_automation_state_snapshot(session)
831+
if not force and current_snapshot == automation_state_snapshot:
832+
return False
833+
834+
previous_automations = automations_by_id.copy()
835+
previous_triggers = triggers.copy()
836+
previous_next_proactive_runs = next_proactive_runs.copy()
837+
838+
clear_loaded_automations()
839+
840+
try:
841+
await load_automations(session)
842+
except Exception:
843+
clear_loaded_automations()
844+
automations_by_id.update(previous_automations)
845+
triggers.update(previous_triggers)
846+
next_proactive_runs.update(previous_next_proactive_runs)
847+
raise
848+
849+
automation_state_snapshot = current_snapshot
850+
return True
851+
852+
791853
@db_injector
792854
async def remove_buckets_exceeding_threshold(
793855
db: PrefectDBInterface, session: AsyncSession, trigger: EventTrigger
@@ -1068,10 +1130,11 @@ async def sweep_closed_buckets(
10681130

10691131
async def reset() -> None:
10701132
"""Resets the in-memory state of the service"""
1133+
global automation_state_snapshot
1134+
10711135
await reset_events_clock()
1072-
automations_by_id.clear()
1073-
triggers.clear()
1074-
next_proactive_runs.clear()
1136+
clear_loaded_automations()
1137+
automation_state_snapshot = None
10751138

10761139

10771140
async def listen_for_automation_changes() -> None:
@@ -1096,6 +1159,8 @@ async def listen_for_automation_changes() -> None:
10961159
f"Listening for automation changes on {AUTOMATION_CHANGES_CHANNEL}"
10971160
)
10981161

1162+
await reconcile_automations()
1163+
10991164
async for payload in pg_listen(
11001165
conn,
11011166
AUTOMATION_CHANGES_CHANNEL,
@@ -1152,8 +1217,7 @@ async def consumer(
11521217
# Start the automation change listener task
11531218
sync_task = asyncio.create_task(listen_for_automation_changes())
11541219

1155-
async with automations_session() as session:
1156-
await load_automations(session)
1220+
await reconcile_automations(force=True)
11571221

11581222
proactive_task = asyncio.create_task(evaluate_periodically(periodic_granularity))
11591223

tests/events/server/triggers/test_service.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def load_automations(monkeypatch: pytest.MonkeyPatch) -> mock.AsyncMock:
103103
return m
104104

105105

106+
@pytest.fixture(autouse=True)
107+
def read_automation_state_snapshot(
108+
monkeypatch: pytest.MonkeyPatch,
109+
) -> mock.AsyncMock:
110+
m = mock.AsyncMock(return_value=(None, 0))
111+
monkeypatch.setattr(
112+
"prefect.server.events.triggers.read_automation_state_snapshot", m
113+
)
114+
return m
115+
116+
106117
@pytest.fixture(autouse=True)
107118
def periodic_evaluation(monkeypatch: pytest.MonkeyPatch) -> mock.AsyncMock:
108119
m = mock.AsyncMock(spec=triggers.periodic_evaluation)
@@ -243,6 +254,45 @@ async def test_loads_automations_at_startup(
243254
load_automations.assert_awaited_once_with(open_automations_session)
244255

245256

257+
async def test_reconcile_automations_skips_reload_when_snapshot_matches(
258+
open_automations_session: mock.Mock,
259+
load_automations: mock.AsyncMock,
260+
read_automation_state_snapshot: mock.AsyncMock,
261+
):
262+
snapshot = (now("UTC"), 2)
263+
triggers.automation_state_snapshot = snapshot
264+
read_automation_state_snapshot.return_value = snapshot
265+
266+
changed = await triggers.reconcile_automations()
267+
268+
assert changed is False
269+
load_automations.assert_not_awaited()
270+
271+
272+
async def test_reconcile_automations_reloads_when_snapshot_changes(
273+
arachnophobia: Automation,
274+
open_automations_session: mock.Mock,
275+
load_automations: mock.AsyncMock,
276+
read_automation_state_snapshot: mock.AsyncMock,
277+
):
278+
previous_snapshot = (now("UTC"), 1)
279+
current_snapshot = (now("UTC") + timedelta(seconds=1), 2)
280+
triggers.automation_state_snapshot = previous_snapshot
281+
triggers.load_automation(arachnophobia)
282+
(trigger,) = arachnophobia.triggers_of_type(EventTrigger)
283+
triggers.next_proactive_runs[trigger.id] = now("UTC")
284+
read_automation_state_snapshot.return_value = current_snapshot
285+
286+
changed = await triggers.reconcile_automations()
287+
288+
assert changed is True
289+
load_automations.assert_awaited_once_with(open_automations_session)
290+
assert triggers.automation_state_snapshot == current_snapshot
291+
assert triggers.automations_by_id == {}
292+
assert triggers.triggers == {}
293+
assert triggers.next_proactive_runs == {}
294+
295+
246296
async def test_only_considers_messages_with_attributes(
247297
effective_automations,
248298
reactive_evaluation: mock.AsyncMock,

0 commit comments

Comments
 (0)