Skip to content

Commit 584657a

Browse files
authored
[refactor] remove Registry and change the order of init (THUDM#398)
* [refactor] remove Registry and change the order of init * bugfix
1 parent 371c030 commit 584657a

File tree

7 files changed

+99
-148
lines changed

7 files changed

+99
-148
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from megatron.core import mpu
1717
from transformers import AutoConfig, AutoTokenizer
1818

19-
from slime.ray.registry import get_actors
2019
from slime.ray.train_actor import TrainRayActor
2120
from slime.utils.data import process_rollout_data
2221
from slime.utils.distributed_utils import get_gloo_group, init_process_group
@@ -404,8 +403,7 @@ def update_weights(self):
404403

405404
if not self.connected:
406405
self.connected = True
407-
rollout_engines = get_actors("rollout")
408-
rollout_engine_lock = get_actors("rollout_lock", 0)
406+
rollout_engines, rollout_engine_lock = ray.get(self.rollout_manager.get_rollout_engines_and_lock.remote())
409407
self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock)
410408
dist.barrier(group=get_gloo_group())
411409

slime/ray/actor_group.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,24 @@ def async_train(self, rollout_id, rollout_data_ref):
122122
"""Do one rollout training"""
123123
return [actor.train.remote(rollout_id, rollout_data_ref) for actor in self._actor_handlers]
124124

125-
def async_save_model(self, step_id):
125+
def save_model(self, step_id):
126126
"""Save actor model on rank 0."""
127-
return [actor.save_model.remote(step_id) for actor in self._actor_handlers]
127+
return ray.get([actor.save_model.remote(step_id) for actor in self._actor_handlers])
128128

129-
def async_update_weights(self):
129+
def update_weights(self):
130130
"""Broadcast weights from rank 0 to all other ranks."""
131-
return [actor.update_weights.remote() for actor in self._actor_handlers]
131+
return ray.get([actor.update_weights.remote() for actor in self._actor_handlers])
132132

133-
def async_offload(self):
134-
return [actor.sleep.remote(("model")) for actor in self._actor_handlers]
133+
def offload(self):
134+
return ray.get([actor.sleep.remote(("model")) for actor in self._actor_handlers])
135135

136-
def async_connect(self, critic_group):
137-
return [
138-
actor.connect_actor_critic.remote((critic))
139-
for actor, critic in zip(self._actor_handlers, critic_group._actor_handlers)
140-
]
136+
def connect(self, critic_group):
137+
return ray.get(
138+
[
139+
actor.connect_actor_critic.remote((critic))
140+
for actor, critic in zip(self._actor_handlers, critic_group._actor_handlers)
141+
]
142+
)
143+
144+
def set_rollout_manager(self, rollout_manager):
145+
return ray.get([actor.set_rollout_manager.remote(rollout_manager) for actor in self._actor_handlers])

slime/ray/placement_group.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,58 @@ def create_training_group(args, pg, wandb_run_id):
131131
return actor_model
132132

133133

