Skip to content

Commit 6d01709

Browse files
authored
Tiny refactor and extract RolloutHealthMonitor (#465)
1 parent add8e8c commit 6d01709

File tree

2 files changed

+88
-72
lines changed

2 files changed

+88
-72
lines changed

slime/ray/rollout.py

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import multiprocessing
33
import random
4-
import threading
54
import time
65
from pathlib import Path
76
from typing import List, Union
@@ -13,6 +12,7 @@
1312

1413
from slime.backends.sglang_utils.sglang_engine import SGLangEngine
1514
from slime.ray.rollout_data_source import RolloutDataSourceWithBuffer
15+
from slime.utils.health_monitor import RolloutHealthMonitor
1616
from slime.utils.http_utils import find_available_port, get_host_info, init_http_client
1717
from slime.utils.metric_checker import MetricChecker
1818
from slime.utils.misc import load_function
@@ -63,13 +63,7 @@ def __init__(self, args, pg, wandb_run_id):
6363
self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote()
6464

6565
self._metric_checker = MetricChecker.maybe_create(args)
66-
67-
# fault tolerance
68-
self._health_monitor_thread = None
69-
self._health_monitor_stop_event = None
70-
self._health_check_interval = args.rollout_health_check_interval
71-
self._health_check_timeout = args.rollout_health_check_timeout
72-
self._health_check_first_wait = args.rollout_health_check_first_wait
66+
self._health_monitor = RolloutHealthMonitor(self, args)
7367

7468
def dispose(self):
7569
if self._metric_checker is not None:
@@ -83,7 +77,7 @@ def get_num_rollout_per_epoch(self):
8377
return len(self.data_source.dataset) // self.args.rollout_batch_size
8478

8579
def generate(self, rollout_id):
86-
monitor_started = self._start_health_monitor()
80+
monitor_started = self._health_monitor.start()
8781
start_time = time.time()
8882
try:
8983
data = self._get_rollout_data(rollout_id=rollout_id)
@@ -93,7 +87,7 @@ def generate(self, rollout_id):
9387
return Box(ray.put(data))
9488
finally:
9589
if monitor_started:
96-
self._stop_health_monitor()
90+
self._health_monitor.stop()
9791
self.num_new_engines = init_rollout_engines(self.args, self.pg, self.all_rollout_engines)
9892
self.rollout_engines = self.all_rollout_engines[:: self.nodes_per_engine]
9993

@@ -119,68 +113,6 @@ def offload(self):
119113
def onload(self, tags: List[str] = None):
120114
return [engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]
121115

122-
def _start_health_monitor(self) -> bool:
123-
if not self.rollout_engines:
124-
return False
125-
126-
assert self._health_monitor_thread is None, "Health monitor thread is already running."
127-
128-
self._health_monitor_stop_event = threading.Event()
129-
self._health_monitor_thread = threading.Thread(
130-
target=self._health_monitor_loop,
131-
name="RolloutHealthMonitor",
132-
daemon=True,
133-
)
134-
self._health_monitor_thread.start()
135-
return True
136-
137-
def _stop_health_monitor(self) -> None:
138-
if not self._health_monitor_thread:
139-
return
140-
141-
assert self._health_monitor_stop_event is not None
142-
self._health_monitor_stop_event.set()
143-
timeout = self._health_check_timeout + self._health_check_interval + 5
144-
self._health_monitor_thread.join(timeout=timeout)
145-
if self._health_monitor_thread.is_alive():
146-
logging.warning("Rollout health monitor thread did not terminate within %.1fs", timeout)
147-
148-
self._health_monitor_thread = None
149-
self._health_monitor_stop_event = None
150-
151-
def _health_monitor_loop(self) -> None:
152-
assert self._health_monitor_stop_event is not None
153-
# TODO: need to be waiting for the large moe to be ready. this is hacky.
154-
if self._health_monitor_stop_event.wait(self._health_check_first_wait):
155-
return
156-
while not self._health_monitor_stop_event.is_set():
157-
self._run_health_checks()
158-
if self._health_monitor_stop_event.wait(self._health_check_interval):
159-
break
160-
161-
def _run_health_checks(self) -> None:
162-
for rollout_engine_id, engine in enumerate(self.rollout_engines):
163-
if self._health_monitor_stop_event is not None and self._health_monitor_stop_event.is_set():
164-
break
165-
self._check_engine_health(rollout_engine_id, engine)
166-
167-
def _check_engine_health(self, rollout_engine_id, engine) -> None:
168-
if engine is None:
169-
return
170-
171-
try:
172-
ray.get(engine.health_generate.remote(timeout=self._health_check_timeout))
173-
except Exception as e:
174-
print(f"Health check timed out for rollout engine {rollout_engine_id} (ray timeout). Killing actor.")
175-
for i in range(rollout_engine_id * self.nodes_per_engine, (rollout_engine_id + 1) * self.nodes_per_engine):
176-
engine = self.all_rollout_engines[i]
177-
try:
178-
ray.kill(engine)
179-
except Exception:
180-
pass
181-
self.all_rollout_engines[i] = None
182-
self.rollout_engines[rollout_engine_id] = None
183-
184116
def _get_rollout_data(self, rollout_id):
185117
if self.args.load_debug_rollout_data:
186118
data = torch.load(

slime/utils/health_monitor.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import logging
2+
import threading
3+
4+
import ray
5+
6+
7+
class RolloutHealthMonitor:
8+
def __init__(self, rollout_manager, args):
9+
# TODO may remove this dependency after refactoring
10+
self._rollout_manager = rollout_manager
11+
12+
self._thread = None
13+
self._stop_event = None
14+
self._check_interval = args.rollout_health_check_interval
15+
self._check_timeout = args.rollout_health_check_timeout
16+
self._check_first_wait = args.rollout_health_check_first_wait
17+
18+
def start(self) -> bool:
19+
if not self._rollout_manager.rollout_engines:
20+
return False
21+
22+
assert self._thread is None, "Health monitor thread is already running."
23+
24+
self._stop_event = threading.Event()
25+
self._thread = threading.Thread(
26+
target=self._health_monitor_loop,
27+
name="RolloutHealthMonitor",
28+
daemon=True,
29+
)
30+
self._thread.start()
31+
return True
32+
33+
def stop(self) -> None:
34+
if not self._thread:
35+
return
36+
37+
assert self._stop_event is not None
38+
self._stop_event.set()
39+
timeout = self._check_timeout + self._check_interval + 5
40+
self._thread.join(timeout=timeout)
41+
if self._thread.is_alive():
42+
logging.warning("Rollout health monitor thread did not terminate within %.1fs", timeout)
43+
44+
self._thread = None
45+
self._stop_event = None
46+
47+
def _health_monitor_loop(self) -> None:
48+
assert self._stop_event is not None
49+
# TODO: need to be waiting for the large moe to be ready. this is hacky.
50+
if self._stop_event.wait(self._check_first_wait):
51+
return
52+
while not self._stop_event.is_set():
53+
self._run_health_checks()
54+
if self._stop_event.wait(self._check_interval):
55+
break
56+
57+
def _run_health_checks(self) -> None:
58+
for rollout_engine_id, engine in enumerate(self._rollout_manager.rollout_engines):
59+
if self._stop_event is not None and self._stop_event.is_set():
60+
break
61+
self._check_engine_health(rollout_engine_id, engine)
62+
63+
def _check_engine_health(self, rollout_engine_id, engine) -> None:
64+
if engine is None:
65+
return
66+
67+
try:
68+
ray.get(engine.health_generate.remote(timeout=self._check_timeout))
69+
except Exception as e:
70+
print(f"Health check timed out for rollout engine {rollout_engine_id} (ray timeout). Killing actor.")
71+
self._kill_engine(rollout_engine_id=rollout_engine_id)
72+
73+
def _kill_engine(self, rollout_engine_id: int):
74+
for i in range(
75+
rollout_engine_id * self._rollout_manager.nodes_per_engine,
76+
(rollout_engine_id + 1) * self._rollout_manager.nodes_per_engine,
77+
):
78+
engine = self._rollout_manager.all_rollout_engines[i]
79+
try:
80+
ray.kill(engine)
81+
except Exception as e:
82+
print(f"Fail to kill engine and skip (e: {e})")
83+
self._rollout_manager.all_rollout_engines[i] = None
84+
self._rollout_manager.rollout_engines[rollout_engine_id] = None

0 commit comments

Comments
 (0)