Skip to content

Commit c9de8c5

Browse files
authored
Add clear_num_new_engines and some code cleanup (THUDM#1349)
1 parent dacd086 commit c9de8c5

File tree

4 files changed

+34
-31
lines changed

4 files changed

+34
-31
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,8 @@ def update_weights(self) -> None: # type: ignore[override]
818818
if num_new_engines > 0:
819819
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
820820
dist.barrier(group=get_gloo_group())
821+
if dist.get_rank() == 0:
822+
ray.get(self.rollout_manager.clear_num_new_engines.remote())
821823

822824
self.weight_updater.update_weights()
823825

slime/backends/megatron_utils/actor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ def update_weights(self) -> None:
524524
if num_new_engines > 0:
525525
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
526526
dist.barrier(group=get_gloo_group())
527+
if dist.get_rank() == 0:
528+
ray.get(self.rollout_manager.clear_num_new_engines.remote())
527529

528530
with torch_memory_saver.disable() if self.args.offload_train else nullcontext():
529531
print_memory("before update_weights")

slime/ray/rollout.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import ray
1111
import torch
1212
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
13-
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
13+
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
1414

1515
from slime.backends.sglang_utils.sglang_engine import SGLangEngine
1616
from slime.rollout.base_types import call_rollout_fn
@@ -170,6 +170,12 @@ def onload(self, tags: list[str] | None = None):
170170
]
171171
)
172172

173+
def onload_weights(self):
174+
self.onload(tags=[GPU_MEMORY_TYPE_WEIGHTS])
175+
176+
def onload_kv(self):
177+
self.onload(tags=[GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH])
178+
173179
def recover_rollout_engines(self):
174180
"""Restart any dead rollout engines and update num_new_engines for update_weights detection."""
175181
self.health_monitoring_pause()
@@ -187,6 +193,10 @@ def recover_rollout_engines(self):
187193

188194
return self.rollout_engines, self.rollout_engine_lock, self.num_new_engines
189195

196+
def clear_num_new_engines(self):
197+
# when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights
198+
self.num_new_engines = 0
199+
190200
def health_monitoring_pause(self) -> None:
191201
if self._health_monitor is not None:
192202
self._health_monitor.pause()

train.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
import ray
2-
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
3-
4-
try:
5-
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH
6-
except ImportError:
7-
GPU_MEMORY_TYPE_CUDA_GRAPH = None
82

93
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
104
from slime.utils.arguments import parse_args
@@ -27,7 +21,7 @@ def train(args):
2721
actor_model, critic_model = create_training_models(args, pgs, rollout_manager)
2822

2923
if args.offload_rollout:
30-
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
24+
ray.get(rollout_manager.onload_weights.remote())
3125

3226
# always update weight first so that sglang has the loaded weights from training.
3327
actor_model.update_weights()
@@ -36,9 +30,7 @@ def train(args):
3630
ray.get(rollout_manager.check_weights.remote(action="compare"))
3731

3832
if args.offload_rollout:
39-
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
40-
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
41-
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
33+
ray.get(rollout_manager.onload_kv.remote())
4234

4335
# special case for eval-only
4436
if args.num_rollout == 0 and args.eval_interval is not None:
@@ -55,9 +47,19 @@ def offload_train():
5547
else:
5648
actor_model.clear_memory()
5749

58-
def onload_rollout():
59-
if args.offload_rollout:
60-
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
50+
def save(rollout_id):
51+
if (not args.use_critic) or (rollout_id >= args.num_critic_only_steps):
52+
actor_model.save_model(
53+
rollout_id,
54+
force_sync=rollout_id == args.num_rollout - 1,
55+
)
56+
if args.use_critic:
57+
critic_model.save_model(
58+
rollout_id,
59+
force_sync=rollout_id == args.num_rollout - 1,
60+
)
61+
if args.rollout_global_dataset:
62+
ray.get(rollout_manager.save.remote(rollout_id))
6163

6264
# train loop.
6365
# note that for async training, one can change the position of the sync operation(ray.get).
@@ -79,27 +81,14 @@ def onload_rollout():
7981
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
8082

8183
if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout):
82-
if (not args.use_critic) or (rollout_id >= args.num_critic_only_steps):
83-
actor_model.save_model(
84-
rollout_id,
85-
force_sync=rollout_id == args.num_rollout - 1,
86-
)
87-
if args.use_critic:
88-
critic_model.save_model(
89-
rollout_id,
90-
force_sync=rollout_id == args.num_rollout - 1,
91-
)
92-
if args.rollout_global_dataset:
93-
ray.get(rollout_manager.save.remote(rollout_id))
84+
save(rollout_id)
9485

9586
offload_train()
96-
onload_rollout()
87+
if args.offload_rollout:
88+
ray.get(rollout_manager.onload_weights.remote())
9789
actor_model.update_weights()
98-
9990
if args.offload_rollout:
100-
if GPU_MEMORY_TYPE_CUDA_GRAPH is not None:
101-
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_CUDA_GRAPH]))
102-
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
91+
ray.get(rollout_manager.onload_kv.remote())
10392

10493
if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch):
10594
ray.get(rollout_manager.eval.remote(rollout_id))

0 commit comments

Comments
 (0)