Skip to content

Commit aa46390

Browse files
committed
Merge remote-tracking branch 'origin/develop' into develop
2 parents 70b4926 + d06d29c commit aa46390

2 files changed

Lines changed: 99 additions & 22 deletions

File tree

securedrop/journalist_app/api2/events.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import asdict
22

33
from db import db
4+
from flask import current_app
45
from journalist_app import utils
56
from journalist_app.api2.shared import json_version, mark_source_deleted, save_reply
67
from journalist_app.api2.types import (
@@ -35,13 +36,17 @@ class EventHandler:
3536
`journalist_api2.types.EVENT_DATA_TYPES`;
3637
3738
3. define the handler as a static method `handle_thing_done(event: Event)`
38-
in this class
39+
in this class; and
3940
4041
4. explicitly register `{"thing_done": self.handle_thing_done}` inside
4142
`EventHandler.process()`.
4243
4344
This is belt-and-suspenders for ensuring that only the intended methods are
4445
exposed as callable event handlers.
46+
47+
To preserve transaction separation between events, handlers MUST return with
48+
a clean SQLAlchemy session: in other words, having either successfully
49+
committed or rolled back all of their changes.
4550
"""
4651

4752
def __init__(self, session: Session, redis: Redis) -> None:
@@ -59,7 +64,7 @@ def process(self, event: Event, minor: int) -> EventResult:
5964
"""The per-event entry-point for handling a single event."""
6065

6166
try:
62-
if self.has_progress(event):
67+
if self.is_duplicate(event):
6368
return EventResult(
6469
event_id=event.id,
6570
status=(EventStatusCode.AlreadyReported, None),
@@ -83,35 +88,45 @@ def process(self, event: Event, minor: int) -> EventResult:
8388
),
8489
)
8590

86-
self.mark_progress(event) # prevent races
87-
result = handler(event, minor)
88-
self.mark_progress(event, result.status[0]) # enforce idempotence
91+
try:
92+
result = handler(event, minor)
93+
94+
# Enforce "handlers MUST return with a clean SQLAlchemy session" above:
95+
if db.session.dirty or db.session.new or db.session.deleted:
96+
raise RuntimeError(f"{handler} returned with a pending database transaction")
97+
98+
# Catch anything not handled by the handler:
99+
except Exception:
100+
current_app.logger.error(f"unhandled exception in handler for {event}", exc_info=True)
101+
db.session.rollback()
102+
result = EventResult(
103+
event.id, (EventStatusCode.InternalServerError, "failed to process event")
104+
)
105+
106+
self.record_status(event, result.status[0])
89107
return result
90108

91109
def idempotence_key(self, event: Event) -> str:
92110
return f"{REDIS_EVENT_PREFIX}/{self._session.user.uuid}/{event.id}"
93111

94-
def has_progress(self, event: Event) -> EventStatusCode:
95-
return self._redis.get(self.idempotence_key(event))
96-
97-
def mark_progress(
98-
self, event: Event, status: EventStatusCode = EventStatusCode.Processing
99-
) -> None:
100-
"""
101-
If `status` is a non-error code, mark it as the progress of `event`, to
102-
be returned later as "Already Reported".
103-
104-
If `status` is an error code, clear it, since `event` MAY be resubmitted
105-
later.
106-
"""
107-
if status >= EventStatusCode.BadRequest:
108-
self._redis.delete(self.idempotence_key(event))
109-
else:
112+
def is_duplicate(self, event: Event) -> bool:
113+
"""Returns `True` if this event is already registered (i.e., a replay)."""
114+
return (
110115
self._redis.set(
111116
self.idempotence_key(event),
112-
status,
117+
EventStatusCode.Processing,
113118
ex=IDEMPOTENCE_PERIOD,
119+
nx=True,
114120
)
121+
is None
122+
)
123+
124+
def record_status(self, event: Event, status: EventStatusCode) -> None:
125+
"""Record the event's final status for idempotence, or clear on error to permit retry."""
126+
if status >= EventStatusCode.BadRequest:
127+
self._redis.delete(self.idempotence_key(event))
128+
else:
129+
self._redis.set(self.idempotence_key(event), status, ex=IDEMPOTENCE_PERIOD)
115130

116131
@staticmethod
117132
def handle_item_deleted(event: Event, minor: int) -> EventResult:

securedrop/tests/test_journalist_api2.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,41 @@
1+
import threading
12
import uuid
23
from contextlib import contextmanager
34
from copy import deepcopy
45
from dataclasses import asdict
56
from datetime import UTC, datetime
7+
from unittest.mock import MagicMock, patch
68
from uuid import uuid4
79

810
import journalist_app as journalist_app_module
911
import pytest
1012
from flask import url_for
1113
from flask_sqlalchemy import get_debug_queries
1214
from journalist_app import api2
15+
from journalist_app.api2.events import EventHandler
1316
from journalist_app.api2.shared import json_version
1417
from journalist_app.api2.types import (
1518
VERSION_LEN,
1619
Event,
20+
EventResult,
21+
EventStatusCode,
1722
EventType,
1823
ItemTarget,
1924
SourceTarget,
2025
)
2126
from models import Reply, Source, SourceStar, Submission, db
27+
from redis import Redis
2228
from sqlalchemy.orm.exc import MultipleResultsFound
2329
from tests.utils import ascii_armor, decrypt_as_journalist
2430
from tests.utils.api_helper import get_api_headers
2531
from tests.utils.db_helper import init_source, submit
2632

2733

34+
@pytest.fixture
35+
def redis(config):
36+
return Redis(decode_responses=True, **config.REDIS_KWARGS)
37+
38+
2839
def filtered_queries():
2940
# filter out PRAGMA, instance_config loading, etc.
3041
return [
@@ -832,6 +843,57 @@ def test_api2_idempotence_period(journalist_app):
832843
assert journalist_app.config["SESSION_LIFETIME"] <= api2.events.IDEMPOTENCE_PERIOD
833844

834845

846+
def test_api2_atomic_idempotence(redis):
847+
"""
848+
Check that identical events received simultaneously are still only accepted
849+
once.
850+
"""
851+
852+
session = MagicMock()
853+
session.user.uuid = str(uuid4())
854+
855+
event = MagicMock()
856+
event.id = str(uuid4())
857+
event.type = EventType.SOURCE_STARRED
858+
859+
barrier = threading.Barrier(2)
860+
handler_call_count = [0]
861+
original_set = redis.set
862+
863+
def slow_set(*args, **kwargs):
864+
if kwargs.get("nx"):
865+
# Widen the possible TOCTOU window: hold both threads at the SETNX
866+
# gate until both are about to set the not-yet-existing key, then
867+
# release together so atomicity is what enforces the invariant.
868+
barrier.wait(timeout=5)
869+
return original_set(*args, **kwargs)
870+
871+
def counting_handler(ev, minor):
872+
handler_call_count[0] += 1
873+
return EventResult(event_id=ev.id, status=(EventStatusCode.OK, None))
874+
875+
with (
876+
patch.object(redis, "set", slow_set),
877+
patch.object(EventHandler, "handle_source_starred", staticmethod(counting_handler)),
878+
):
879+
threads = [
880+
threading.Thread(
881+
target=EventHandler(session=session, redis=redis).process,
882+
args=(event, 1),
883+
)
884+
for _ in range(2)
885+
]
886+
for t in threads:
887+
t.start()
888+
for t in threads:
889+
t.join(timeout=10)
890+
891+
# Should be 1: a correct implementation relies on atomic SETNX to let only
892+
# one thread past the check. Fails with count == 2 if mark_progress() drops
893+
# the NX flag or claim_progress() reverts to a non-atomic check-then-act.
894+
assert handler_call_count[0] == 1
895+
896+
835897
def test_api2_event_ordering(journalist_app, journalist_api_token, test_files):
836898
"""
837899
If two `item_deleted` events for the same item arrive out of order, the

0 commit comments

Comments
 (0)