Skip to content

Commit dde524e

Browse files
authored
[feat] support fault tolerant for rollout engines (THUDM#405)
* [feat] support fault tolerant for rollout engines * support fault tolerant for UpdateWeightFromDistributed * bugfix * bugfix
1 parent e9c677b commit dde524e

File tree

9 files changed

+216
-59
lines changed

9 files changed

+216
-59
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def init(self, args, role, wandb_run_id, with_ref: bool = False): # type: ignor
9999

100100
self.update_cpu_params_dict(self.weights["actor"])
101101

102-
self.connected = False
103102
self.weight_updator = (
104103
UpdateWeightFromTensor(self.args, self.model)
105104
if self.args.colocate
@@ -405,9 +404,10 @@ def update_weights(self): # type: ignore[override]
405404
if self.args.debug_train_only or self.args.debug_rollout_only:
406405
return
407406

408-
if not self.connected:
409-
self.connected = True
410-
rollout_engines, rollout_engine_lock = ray.get(self.rollout_manager.get_rollout_engines_and_lock.remote())
407+
rollout_engines, rollout_engine_lock, num_new_engines = ray.get(
408+
self.rollout_manager.get_rollout_engines_and_lock.remote()
409+
)
410+
if num_new_engines > 0:
411411
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
412412
dist.barrier(group=get_gloo_group())
413413

slime/backends/megatron_utils/actor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,10 @@ def update_weights(self):
401401
if self.args.offload and hasattr(mpu, "reload_process_groups"):
402402
mpu.reload_process_groups()
403403

404-
if not self.connected:
405-
self.connected = True
406-
rollout_engines, rollout_engine_lock = ray.get(self.rollout_manager.get_rollout_engines_and_lock.remote())
404+
rollout_engines, rollout_engine_lock, num_new_engines = ray.get(
405+
self.rollout_manager.get_rollout_engines_and_lock.remote()
406+
)
407+
if num_new_engines > 0:
407408
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
408409
dist.barrier(group=get_gloo_group())
409410

slime/backends/megatron_utils/update_weight_utils.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,17 @@ def __init__(self, args, model, weights, *, model_name, quantization_config, voc
305305
self.param_info_buckets = get_param_info_buckets(self.args, self.model)
306306
self.weight_version = 0
307307

308+
# create the group within megatron.
309+
for start_rank in range(0, dist.get_world_size(), self.args.rollout_num_gpus_per_engine):
310+
end_rank = start_rank + self.args.rollout_num_gpus_per_engine
311+
group_ranks = list(range(start_rank, end_rank))
312+
new_group = dist.new_group(ranks=group_ranks, backend="gloo")
313+
if dist.get_rank() in group_ranks:
314+
self._ipc_gather_group = new_group
315+
self._ipc_gather_src = start_rank
316+
317+
self._model_update_groups = None
318+
308319
def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
309320
self.rollout_engines = rollout_engines
310321
colocate_engine_nums = (
@@ -322,6 +333,11 @@ def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
322333
)
323334
self._group_name = "slime"
324335
if self._is_distributed_src_rank:
336+
if self._model_update_groups is not None:
337+
disconnect_rollout_engines_from_distributed(
338+
self.args, self._group_name, self._model_update_groups, self.distributed_rollout_engines
339+
)
340+
325341
self._model_update_groups = connect_rollout_engines_from_distributed(
326342
self.args, self._group_name, self.distributed_rollout_engines
327343
)
@@ -331,13 +347,7 @@ def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
331347
start_rank = i * self.args.rollout_num_gpus_per_engine
332348
end_rank = (i + 1) * self.args.rollout_num_gpus_per_engine
333349
group_ranks = list(range(start_rank, end_rank))
334-
new_group = dist.new_group(
335-
ranks=group_ranks,
336-
backend="gloo",
337-
)
338350
if dist.get_rank() in group_ranks:
339-
self._ipc_gather_src = start_rank
340-
self._ipc_gather_group = new_group
341351
self._ipc_engine = engine
342352

343353
@torch.no_grad()
@@ -496,6 +506,7 @@ def __init__(self, args, model, weights, *, model_name, quantization_config, voc
496506
self.vocab_size = vocab_size
497507
self.quantization_config = quantization_config
498508
self.weight_version = 0
509+
self._model_update_groups = None
499510

500511
def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
501512
self.rollout_engines = rollout_engines
@@ -512,6 +523,10 @@ def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
512523
self._group_name = f"slime-pp_{pp_rank}"
513524

514525
if self._is_pp_src_rank:
526+
if self._model_update_groups is not None:
527+
disconnect_rollout_engines_from_distributed(
528+
self.args, self._group_name, self._model_update_groups, self.rollout_engines
529+
)
515530
self._model_update_groups = connect_rollout_engines_from_distributed(
516531
self.args, self._group_name, rollout_engines
517532
)
@@ -670,6 +685,12 @@ def connect_rollout_engines_from_distributed(args, group_name, rollout_engines):
670685
return model_update_groups
671686

672687

688+
def disconnect_rollout_engines_from_distributed(args, group_name, model_update_groups, rollout_engines):
689+
refs = [engine.destroy_weights_update_group.remote(group_name) for engine in rollout_engines]
690+
dist.destroy_process_group(model_update_groups)
691+
ray.get(refs)
692+
693+
673694
def update_weights_from_distributed(args, group_name, group, weight_version, rollout_engines, converted_named_tensors):
674695
refs = [
675696
engine.update_weights_from_distributed.remote(

slime/backends/sglang_utils/sglang_engine.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,28 @@ def _make_request(self, endpoint: str, payload: Optional[dict] = None):
149149
response.raise_for_status()
150150
return response.json()
151151

152+
def health_generate(self, timeout: float = 5.0) -> bool:
153+
"""Run /health_generate on the underlying SGLang HTTP server.
154+
155+
Args:
156+
timeout: Timeout for the health request in seconds.
157+
158+
Returns:
159+
True if the server responds with HTTP 200.
160+
161+
Raises:
162+
requests.RequestException: If the request fails for any reason, including timeout.
163+
"""
164+
if self.node_rank != 0:
165+
return True
166+
167+
response = requests.get(
168+
f"http://{self.server_args.host}:{self.server_args.port}/health_generate",
169+
timeout=timeout,
170+
)
171+
response.raise_for_status()
172+
return True
173+
152174
def update_weights_from_tensor(
153175
self,
154176
serialized_named_tensors: List[str],
@@ -179,7 +201,7 @@ def flush_cache(self):
179201
if self.node_rank != 0:
180202
return
181203
# flush cache will not return status_code 200 when there are pending requests
182-
while True:
204+
for _ in range(60):
183205
try:
184206
response = requests.get(f"http://{self.server_args.host}:{self.server_args.port}/flush_cache")
185207
if response.status_code == 200:
@@ -188,7 +210,10 @@ def flush_cache(self):
188210
raise e
189211
except Exception as e:
190212
print(f"Error flushing cache: {e}")
213+
time.sleep(1)
191214
continue
215+
else:
216+
raise TimeoutError("Timeout while flushing cache.")
192217

193218
def shutdown(self):
194219
requests.post(
@@ -230,6 +255,18 @@ def init_weights_update_group(self, master_address, master_port, rank_offset, wo
230255
},
231256
)
232257

258+
def destroy_weights_update_group(self, group_name):
259+
try:
260+
return self._make_request(
261+
"destroy_weights_update_group",
262+
{
263+
"group_name": group_name,
264+
},
265+
)
266+
except:
267+
# catch the case there the engine is just created and does not have the group.
268+
pass
269+
233270
def update_weights_from_distributed(
234271
self, names, dtypes, shapes, group_name, flush_cache=False, weight_version: Optional[str] = None
235272
):

slime/backends/xtuner_utils/actor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,10 @@ def update_weights(self): # type: ignore[override]
260260
if self.args.debug_train_only or self.args.debug_rollout_only:
261261
return
262262

263-
if not self.connected:
264-
self.connected = True
265-
rollout_engines, rollout_engine_lock = ray.get(self.rollout_manager.get_rollout_engines_and_lock.remote())
263+
rollout_engines, rollout_engine_lock, num_new_engines = ray.get(
264+
self.rollout_manager.get_rollout_engines_and_lock.remote()
265+
)
266+
if num_new_engines > 0:
266267
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
267268
dist.barrier(group=get_gloo_group())
268269

slime/ray/placement_group.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,6 @@ def create_rollout_manager(args, pg, wandb_run_id):
175175
if args.rollout_global_dataset:
176176
ray.get(rollout_manager.load.remote(args.start_rollout_id - 1))
177177

178-
# TODO: extract this to single function
179-
rollout_engines, rollout_engine_lock = ray.get(rollout_manager.get_rollout_engines_and_lock.remote())
180-
181178
# calculate num_rollout from num_epoch
182179
num_rollout_per_epoch = None
183180
if args.num_rollout is None:

0 commit comments

Comments
 (0)