134+
def create_training_models(args, pgs, wandb_run_id):
135+
actor_model = allocate_train_group(
136+
args=args,
137+
num_nodes=args.actor_num_nodes,
138+
num_gpus_per_node=args.actor_num_gpus_per_node,
139+
pg=pgs["actor"],
140+
wandb_run_id=wandb_run_id,
141+
)
142+
if args.use_critic:
143+
critic_model = allocate_train_group(
144+
args=args,
145+
num_nodes=args.critic_num_nodes,
146+
num_gpus_per_node=args.critic_num_gpus_per_node,
147+
pg=pgs["critic"],
148+
wandb_run_id=wandb_run_id,
149+
)
150+
critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False)
151+
else:
152+
critic_model = None
153+
154+
start_rollout_ids = ray.get(
155+
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
156+
)
157+
158+
assert len(set(start_rollout_ids)) == 1
159+
if args.start_rollout_id is None:
160+
args.start_rollout_id = start_rollout_ids[0]
161+
162+
if args.use_critic:
163+
ray.get(critic_init_handle)
164+
actor_model.connect(critic_model)
165+
166+
return actor_model, critic_model
167+
168+
134169
def create_rollout_manager(args, pg, wandb_run_id):
135-
return RolloutManager.options(
170+
rollout_manager = RolloutManager.options(
136171
num_cpus=1,
137172
num_gpus=0,
138173
).remote(args, pg, wandb_run_id=wandb_run_id)
174+
175+
if args.rollout_global_dataset:
176+
ray.get(rollout_manager.load.remote(args.start_rollout_id - 1))
177+
178+
# TODO: extract this to single function
179+
rollout_engines, rollout_engine_lock = ray.get(rollout_manager.get_rollout_engines_and_lock.remote())
180+
181+
# calculate num_rollout from num_epoch
182+
num_rollout_per_epoch = None
183+
if args.num_rollout is None:
184+
num_rollout_per_epoch = ray.get(rollout_manager.get_num_rollout_per_epoch.remote())
185+
args.num_rollout = num_rollout_per_epoch * args.num_epoch
186+
assert args.num_rollout > 0
187+
188+
return rollout_manager, num_rollout_per_epoch

slime/ray/registry.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

slime/ray/train_actor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,10 @@ def save_model(self, iteration):
9393
@abc.abstractmethod
9494
def update_weights(self):
9595
raise NotImplementedError
96+
97+
@abc.abstractmethod
98+
def connect_actor_critic(self, critic_group):
99+
raise NotImplementedError
100+
101+
def set_rollout_manager(self, rollout_manager):
102+
self.rollout_manager = rollout_manager

train.py

Lines changed: 14 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import ray
22
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
33

4-
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_group
5-
from slime.ray.registry import register_actor
4+
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models
65
from slime.utils.arguments import parse_args
76
from slime.utils.wandb_utils import init_wandb_primary
87

@@ -12,54 +11,19 @@ def train(args):
1211
pgs = create_placement_groups(args)
1312
wandb_run_id = init_wandb_primary(args)
1413

15-
actor_model = create_training_group(args, pgs["actor"], wandb_run_id=wandb_run_id)
16-
if args.use_critic:
17-
critic_model = create_training_group(args, pgs["critic"], wandb_run_id=wandb_run_id)
14+
# create the actor and critic models
15+
actor_model, critic_model = create_training_models(args, pgs, wandb_run_id=wandb_run_id)
1816

1917
# create the rollout manager, with sglang engines inside.
20-
rollout_manager = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
21-
22-
# TODO: extract this to single function
23-
rollout_engines, rollout_engine_lock = ray.get(rollout_manager.get_rollout_engines_and_lock.remote())
24-
for i, rollout_engine in enumerate(rollout_engines):
25-
register_actor("rollout", i, rollout_engine)
26-
register_actor("rollout_lock", 0, rollout_engine_lock)
27-
for i, actor in enumerate(actor_model._actor_handlers):
28-
register_actor("actor", i, actor)
29-
if args.use_critic:
30-
for i, critic in enumerate(critic_model._actor_handlers):
31-
register_actor("critic", i, critic)
32-
33-
# calculate num_rollout from num_epoch
34-
num_rollout_per_epoch = None
35-
if args.num_rollout is None:
36-
num_rollout_per_epoch = ray.get(rollout_manager.get_num_rollout_per_epoch.remote())
37-
args.num_rollout = num_rollout_per_epoch * args.num_epoch
38-
assert args.num_rollout > 0
39-
40-
# sync the initialization (model initalization, load checkpoint, etc.)
41-
if args.use_critic:
42-
critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False)
43-
44-
start_rollout_ids = ray.get(
45-
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
46-
)
47-
assert len(set(start_rollout_ids)) == 1
48-
if args.start_rollout_id is None:
49-
args.start_rollout_id = start_rollout_ids[0]
50-
51-
if args.rollout_global_dataset:
52-
ray.get(rollout_manager.load.remote(args.start_rollout_id - 1))
53-
54-
if args.use_critic:
55-
ray.get(critic_init_handle)
56-
ray.get(actor_model.async_connect(critic_model))
18+
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
19+
20+
actor_model.set_rollout_manager(rollout_manager)
5721

5822
if args.offload:
5923
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
6024

6125
# always update weight first so that sglang has the loaded weights from training.
62-
ray.get(actor_model.async_update_weights())
26+
actor_model.update_weights()
6327

6428
if args.offload:
6529
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))
@@ -88,21 +52,23 @@ def train(args):
8852
(rollout_id + 1) % args.save_interval == 0
8953
or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
9054
):
91-
ray.get(actor_model.async_save_model(rollout_id))
55+
actor_model.save_model(rollout_id)
56+
if args.use_critic:
57+
critic_model.save_model(rollout_id)
9258
if args.rollout_global_dataset:
9359
ray.get(rollout_manager.save.remote(rollout_id))
9460

