Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions steptronoss/core/generators/flow_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,18 @@ def _generation_worker(self):

def flow_callback(genable: GenableItem, generated: dict):
nonlocal active_counter
if isinstance(generated, Exception):
if not self.cfg.genable_allow_errors:
raise generated
else: # ignore this
generated = []
try:
if isinstance(generated, Exception):
if not self.cfg.genable_allow_errors:
raise generated
else: # ignore this
generated = []

with self.flow.lock:
self.flow["pre_train"].put(generated)
self.flow["pre_gen"].ack(genable.meta.pop("ack_pre_id"))
active_counter -= 1
with self.flow.lock:
self.flow["pre_train"].put(generated)
self.flow["pre_gen"].ack(genable.meta.pop("ack_pre_id"))
finally:
active_counter -= 1

while 1:
ack_id, data = self.flow["pre_gen"].get()
Expand Down Expand Up @@ -430,16 +432,22 @@ def _control_worker(self):

def _generation_worker(self):
def flow_callback(genable: GenableItem, generated: dict):
if isinstance(generated, Exception):
if not self.cfg.genable_allow_errors:
raise generated
generated = []
try:
if isinstance(generated, Exception):
if not self.cfg.genable_allow_errors:
raise generated
generated = []

with self._cv:
self.flow["pre_train"].put(generated)
self.flow["pre_gen"].ack(genable.meta.pop("ack_pre_id"))
self.running_genables.pop(genable.meta.pop("running_ack_id"), None)
self._cv.notify_all()
with self._cv:
self.flow["pre_train"].put(generated)
self.flow["pre_gen"].ack(genable.meta.pop("ack_pre_id"))
self.running_genables.pop(genable.meta.pop("running_ack_id"), None)
self._cv.notify_all()
except Exception:
with self._cv:
self.running_genables.pop(genable.meta.get("running_ack_id"), None)
self._cv.notify_all()
raise

while True:
with self._cv:
Expand Down
129 changes: 129 additions & 0 deletions tests/generators/test_flow_controller_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ def submit_with_callback(self, genable, for_train=False, callback=None, task_id=
self._queue.put((genable, callback))


class ErrorGenerationController(FakeGenerationController):
"""Generation controller that returns errors for specific item IDs."""

def __init__(self, error_ids: set[int], **kwargs):
super().__init__(**kwargs)
self.error_ids = error_ids

def _worker(self):
while True:
genable, callback = self._queue.get()
time.sleep(genable.delay_s)
if genable.item_id in self.error_ids:
callback(genable, RuntimeError(f"generation failed for {genable.item_id}"))
else:
callback(
genable,
[
EnvTrajectory(
trajectory=[genable.item_id],
logprobs=[0.0],
is_gen_mask=[True],
meta={"item_id": genable.item_id},
stop_type=0,
raw_reward=1.0,
)
],
)


class DummyVLLMClient:
def wait_for_server(self):
return None
Expand Down Expand Up @@ -258,6 +287,106 @@ def test_one_step_off_runtime_matches_simulator(monkeypatch):
assert runtime_batches == [list(snapshot.yielded_ids) for snapshot in simulated.yield_snapshots]


def test_fully_async_error_with_allow_errors_does_not_hang(monkeypatch):
"""When genable_allow_errors=True and a generation fails, the pipeline
should skip the failed item and continue without hanging."""
cfg = build_cfg("fully-async")
cfg.prompt_per_iter = 2
cfg.max_untrained_prompts = 4
cfg.max_staleness = 2
cfg.max_concurrent_genables = 2
cfg.genable_allow_errors = True
cfg.vllm_cfg.build_cli = lambda: DummyVLLMClient()
cfg.vllm_cfg.deploy_training_model = lambda model: None

error_ids = {1}
monkeypatch.setattr(
"steptronoss.core.generators.flow_controller.GenerationController",
lambda **kwargs: ErrorGenerationController(error_ids=error_ids, **kwargs),
)

controller = FullyAsyncFlowController(flow_cfg=cfg)
controller.start(
dataloader=FakeNextable([FakeTrainableItem(i, 0.01) for i in range(4)]),
model=[],
)

# Should complete within timeout — no hang
batch0 = controller.get_train_samples()
controller.weight_dumped()
batch1 = controller.get_train_samples()
controller.weight_dumped()

all_ids = [traj.meta["item_id"] for traj in batch0 + batch1]
# Item 1 failed, so its trajectories are empty — remaining items should appear
assert 1 not in all_ids
assert 0 in all_ids


def test_fully_async_error_cleans_up_running_genables(monkeypatch):
"""When a generation error occurs, running_genables should be cleaned up
so staleness checks don't block forever."""
cfg = build_cfg("fully-async")
cfg.prompt_per_iter = 1
cfg.max_untrained_prompts = 4
cfg.max_staleness = 1
cfg.max_concurrent_genables = 2
cfg.genable_allow_errors = True
cfg.vllm_cfg.build_cli = lambda: DummyVLLMClient()
cfg.vllm_cfg.deploy_training_model = lambda model: None

error_ids = {0}
monkeypatch.setattr(
"steptronoss.core.generators.flow_controller.GenerationController",
lambda **kwargs: ErrorGenerationController(error_ids=error_ids, **kwargs),
)

controller = FullyAsyncFlowController(flow_cfg=cfg)
controller.start(
dataloader=FakeNextable([FakeTrainableItem(i, 0.01) for i in range(3)]),
model=[],
)

# Collect multiple batches — should not hang due to stale running_genables
for _ in range(3):
controller.get_train_samples()
controller.weight_dumped()

# All running_genables should have been cleaned up
assert len(controller.running_genables) == 0


def test_simple_controller_error_with_allow_errors_does_not_hang(monkeypatch):
"""SimpleFlowController: when genable_allow_errors=True and a generation
fails, active_counter should still decrement so the pipeline doesn't hang."""
cfg = build_cfg("one-step-off")
cfg.prompt_per_iter = 2
cfg.max_concurrent_genables = 2
cfg.genable_allow_errors = True
cfg.vllm_cfg.build_cli = lambda: DummyVLLMClient()
cfg.vllm_cfg.deploy_training_model = lambda model: None

error_ids = {1}
monkeypatch.setattr(
"steptronoss.core.generators.flow_controller.GenerationController",
lambda **kwargs: ErrorGenerationController(error_ids=error_ids, **kwargs),
)

controller = cfg.build_flow_controller()
controller.start(
dataloader=FakeNextable([FakeTrainableItem(i, 0.01) for i in range(4)]),
model=[],
)

# Should complete within timeout — no hang from active_counter leak
batch0 = controller.get_train_samples()
controller.weight_dumped()

all_ids = [traj.meta["item_id"] for traj in batch0]
assert 1 not in all_ids
assert 0 in all_ids


def test_fully_async_runtime_matches_simulator(monkeypatch):
cfg = build_cfg("fully-async")
cfg.prompt_per_iter = 2
Expand Down
Loading