Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion storey/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(
raise TypeError("Table can not be string if no context was provided to the step")
self._table = self.context.get_table(table)
self._table._set_aggregation_metadata(aggregates, use_windows_from_schema=use_windows_from_schema)
self._closeables = [self._table]

self._aggregates_metadata = aggregates

Expand Down Expand Up @@ -141,6 +140,7 @@ def f(element, features):

def _init(self):
super()._init()
self._closeables = [self._table]
self._events_in_batch = {}
self._emit_worker_running = False
self._terminate_worker = False
Expand Down
4 changes: 4 additions & 0 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self._create_name_to_outlet = True

def _init(self):
self._closeables = []
self._termination_received = 0
self._termination_result = None
self._name_to_outlet = {}
Expand Down Expand Up @@ -654,6 +655,9 @@ def __init__(self, initial_state, fn, group_by_key=False, **kwargs):
self._state = self.context.get_table(self._state)
self._fn = fn
self._group_by_key = group_by_key

def _init(self):
super()._init()
if hasattr(self._state, "close"):
self._closeables = [self._state]

Expand Down
2 changes: 1 addition & 1 deletion storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,14 +1428,14 @@ def __init__(
if not self.context:
raise TypeError("Table can not be string if no context was provided to the step")
self._table = self.context.get_table(table)
self._closeables = [self._table]

self._field_extractor = lambda event_body, field_name: event_body.get(field_name)
self._write_missing_fields = False

def _init(self):
Flow._init(self)
_Writer._init(self)
self._closeables = [self._table]

async def _handle_completed(self, event, response):
await self._do_downstream(event)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_aggregate_by_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3977,3 +3977,36 @@ def maybe_parse(t):
assert (
termination_result == expected
), f"actual did not match expected. \n actual: {termination_result} \n expected: {expected}"


def test_aggregate_by_key_reuse_resets_closeables():
"""Test that AggregateByKey properly resets _closeables on flow reuse.

Regression test for ML-11518: Without resetting _closeables in _init(), closeables
accumulate in the source's _closeables list across runs.
"""
table = Table("test", NoopDriver())
aggregator = AggregateByKey(
[FieldAggregator("sum_col1", "col1", ["sum"], FixedWindows(["1h"]))],
table,
time_field="time",
)

source = SyncEmitSource()
reduce_step = Reduce([], lambda acc, x: acc + [x])

source.to(aggregator).to(reduce_step)

base_time = datetime(2020, 7, 21, 12, 0, 0)

for _ in range(3):
controller = source.run()
try:
assert len(source._closeables) == 1
# Emit some data
for i in range(3):
data = {"col1": i, "time": base_time}
controller.emit(data, f"key{i}")
finally:
controller.terminate()
controller.await_termination()
117 changes: 117 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5636,3 +5636,120 @@ def test_flow_reuse_with_cycle():

controller.terminate()
controller.await_termination()


def test_flow_reuse_resets_closeables():
"""Test that _closeables is properly reset when a flow is reused.

Baseline test for flow reuse without closeables. This test uses steps (Map, Reduce)
that don't create closeables, so it verifies the baseline behavior.
"""
source = SyncEmitSource()
map_step = Map(lambda x: x + 1)
reduce_step = Reduce(0, lambda acc, x: acc + x)

source.to(map_step).to(reduce_step)

for run_num in range(3):
controller = source.run()
try:
assert len(source._closeables) == 0
for i in range(5):
controller.emit(i)
finally:
controller.terminate()
result = controller.await_termination()
assert result == 15, f"Run {run_num}: Expected 15 but got {result}"


def test_map_with_state_reuse_resets_closeables():
"""Test that MapWithState properly resets _closeables on flow reuse.

Regression test for ML-11518: Without resetting _closeables in _init(), closeables
accumulate across runs, causing close() to be called multiple times per run.
"""

class CloseCounter:
def __init__(self):
self.close_count = 0

def close(self):
self.close_count += 1

state = CloseCounter()
source = SyncEmitSource()
map_with_state = MapWithState(state, lambda x, s: (x, s))
reduce_step = Reduce([], lambda acc, x: acc + [x])

source.to(map_with_state).to(reduce_step)

for run_num in range(3):
initial_close_count = state.close_count
controller = source.run()
try:
controller.emit(1)
finally:
controller.terminate()
controller.await_termination()

# close() should be called exactly once per run
closes_this_run = state.close_count - initial_close_count
assert closes_this_run == 1, (
f"Run {run_num}: state.close() should be called exactly once, "
f"but was called {closes_this_run} times (total: {state.close_count})"
)


def test_map_with_state_no_closeables_without_close_method():
"""Test that MapWithState doesn't add state to _closeables if it has no close method.

Edge case test: When using a plain dict as state (no close method), _closeables
should remain empty. This verifies the hasattr(self._state, "close") check works.
"""
initial_state = {"count": 0}

def state_fn(event, state):
state["count"] += 1
return event["value"] * state["count"], state

source = SyncEmitSource()
map_with_state = MapWithState(initial_state, state_fn, group_by_key=False)
reduce_step = Reduce(0, lambda acc, x: acc + x)

source.to(map_with_state).to(reduce_step)

for _ in range(3):
controller = source.run()
try:
# Dict has no close method, so _closeables should be empty
assert map_with_state._closeables == []
for i in range(3):
controller.emit({"value": i + 1})
finally:
controller.terminate()
controller.await_termination()


def test_nosql_target_reuse_resets_closeables():
"""Test that NoSqlTarget properly resets _closeables on flow reuse.

Regression test for ML-11518: Without resetting _closeables in _init(), closeables
accumulate in the source's _closeables list across runs.
"""
table = Table("test_nosql", NoopDriver())

source = SyncEmitSource()
nosql_target = NoSqlTarget(table)

source.to(nosql_target)

for _ in range(3):
controller = source.run()
try:
assert len(source._closeables) == 1
# Emit some data with keys
for i in range(3):
controller.emit(Event(body={"col": i}, key=f"key{i}"))
finally:
controller.terminate()
controller.await_termination()