Skip to content

Commit d71bf78

Browse files
alxtkr77Alex Toker
andauthored
Reset _closeables in _init for flow reuse (#597)
* Reset _closeables in _init for flow reuse * Add unit tests for _closeables reset in _init for flow reuse Tests verify that _closeables is properly reset when flows are reused: - test_flow_reuse_resets_closeables: Base Flow class resets to [] - test_aggregate_by_key_reuse_resets_closeables: AggregateByKey sets [table] - test_map_with_state_reuse_resets_closeables: MapWithState sets [state] for Table - test_map_with_state_no_closeables_without_close_method: No closeables for dict state - test_nosql_target_reuse_resets_closeables: NoSqlTarget sets [table] * Fix unit tests to properly detect ML-11518 bug and add clarifying comments Updated tests to check source._closeables instead of downstream step's _closeables, as the bug manifests in upstream steps. Added comments clarifying which tests detect the ML-11518 regression (3 tests) vs baseline/edge case tests (2 tests). Verified: - 3 tests FAIL on commit 74fec1a (before fix) - correctly detecting bug - All 5 tests PASS on commit 321dfec (with fix) and current branch * Simplify ML-11518 regression tests based on code review - Replace dynamic closeables tracking with direct assertions - Add try/finally blocks to prevent test hangs on failure - Remove PR-specific docstring notes that don't apply post-merge - Move AggregateByKey test to test_aggregate_by_key.py --------- Co-authored-by: Alex Toker <alext@mckinsey.com>
1 parent 0bfaa08 commit d71bf78

File tree

5 files changed

+156
-2
lines changed

5 files changed

+156
-2
lines changed

storey/aggregations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __init__(
8989
raise TypeError("Table can not be string if no context was provided to the step")
9090
self._table = self.context.get_table(table)
9191
self._table._set_aggregation_metadata(aggregates, use_windows_from_schema=use_windows_from_schema)
92-
self._closeables = [self._table]
9392

9493
self._aggregates_metadata = aggregates
9594

@@ -141,6 +140,7 @@ def f(element, features):
141140

142141
def _init(self):
143142
super()._init()
143+
self._closeables = [self._table]
144144
self._events_in_batch = {}
145145
self._emit_worker_running = False
146146
self._terminate_worker = False

storey/flow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
self._create_name_to_outlet = True
8181

8282
def _init(self):
83+
self._closeables = []
8384
self._termination_received = 0
8485
self._termination_result = None
8586
self._name_to_outlet = {}
@@ -654,6 +655,9 @@ def __init__(self, initial_state, fn, group_by_key=False, **kwargs):
654655
self._state = self.context.get_table(self._state)
655656
self._fn = fn
656657
self._group_by_key = group_by_key
658+
659+
def _init(self):
660+
super()._init()
657661
if hasattr(self._state, "close"):
658662
self._closeables = [self._state]
659663

storey/targets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1428,14 +1428,14 @@ def __init__(
14281428
if not self.context:
14291429
raise TypeError("Table can not be string if no context was provided to the step")
14301430
self._table = self.context.get_table(table)
1431-
self._closeables = [self._table]
14321431

14331432
self._field_extractor = lambda event_body, field_name: event_body.get(field_name)
14341433
self._write_missing_fields = False
14351434

14361435
def _init(self):
14371436
Flow._init(self)
14381437
_Writer._init(self)
1438+
self._closeables = [self._table]
14391439

14401440
async def _handle_completed(self, event, response):
14411441
await self._do_downstream(event)

tests/test_aggregate_by_key.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3977,3 +3977,36 @@ def maybe_parse(t):
39773977
assert (
39783978
termination_result == expected
39793979
), f"actual did not match expected. \n actual: {termination_result} \n expected: {expected}"
3980+
3981+
3982+
def test_aggregate_by_key_reuse_resets_closeables():
3983+
"""Test that AggregateByKey properly resets _closeables on flow reuse.
3984+
3985+
Regression test for ML-11518: Without resetting _closeables in _init(), closeables
3986+
accumulate in the source's _closeables list across runs.
3987+
"""
3988+
table = Table("test", NoopDriver())
3989+
aggregator = AggregateByKey(
3990+
[FieldAggregator("sum_col1", "col1", ["sum"], FixedWindows(["1h"]))],
3991+
table,
3992+
time_field="time",
3993+
)
3994+
3995+
source = SyncEmitSource()
3996+
reduce_step = Reduce([], lambda acc, x: acc + [x])
3997+
3998+
source.to(aggregator).to(reduce_step)
3999+
4000+
base_time = datetime(2020, 7, 21, 12, 0, 0)
4001+
4002+
for _ in range(3):
4003+
controller = source.run()
4004+
try:
4005+
assert len(source._closeables) == 1
4006+
# Emit some data
4007+
for i in range(3):
4008+
data = {"col1": i, "time": base_time}
4009+
controller.emit(data, f"key{i}")
4010+
finally:
4011+
controller.terminate()
4012+
controller.await_termination()

tests/test_flow.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5636,3 +5636,120 @@ def test_flow_reuse_with_cycle():
56365636

56375637
controller.terminate()
56385638
controller.await_termination()
5639+
5640+
5641+
def test_flow_reuse_resets_closeables():
5642+
"""Test that _closeables is properly reset when a flow is reused.
5643+
5644+
Baseline test for flow reuse without closeables. This test uses steps (Map, Reduce)
5645+
that don't create closeables, so it verifies the baseline behavior.
5646+
"""
5647+
source = SyncEmitSource()
5648+
map_step = Map(lambda x: x + 1)
5649+
reduce_step = Reduce(0, lambda acc, x: acc + x)
5650+
5651+
source.to(map_step).to(reduce_step)
5652+
5653+
for run_num in range(3):
5654+
controller = source.run()
5655+
try:
5656+
assert len(source._closeables) == 0
5657+
for i in range(5):
5658+
controller.emit(i)
5659+
finally:
5660+
controller.terminate()
5661+
result = controller.await_termination()
5662+
assert result == 15, f"Run {run_num}: Expected 15 but got {result}"
5663+
5664+
5665+
def test_map_with_state_reuse_resets_closeables():
5666+
"""Test that MapWithState properly resets _closeables on flow reuse.
5667+
5668+
Regression test for ML-11518: Without resetting _closeables in _init(), closeables
5669+
accumulate across runs, causing close() to be called multiple times per run.
5670+
"""
5671+
5672+
class CloseCounter:
5673+
def __init__(self):
5674+
self.close_count = 0
5675+
5676+
def close(self):
5677+
self.close_count += 1
5678+
5679+
state = CloseCounter()
5680+
source = SyncEmitSource()
5681+
map_with_state = MapWithState(state, lambda x, s: (x, s))
5682+
reduce_step = Reduce([], lambda acc, x: acc + [x])
5683+
5684+
source.to(map_with_state).to(reduce_step)
5685+
5686+
for run_num in range(3):
5687+
initial_close_count = state.close_count
5688+
controller = source.run()
5689+
try:
5690+
controller.emit(1)
5691+
finally:
5692+
controller.terminate()
5693+
controller.await_termination()
5694+
5695+
# close() should be called exactly once per run
5696+
closes_this_run = state.close_count - initial_close_count
5697+
assert closes_this_run == 1, (
5698+
f"Run {run_num}: state.close() should be called exactly once, "
5699+
f"but was called {closes_this_run} times (total: {state.close_count})"
5700+
)
5701+
5702+
5703+
def test_map_with_state_no_closeables_without_close_method():
5704+
"""Test that MapWithState doesn't add state to _closeables if it has no close method.
5705+
5706+
Edge case test: When using a plain dict as state (no close method), _closeables
5707+
should remain empty. This verifies the hasattr(self._state, "close") check works.
5708+
"""
5709+
initial_state = {"count": 0}
5710+
5711+
def state_fn(event, state):
5712+
state["count"] += 1
5713+
return event["value"] * state["count"], state
5714+
5715+
source = SyncEmitSource()
5716+
map_with_state = MapWithState(initial_state, state_fn, group_by_key=False)
5717+
reduce_step = Reduce(0, lambda acc, x: acc + x)
5718+
5719+
source.to(map_with_state).to(reduce_step)
5720+
5721+
for _ in range(3):
5722+
controller = source.run()
5723+
try:
5724+
# Dict has no close method, so _closeables should be empty
5725+
assert map_with_state._closeables == []
5726+
for i in range(3):
5727+
controller.emit({"value": i + 1})
5728+
finally:
5729+
controller.terminate()
5730+
controller.await_termination()
5731+
5732+
5733+
def test_nosql_target_reuse_resets_closeables():
5734+
"""Test that NoSqlTarget properly resets _closeables on flow reuse.
5735+
5736+
Regression test for ML-11518: Without resetting _closeables in _init(), closeables
5737+
accumulate in the source's _closeables list across runs.
5738+
"""
5739+
table = Table("test_nosql", NoopDriver())
5740+
5741+
source = SyncEmitSource()
5742+
nosql_target = NoSqlTarget(table)
5743+
5744+
source.to(nosql_target)
5745+
5746+
for _ in range(3):
5747+
controller = source.run()
5748+
try:
5749+
assert len(source._closeables) == 1
5750+
# Emit some data with keys
5751+
for i in range(3):
5752+
controller.emit(Event(body={"col": i}, key=f"key{i}"))
5753+
finally:
5754+
controller.terminate()
5755+
controller.await_termination()

0 commit comments

Comments
 (0)