Skip to content

Commit ee998ab

Browse files
authored
Add API to terminate and await termination (#549)
To make it easier for the caller to do both regardless of whether source is sync or async. [ML-8738](https://iguazio.atlassian.net/browse/ML-8738)
1 parent a00e997 commit ee998ab

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

storey/sources.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,17 @@ def emit(
175175
self._emit_fn(event)
176176
return awaitable_result
177177

178-
def terminate(self):
179-
"""Terminates the associated flow."""
178+
def terminate(self, wait=False):
179+
"""
180+
Terminates the associated flow.
181+
182+
:param wait: Whether to wait for the flow to terminate before returning.
183+
184+
:returns: None if wait=False. If wait=True, the termination result will be returned.
185+
"""
180186
self._emit_fn(_termination_obj)
187+
if wait:
188+
return self._await_termination_fn()
181189

182190
def await_termination(self):
183191
"""Awaits the termination of the flow. To be called after terminate. Returns the termination result of the
@@ -482,9 +490,17 @@ async def emit(
482490
raise result
483491
return result
484492

485-
async def terminate(self):
486-
"""Terminates the associated flow."""
493+
async def terminate(self, wait=False):
494+
"""
495+
Terminates the associated flow.
496+
497+
:param wait: Whether to wait for the flow to terminate before returning.
498+
499+
:returns: None if wait=False. If wait=True, the termination result will be returned.
500+
"""
487501
await self._emit_fn(_termination_obj)
502+
if wait:
503+
return await self.await_termination()
488504

489505
async def await_termination(self):
490506
"""

tests/test_flow.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test_offset_commit():
185185
event.shard_id = shard
186186
event.offset = offset
187187
controller.emit(event)
188-
controller.terminate()
189-
termination_result = controller.await_termination()
188+
termination_result = controller.terminate(wait=True)
190189
assert termination_result == 330
191190

192191
offsets = copy.copy(platform.offsets)
@@ -225,9 +224,8 @@ async def async_offset_commit():
225224
try:
226225
assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)}
227226
finally:
228-
await controller.terminate()
227+
termination_result = await controller.terminate(wait=True)
229228

230-
termination_result = await controller.await_termination()
231229
assert termination_result == 330
232230

233231

0 commit comments

Comments
 (0)