Skip to content

Commit bc7c966

Browse files
LogPusher: flush() blocks until every prior entry is processed
Adds a per-entry monotonic seq to the queue. `flush(timeout=None)` snapshots _pushed_seq and waits on the condvar until the drain has advanced _processed_seq past it, where "processed" means either sent or overflow-dropped. Dropped entries advance the marker so a flush doesn't wait forever on entries that will never reach the server. Makes task_attempt._cleanup's `self._log_pusher.flush()` actually wait for final logs to ship, restoring the pre-rewrite contract. close() stays best-effort per directive. Also replaces the _wait_for polling in test_failures_always_deliver_via_retry with deterministic flush(timeout=...) synchronization — removes the xdist- load flake that showed up as assert 2==1 / 6==1 on the seed/seed-then-push race. Co-authored-by: Russell Power <rjpower@users.noreply.github.com>
1 parent f56c8f8 commit bc7c966

2 files changed

Lines changed: 158 additions & 43 deletions

File tree

lib/iris/src/iris/log_server/client.py

Lines changed: 94 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,16 @@
5555

5656

5757
class LogPusher:
58-
"""Non-blocking buffered client for pushing log entries to a remote LogService.
58+
"""Buffered client for pushing log entries to a remote LogService.
5959
60-
``push`` appends to an in-memory queue; a background thread drains it
61-
in per-key batches. Send failures re-buffer and back off exponentially
62-
— only the ``MAX_LOG_BUFFER_SIZE`` overflow path drops entries.
60+
``push`` is non-blocking: it appends to an in-memory queue and returns.
61+
A background thread drains the queue in per-key batches. Send failures
62+
re-buffer and back off exponentially — only the ``MAX_LOG_BUFFER_SIZE``
63+
overflow path drops entries.
64+
65+
``flush`` blocks until every entry enqueued before the call has been
66+
processed (sent or overflow-dropped). Use the ``timeout`` argument to
67+
bound the wait — by default ``flush`` waits indefinitely.
6368
6469
``server_url`` is passed to ``resolver`` (default: identity) to obtain
6570
the actual http address. Retryable failures invalidate the cached RPC
@@ -87,12 +92,20 @@ def __init__(
8792

8893
# All shared state is guarded by _cond. The drain thread is the
8994
# only owner of _client, so no separate client lock. ``_queue`` is
90-
# a single FIFO of (key, entry); the drain thread groups by key
91-
# just before sending. Trimming on overflow is one popleft.
95+
# a single FIFO of (seq, key, entry); the drain thread groups by
96+
# key just before sending. Trimming on overflow is one popleft.
97+
# ``seq`` is a monotonic per-entry counter used by blocking flush.
9298
self._cond = threading.Condition()
93-
self._queue: deque[tuple[str, logging_pb2.LogEntry]] = deque()
99+
self._queue: deque[tuple[int, str, logging_pb2.LogEntry]] = deque()
94100
self._closed = False
95101

102+
# Monotonic counters for blocking flush(). ``_pushed_seq`` advances
103+
# on every entry enqueued. ``_processed_seq`` advances when the
104+
# drain thread acks an entry as either successfully sent or
105+
# overflow-dropped — both terminal states from flush's POV.
106+
self._pushed_seq = 0
107+
self._processed_seq = 0
108+
96109
# Built lazily by the drain thread on first send; invalidated on
97110
# any failure so the next attempt re-resolves.
98111
self._client: LogServiceClientSync | None = None
@@ -115,27 +128,53 @@ def push(self, key: str, entries: list[logging_pb2.LogEntry]) -> None:
115128
if self._closed:
116129
return
117130
for e in entries:
118-
self._queue.append((key, e))
131+
self._pushed_seq += 1
132+
self._queue.append((self._pushed_seq, key, e))
119133
self._trim_oldest_locked()
120134
if len(self._queue) >= self._batch_size:
121-
self._cond.notify()
135+
self._cond.notify_all()
122136

123-
def flush(self) -> None:
124-
"""Poke the drain thread to send whatever is buffered now.
137+
def flush(self, timeout: float | None = None) -> bool:
138+
"""Block until every entry enqueued before this call has been processed.
125139
126-
Non-blocking. For draining on shutdown, use ``close``.
140+
"Processed" means either successfully sent or overflow-dropped —
141+
both terminal states. Returns ``True`` if the drain caught up,
142+
``False`` on timeout. ``timeout=None`` waits indefinitely.
143+
144+
For shutdown drain, prefer ``close`` (best-effort, won't block on
145+
a stuck server).
127146
"""
128147
with self._cond:
129-
if self._queue:
130-
self._cond.notify()
148+
target = self._pushed_seq
149+
if target == 0 or self._processed_seq >= target:
150+
return True
151+
self._cond.notify_all()
152+
deadline = (time.monotonic() + timeout) if timeout is not None else None
153+
while self._processed_seq < target:
154+
if self._closed:
155+
return self._processed_seq >= target
156+
if deadline is None:
157+
# Re-check periodically so a wedged drain still surfaces.
158+
self._cond.wait(timeout=1.0)
159+
else:
160+
remaining = deadline - time.monotonic()
161+
if remaining <= 0:
162+
return False
163+
self._cond.wait(timeout=remaining)
164+
return True
131165

132166
def close(self) -> None:
133-
"""Stop the drain thread after one best-effort drain, close the RPC client."""
167+
"""Stop the drain thread after one best-effort drain, close the RPC client.
168+
169+
Best-effort: if a send is in flight when ``close()`` returns the
170+
join timeout, we still close the cached client. Use ``flush()``
171+
first if you need to guarantee final delivery.
172+
"""
134173
with self._cond:
135174
if self._closed:
136175
return
137176
self._closed = True
138-
self._cond.notify()
177+
self._cond.notify_all()
139178
# Join the drain thread; it will send what it can and exit.
140179
self._thread.join(timeout=max(self._flush_interval * 2, 10.0))
141180
if self._client is not None:
@@ -150,28 +189,38 @@ def close(self) -> None:
150189
# ------------------------------------------------------------------
151190

152191
def _trim_oldest_locked(self) -> None:
153-
"""Drop oldest entries until under ``_max_buffer_size``."""
192+
"""Drop oldest entries until under ``_max_buffer_size``.
193+
194+
Dropped entries advance ``_processed_seq`` so blocking ``flush``
195+
doesn't wait forever on entries that will never reach the server.
196+
"""
154197
dropped = 0
198+
max_dropped_seq = 0
155199
while len(self._queue) > self._max_buffer_size:
156-
self._queue.popleft()
200+
seq, _key, _entry = self._queue.popleft()
201+
if seq > max_dropped_seq:
202+
max_dropped_seq = seq
157203
dropped += 1
158204
if dropped:
159205
logger.warning(
160206
"LogPusher buffer overflow: dropped %d oldest entries (cap=%d)",
161207
dropped,
162208
self._max_buffer_size,
163209
)
210+
if max_dropped_seq > self._processed_seq:
211+
self._processed_seq = max_dropped_seq
212+
self._cond.notify_all()
164213

165-
def _take_queue_locked(self) -> list[tuple[str, logging_pb2.LogEntry]]:
214+
def _take_queue_locked(self) -> list[tuple[int, str, logging_pb2.LogEntry]]:
166215
"""Drain the entire queue, preserving arrival order."""
167216
items = list(self._queue)
168217
self._queue.clear()
169218
return items
170219

171-
def _rebuffer_at_head_locked(self, items: list[tuple[str, logging_pb2.LogEntry]]) -> None:
220+
def _rebuffer_at_head_locked(self, items: list[tuple[int, str, logging_pb2.LogEntry]]) -> None:
172221
"""Put unsent items back at the head of the queue (original order)."""
173-
for pair in reversed(items):
174-
self._queue.appendleft(pair)
222+
for triple in reversed(items):
223+
self._queue.appendleft(triple)
175224
self._trim_oldest_locked()
176225

177226
# ------------------------------------------------------------------
@@ -194,7 +243,11 @@ def _run(self) -> None:
194243
return
195244
items = self._take_queue_locked()
196245

197-
unsent = self._send_items(items)
246+
sent_max_seq, unsent = self._send_items(items)
247+
with self._cond:
248+
if sent_max_seq > self._processed_seq:
249+
self._processed_seq = sent_max_seq
250+
self._cond.notify_all()
198251
if not unsent:
199252
self._backoff.reset()
200253
continue
@@ -216,44 +269,50 @@ def _run(self) -> None:
216269

217270
def _send_items(
218271
self,
219-
items: list[tuple[str, logging_pb2.LogEntry]],
220-
) -> list[tuple[str, logging_pb2.LogEntry]]:
272+
items: list[tuple[int, str, logging_pb2.LogEntry]],
273+
) -> tuple[int, list[tuple[int, str, logging_pb2.LogEntry]]]:
221274
"""Group ``items`` by key (stable on first occurrence) and push one
222-
RPC per key. On any failure, return every item from that key onward
223-
so the caller can re-buffer it at the head of the queue.
275+
RPC per key. Returns ``(max_sent_seq, unsent_items)``.
224276
277+
On any failure, every item from that key onward is returned as
278+
unsent so the caller can re-buffer it at the head of the queue.
225279
Every failure mode — resolver error, retryable RPC error, or
226280
non-retryable RPC error — re-buffers so no log entries are silently
227281
dropped. Retryable errors additionally invalidate the cached client
228282
so the next attempt re-resolves the endpoint.
229283
"""
230-
groups: dict[str, list[logging_pb2.LogEntry]] = {}
231-
for key, entry in items:
232-
groups.setdefault(key, []).append(entry)
284+
groups: dict[str, list[tuple[int, logging_pb2.LogEntry]]] = {}
285+
for seq, key, entry in items:
286+
groups.setdefault(key, []).append((seq, entry))
233287

234288
sent_keys: set[str] = set()
235-
for key, entries in groups.items():
289+
max_sent_seq = 0
290+
for key, seq_entries in groups.items():
236291
try:
237292
client = self._get_client()
238293
except Exception as exc:
239294
logger.warning("LogPusher: endpoint resolution failed: %s", exc)
240-
return [p for p in items if p[0] not in sent_keys]
295+
return max_sent_seq, [p for p in items if p[1] not in sent_keys]
241296
try:
297+
entries = [e for _s, e in seq_entries]
242298
client.push_logs(logging_pb2.PushLogsRequest(key=key, entries=entries))
243299
sent_keys.add(key)
300+
for seq, _e in seq_entries:
301+
if seq > max_sent_seq:
302+
max_sent_seq = seq
244303
except Exception as exc:
245304
retryable = is_retryable_error(exc)
246305
logger.warning(
247306
"LogPusher: send failure for key=%s (%d entries, retryable=%s): %s",
248307
key,
249-
len(entries),
308+
len(seq_entries),
250309
retryable,
251310
exc,
252311
)
253312
if retryable:
254313
self._invalidate(str(exc))
255-
return [p for p in items if p[0] not in sent_keys]
256-
return []
314+
return max_sent_seq, [p for p in items if p[1] not in sent_keys]
315+
return max_sent_seq, []
257316

258317
def _build_client(self, address: str) -> LogServiceClientSync:
259318
return LogServiceClientSync(

lib/iris/tests/test_remote_log_handler.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@ def _wait_for(predicate, timeout: float = 5.0) -> None:
117117

118118

119119
def test_log_pusher_buffers_and_flushes_on_demand(tracked_log_service_client):
120-
"""Entries buffered below batch_size are drained on flush()."""
120+
"""Entries buffered below batch_size are drained on flush().
121+
122+
flush() blocks until every entry enqueued before the call has shipped,
123+
so the assertions can run immediately without polling.
124+
"""
121125
pusher = LogPusher(
122126
"http://h:1",
123127
batch_size=1000,
@@ -127,15 +131,62 @@ def test_log_pusher_buffers_and_flushes_on_demand(tracked_log_service_client):
127131
entry = logging_pb2.LogEntry(source="test", data="line1")
128132
pusher.push("key-a", [entry, entry, entry])
129133
pusher.push("key-b", [entry])
130-
pusher.flush()
131-
_wait_for(lambda: len(tracked_log_service_client) == 1 and len(tracked_log_service_client[0].pushes) >= 2)
134+
assert pusher.flush(timeout=5.0)
132135

133136
totals = {p.key: len(p.entries) for p in tracked_log_service_client[0].pushes}
134137
assert totals == {"key-a": 3, "key-b": 1}
135138
finally:
136139
pusher.close()
137140

138141

142+
def test_log_pusher_flush_is_blocking(tracked_log_service_client):
143+
"""flush() returns only after every previously-pushed entry has been sent."""
144+
pusher = LogPusher(
145+
"http://h:1",
146+
batch_size=1000,
147+
flush_interval=999.0,
148+
)
149+
try:
150+
entry = logging_pb2.LogEntry(source="test", data="line")
151+
pusher.push("k", [entry, entry])
152+
# No polling — flush must block until shipped.
153+
assert pusher.flush(timeout=5.0) is True
154+
assert len(tracked_log_service_client[0].pushes) == 1
155+
assert len(tracked_log_service_client[0].pushes[0].entries) == 2
156+
finally:
157+
pusher.close()
158+
159+
160+
def test_log_pusher_flush_timeout_returns_false(monkeypatch):
161+
"""flush(timeout=...) returns False when the drain can't catch up in time.
162+
163+
Seeds a non-retryable error so the drain rebuffers and enters the
164+
backoff window; flush is given less time than the backoff interval.
165+
"""
166+
created: list[_FakeLogServiceClient] = []
167+
168+
def factory(address, timeout_ms=10_000, interceptors=()):
169+
c = _FakeLogServiceClient(address, timeout_ms=timeout_ms, interceptors=interceptors)
170+
created.append(c)
171+
return c
172+
173+
monkeypatch.setattr(client_mod, "LogServiceClientSync", factory)
174+
175+
pusher = LogPusher("http://h:1", batch_size=1, flush_interval=999.0)
176+
try:
177+
entry = logging_pb2.LogEntry(source="test", data="primer")
178+
pusher.push("k", [entry])
179+
# Wait for the cached client to exist, then seed a non-retryable
180+
# error so the next send rebuffers and the drain enters backoff.
181+
assert pusher.flush(timeout=5.0) is True
182+
created[0].errors.append(ConnectError(Code.NOT_FOUND, "missing"))
183+
pusher.push("k", [logging_pb2.LogEntry(source="test", data="stuck")])
184+
# Backoff is 0.5s; a 0.05s flush cannot catch up.
185+
assert pusher.flush(timeout=0.05) is False
186+
finally:
187+
pusher.close()
188+
189+
139190
def test_log_pusher_flushes_at_batch_size(tracked_log_service_client):
140191
"""Reaching batch_size wakes the drain thread without waiting for a timer."""
141192
pusher = LogPusher(
@@ -305,8 +356,9 @@ def resolver(_url):
305356
# for resolver_raises, which has no client to seed).
306357
pusher.push("k", [_entry("a")])
307358
if scenario != "resolver_raises":
308-
_wait_for(lambda: created and created[0].pushes)
309-
# Seed the cached client with the scenario-appropriate error.
359+
# Block until "a" has shipped, so seeding the next error is
360+
# race-free with the drain thread's next iteration.
361+
assert pusher.flush(timeout=5.0)
310362
err = (
311363
ConnectError(Code.NOT_FOUND, "missing")
312364
if scenario == "non_retryable"
@@ -316,15 +368,19 @@ def resolver(_url):
316368

317369
pusher.push("k", [_entry("b")])
318370

319-
# "b" must eventually land somewhere.
371+
# Wait deterministically for "b" to be processed (sent or dropped).
372+
assert pusher.flush(timeout=10.0)
373+
374+
# "b" must have landed somewhere — the buffer-overflow path is not
375+
# exercised here, so processed implies delivered.
320376
def delivered():
321377
return any(any(e.data == "b" for p in c.pushes for e in p.entries) for c in created)
322378

323-
_wait_for(delivered, timeout=10.0)
379+
assert delivered(), "entry 'b' was never delivered to any client"
324380

325381
if scenario.startswith("retryable"):
326382
# Retryable RPC failure invalidated the first client; second built.
327-
_wait_for(lambda: len(created) >= 2)
383+
assert len(created) >= 2
328384
assert created[0].closed is True
329385
elif scenario == "resolver_raises":
330386
# Resolver raised on first call → no client yet. Second call

0 commit comments

Comments
 (0)