Skip to content

Commit a0cd351

Browse files
author
anon
committed
fix(parallelism): mix sequence_id into parallel sub-app keys
1 parent a7e4489 commit a0cd351

2 files changed

Lines changed: 74 additions & 3 deletions

File tree

burr/core/parallelism.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def _create_task(key: str, action: Action, substate: State) -> SubGraphTask:
518518
def _tasks() -> Generator[SubGraphTask, None, None]:
519519
for i, action in enumerate(self.actions(state, context, inputs)):
520520
for j, substate in enumerate(self.states(state, context, inputs)):
521-
key = f"{i}-{j}" # this is a stable hash for now but will not handle caching
521+
key = f"{context.sequence_id}-{i}-{j}"
522522
yield _create_task(key, action, substate)
523523

524524
async def _atasks() -> AsyncGenerator[SubGraphTask, None]:
@@ -528,7 +528,7 @@ async def _atasks() -> AsyncGenerator[SubGraphTask, None]:
528528
states = await async_utils.arealize(state_generator)
529529
for i, action in enumerate(actions):
530530
for j, substate in enumerate(states):
531-
key = f"{i}-{j}"
531+
key = f"{context.sequence_id}-{i}-{j}"
532532
yield _create_task(key, action, substate)
533533

534534
return _atasks() if self.is_async() else _tasks()

tests/core/test_parallelism.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
_cascade_adapter,
4646
map_reduce_action,
4747
)
48-
from burr.core.persistence import BaseStateLoader, BaseStateSaver, PersistedStateData
48+
from burr.core.persistence import (
49+
BaseStateLoader,
50+
BaseStateSaver,
51+
InMemoryPersister,
52+
PersistedStateData,
53+
)
4954
from burr.tracking.base import SyncTrackingClient
5055
from burr.visibility import ActionSpan
5156

@@ -1227,3 +1232,69 @@ def reads(self) -> list[str]:
12271232
assert task.state_initializer is not None
12281233
assert task.tracker is not None
12291234
assert task.state_persister is task.state_initializer # This ensures they're the same
1235+
1236+
1237+
def test_map_states_no_stale_replay_on_repeated_invocation():
1238+
"""Regression test for #761.
1239+
1240+
When the parent app cascades a state initializer (via ``initialize_from``)
1241+
and a MapStates action is invoked more than once, each invocation must
1242+
spawn fresh sub-applications. Before the fix, sub-app ids were keyed only
1243+
on ``(i, j)``, so they collided across invocations and the cascaded
1244+
initializer replayed the prior call's persisted sub-state.
1245+
"""
1246+
1247+
@old_action(reads=["round"], writes=["output_number"])
1248+
def emit_round(state: State) -> State:
1249+
return state.update(output_number=state["round"])
1250+
1251+
@old_action(reads=["round"], writes=["round"])
1252+
def bump(state: State) -> State:
1253+
return state.update(round=state["round"] + 1)
1254+
1255+
class Fan(MapStates):
1256+
def action(self, state: State, inputs: Dict[str, Any]):
1257+
return emit_round
1258+
1259+
def states(
1260+
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
1261+
) -> Generator[State, None, None]:
1262+
for _ in range(3):
1263+
yield state
1264+
1265+
def reduce(self, state: State, states: Generator[State, None, None]) -> State:
1266+
new_state = state
1267+
for output_state in states:
1268+
new_state = new_state.append(outputs=output_state["output_number"])
1269+
return new_state
1270+
1271+
@property
1272+
def reads(self) -> list[str]:
1273+
return ["round"]
1274+
1275+
@property
1276+
def writes(self) -> list[str]:
1277+
return ["outputs"]
1278+
1279+
persister = InMemoryPersister()
1280+
app = (
1281+
ApplicationBuilder()
1282+
.with_actions(fan=Fan(), bump=bump)
1283+
.with_transitions(("fan", "bump"), ("bump", "fan"))
1284+
.with_state_persister(persister)
1285+
.with_identifiers(app_id="parent-app")
1286+
.initialize_from(
1287+
persister,
1288+
resume_at_next_action=True,
1289+
default_state={"round": 1, "outputs": []},
1290+
default_entrypoint="fan",
1291+
)
1292+
.build()
1293+
)
1294+
1295+
app.run(halt_after=["fan"]) # first fan invocation, round=1
1296+
app.run(halt_after=["fan"]) # bump runs, then second fan invocation, round=2
1297+
1298+
# Each fan invocation contributes 3 outputs. Fresh execution -> [1,1,1,2,2,2].
1299+
# Buggy behavior replays the first invocation's persisted sub-state -> [1,1,1,1,1,1].
1300+
assert list(app.state["outputs"]) == [1, 1, 1, 2, 2, 2]

0 commit comments

Comments
 (0)