9561
if args.offload:
9662
if args.use_critic:
97-
ray.get(critic_model.async_offload())
63+
critic_model.offload()
9864
if rollout_id >= args.num_critic_only_steps:
99-
ray.get(actor_model.async_offload())
65+
actor_model.offload()
10066
else:
101-
ray.get(actor_model.async_offload())
67+
actor_model.offload()
10268

10369
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
10470

105-
ray.get(actor_model.async_update_weights())
71+
actor_model.update_weights()
10672

10773
if args.offload:
10874
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_KV_CACHE]))

train_async.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import ray
22

3-
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_group
4-
from slime.ray.registry import register_actor
3+
from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_train_models
54
from slime.utils.arguments import parse_args
65
from slime.utils.wandb_utils import init_wandb_primary
76

@@ -12,56 +11,16 @@ def train(args):
1211
pgs = create_placement_groups(args)
1312
wandb_run_id = init_wandb_primary(args)
1413

15-
actor_model = create_training_group(args, pgs["actor"], wandb_run_id=wandb_run_id)
16-
if args.use_critic:
17-
critic_model = create_training_group(args, pgs["critic"], wandb_run_id=wandb_run_id)
14+
# create the actor and critic models
15+
actor_model, critic_model = create_train_models(args, pgs, wandb_run_id=wandb_run_id)
1816

1917
# create the rollout manager, with sglang engines inside.
20-
rollout_manager = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
18+
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
2119

22-
# TODO: extract this to single function
23-
rollout_engines, rollout_engine_lock = ray.get(rollout_manager.get_rollout_engines_and_lock.remote())
24-
for i, rollout_engine in enumerate(rollout_engines):
25-
register_actor("rollout", i, rollout_engine)
26-
register_actor("rollout_lock", 0, rollout_engine_lock)
27-
for i, actor in enumerate(actor_model._actor_handlers):
28-
register_actor("actor", i, actor)
29-
if args.use_critic:
30-
for i, critic in enumerate(critic_model._actor_handlers):
31-
register_actor("critic", i, critic)
32-
33-
# calculate num_rollout from num_epoch
34-
num_rollout_per_epoch = None
35-
if args.num_rollout is None:
36-
num_rollout_per_epoch = ray.get(rollout_manager.get_num_rollout_per_epoch.remote())
37-
args.num_rollout = num_rollout_per_epoch * args.num_epoch
38-
assert args.num_rollout > 0
39-
40-
# sync the initialization (model initalization, load checkpoint, etc.)
41-
if args.use_critic:
42-
critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False)
43-
44-
start_rollout_ids = ray.get(
45-
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
46-
)
47-
48-
assert len(set(start_rollout_ids)) == 1
49-
if args.start_rollout_id is None:
50-
args.start_rollout_id = start_rollout_ids[0]
51-
52-
if args.rollout_global_dataset:
53-
ray.get(rollout_manager.load.remote(args.start_rollout_id - 1))
54-
55-
if args.use_critic:
56-
ray.get(critic_init_handle)
57-
ray.get(actor_model.async_connect(critic_model))
58-
59-
if args.use_critic:
60-
ray.get(critic_init_handle)
61-
ray.get(actor_model.async_connect(critic_model))
20+
actor_model.set_rollout_manager(rollout_manager)
6221

6322
# always update weight first so that sglang has the loaded weights from training.
64-
ray.get(actor_model.async_update_weights())
23+
actor_model.update_weights()
6524

6625
# async train loop.
6726
rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id)
@@ -86,15 +45,17 @@ def train(args):
8645
(rollout_id + 1) % args.save_interval == 0
8746
or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
8847
):
89-
ray.get(actor_model.async_save_model(rollout_id))
48+
actor_model.save_model(rollout_id)
49+
if args.use_critic:
50+
critic_model.save_model(rollout_id)
9051
if args.rollout_global_dataset:
9152
ray.get(rollout_manager.save.remote(rollout_id))
9253

9354
if (rollout_id + 1) % args.update_weights_interval == 0:
9455
# sync generate before update weights to prevent update weight in the middle of generation
9556
rollout_data_curr_ref = ray.get(rollout_data_next_future)
9657
rollout_data_next_future = None
97-
ray.get(actor_model.async_update_weights())
58+
actor_model.update_weights()
9859

9960
if args.eval_interval is not None and (
10061
(rollout_id + 1) % args.eval_interval == 0

0 commit comments

Comments
 (0)