Skip to content

Commit db71d37

Browse files
committed
Always commit offsets on termination
Even when an error occurs. Added a regression test. [ML-11919](https://iguazio.atlassian.net/browse/ML-11919)
1 parent 1cc6d88 commit db71d37

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

storey/sources.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,6 @@ async def _run_loop(self):
662662
if event is _termination_obj:
663663
# We can commit all at this point because termination of
664664
# all downstream steps completed successfully.
665-
await _commit_handled_events(self._outstanding_offsets, committer, self.logger, commit_all=True)
666665
return termination_result
667666
except BaseException as ex:
668667
if self.logger:
@@ -680,6 +679,8 @@ async def _run_loop(self):
680679
await self._q.get()
681680
self._raise_on_error()
682681
finally:
682+
# Commit on termination regardless of errors (ML-11919)
683+
await _commit_handled_events(self._outstanding_offsets, committer, self.logger, commit_all=True)
683684
if event is _termination_obj or self._ex:
684685
for closeable in self._closeables:
685686
try:

tests/test_flow.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ async def _do(self, event):
176176
return await self._do_downstream(event)
177177

178178

179+
class ErrorOnTermination(storey.Flow):
180+
async def _do(self, event):
181+
if event is storey.dtypes._termination_obj:
182+
raise ATestException("We raise this error on termination on purpose")
183+
return await self._do_downstream(event)
184+
185+
179186
def test_offset_commit():
180187
platform = Committer()
181188
context = CommitterContext(platform)
@@ -544,6 +551,40 @@ def test_offset_commit_error():
544551
assert "RuntimeError: Something went wrong" in log_message
545552

546553

554+
async def async_offset_commit_error_on_termination():
555+
platform = Committer()
556+
logger = MockLogger()
557+
context = CommitterContext(platform, logger=logger)
558+
559+
controller = build_flow(
560+
[
561+
AsyncEmitSource(context=context, explicit_ack=True, max_wait_before_commit=1),
562+
ErrorOnTermination(),
563+
]
564+
).run()
565+
566+
num_shards = 10
567+
num_records_per_shard = 10
568+
569+
for offset in range(1, num_records_per_shard + 1):
570+
for shard in range(num_shards):
571+
event = Event(shard)
572+
event.shard_id = shard
573+
event.offset = offset
574+
await controller.emit(event)
575+
del event
576+
577+
with pytest.raises(ATestException):
578+
await controller.terminate(wait=True)
579+
580+
assert platform.offsets == {("/", shard): 10 for shard in range(num_shards)}
581+
582+
583+
# ML-11919
584+
def test_async_offset_commit_error_on_termination():
585+
asyncio.run(async_offset_commit_error_on_termination())
586+
587+
547588
def test_multiple_upstreams():
548589
source = SyncEmitSource()
549590
map1 = Map(lambda x: x + 1)

0 commit comments

Comments
 (0)