Skip to content

Commit e07bde4

Browse files
authored
Surface RLJob errors to Ray, fail the job accordingly (#2311)
1 parent 3f9b0ec commit e07bde4

3 files changed

Lines changed: 21 additions & 6 deletions

File tree

lib/marin/src/marin/rl/rollout_worker.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,13 +629,26 @@ def run(self):
629629

630630
# For inflight weight updates, wait for first weights before generating rollouts
631631
if self.config.inflight_weight_updates:
632-
logger.info("Waiting for first weight transfer before starting inference...")
633-
while not self._first_weights_received.wait(timeout=10.0):
632+
max_wait_time = self.config.weight_transfer.max_weight_transfer_wait_time
633+
logger.info(
634+
"Waiting for first weight transfer before starting inference (timeout %.1fs)...",
635+
max_wait_time,
636+
)
637+
start_time = time.time()
638+
while True:
639+
if self._first_weights_received.wait(timeout=10.0):
640+
break
641+
634642
if not self._running:
635643
logger.info("Shutdown requested while waiting for first weights")
636644
self._shutdown_complete.set()
637645
return
638-
logger.info("Still waiting for first weight transfer...")
646+
647+
elapsed = time.time() - start_time
648+
if max_wait_time - elapsed <= 0:
649+
raise RuntimeError("Timed out waiting for initial weight transfer.")
650+
651+
logger.info("Still waiting for first weight transfer (elapsed: %.1fs)", elapsed)
639652
logger.info("First weights received, starting inference loop")
640653

641654
step = 0

lib/marin/src/marin/rl/train_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ def _loss_function(model, batch, key):
308308
self.replay_buffer.set_current_step(-1)
309309

310310
# Wait for initial rollouts to ensure we have baseline measurements
311-
self._wait_for_initial_rollouts()
311+
if not self._wait_for_initial_rollouts():
312+
raise RuntimeError("Timed out waiting for initial rollouts; aborting training.")
312313

313314
self._configure_training_hooks(trainer)
314315

lib/marin/src/marin/rl/weight_transfer/arrow_flight.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,10 @@ def serve_weights(self, weight_id: int, model: PyTree) -> None:
435435

436436
barrier_sync()
437437

438-
except Exception as e:
438+
except Exception:
439439
self.metrics.failed_transfers += 1
440-
logger.error(f"Failed to serve weights {weight_id} via Arrow Flight: {e}")
440+
logger.exception(f"Failed to serve weights {weight_id} via Arrow Flight")
441+
raise
441442

442443
def cleanup(self) -> None:
443444
"""Cleanup Flight server resources."""

0 commit comments

Comments
 (0)