11import logging
22import multiprocessing
33import random
4- import threading
54import time
65from pathlib import Path
76from typing import List , Union
1312
1413from slime .backends .sglang_utils .sglang_engine import SGLangEngine
1514from slime .ray .rollout_data_source import RolloutDataSourceWithBuffer
15+ from slime .utils .health_monitor import RolloutHealthMonitor
1616from slime .utils .http_utils import find_available_port , get_host_info , init_http_client
1717from slime .utils .metric_checker import MetricChecker
1818from slime .utils .misc import load_function
@@ -63,13 +63,7 @@ def __init__(self, args, pg, wandb_run_id):
6363 self .rollout_engine_lock = Lock .options (num_cpus = 1 , num_gpus = 0 ).remote ()
6464
6565 self ._metric_checker = MetricChecker .maybe_create (args )
66-
67- # fault tolerance
68- self ._health_monitor_thread = None
69- self ._health_monitor_stop_event = None
70- self ._health_check_interval = args .rollout_health_check_interval
71- self ._health_check_timeout = args .rollout_health_check_timeout
72- self ._health_check_first_wait = args .rollout_health_check_first_wait
66+ self ._health_monitor = RolloutHealthMonitor (self , args )
7367
7468 def dispose (self ):
7569 if self ._metric_checker is not None :
@@ -83,7 +77,7 @@ def get_num_rollout_per_epoch(self):
8377 return len (self .data_source .dataset ) // self .args .rollout_batch_size
8478
8579 def generate (self , rollout_id ):
86- monitor_started = self ._start_health_monitor ()
80+ monitor_started = self ._health_monitor . start ()
8781 start_time = time .time ()
8882 try :
8983 data = self ._get_rollout_data (rollout_id = rollout_id )
@@ -93,7 +87,7 @@ def generate(self, rollout_id):
9387 return Box (ray .put (data ))
9488 finally :
9589 if monitor_started :
96- self ._stop_health_monitor ()
90+ self ._health_monitor . stop ()
9791 self .num_new_engines = init_rollout_engines (self .args , self .pg , self .all_rollout_engines )
9892 self .rollout_engines = self .all_rollout_engines [:: self .nodes_per_engine ]
9993
@@ -119,68 +113,6 @@ def offload(self):
119113 def onload (self , tags : List [str ] = None ):
120114 return [engine .resume_memory_occupation .remote (tags = tags ) for engine in self .rollout_engines ]
121115
122- def _start_health_monitor (self ) -> bool :
123- if not self .rollout_engines :
124- return False
125-
126- assert self ._health_monitor_thread is None , "Health monitor thread is already running."
127-
128- self ._health_monitor_stop_event = threading .Event ()
129- self ._health_monitor_thread = threading .Thread (
130- target = self ._health_monitor_loop ,
131- name = "RolloutHealthMonitor" ,
132- daemon = True ,
133- )
134- self ._health_monitor_thread .start ()
135- return True
136-
137- def _stop_health_monitor (self ) -> None :
138- if not self ._health_monitor_thread :
139- return
140-
141- assert self ._health_monitor_stop_event is not None
142- self ._health_monitor_stop_event .set ()
143- timeout = self ._health_check_timeout + self ._health_check_interval + 5
144- self ._health_monitor_thread .join (timeout = timeout )
145- if self ._health_monitor_thread .is_alive ():
146- logging .warning ("Rollout health monitor thread did not terminate within %.1fs" , timeout )
147-
148- self ._health_monitor_thread = None
149- self ._health_monitor_stop_event = None
150-
151- def _health_monitor_loop (self ) -> None :
152- assert self ._health_monitor_stop_event is not None
153- # TODO: need to be waiting for the large moe to be ready. this is hacky.
154- if self ._health_monitor_stop_event .wait (self ._health_check_first_wait ):
155- return
156- while not self ._health_monitor_stop_event .is_set ():
157- self ._run_health_checks ()
158- if self ._health_monitor_stop_event .wait (self ._health_check_interval ):
159- break
160-
161- def _run_health_checks (self ) -> None :
162- for rollout_engine_id , engine in enumerate (self .rollout_engines ):
163- if self ._health_monitor_stop_event is not None and self ._health_monitor_stop_event .is_set ():
164- break
165- self ._check_engine_health (rollout_engine_id , engine )
166-
167- def _check_engine_health (self , rollout_engine_id , engine ) -> None :
168- if engine is None :
169- return
170-
171- try :
172- ray .get (engine .health_generate .remote (timeout = self ._health_check_timeout ))
173- except Exception as e :
174- print (f"Health check timed out for rollout engine { rollout_engine_id } (ray timeout). Killing actor." )
175- for i in range (rollout_engine_id * self .nodes_per_engine , (rollout_engine_id + 1 ) * self .nodes_per_engine ):
176- engine = self .all_rollout_engines [i ]
177- try :
178- ray .kill (engine )
179- except Exception :
180- pass
181- self .all_rollout_engines [i ] = None
182- self .rollout_engines [rollout_engine_id ] = None
183-
184116 def _get_rollout_data (self , rollout_id ):
185117 if self .args .load_debug_rollout_data :
186118 data = torch .load (
0 commit comments