Skip to content

Commit 6bc208b

Browse files
authored
[bugfix] initialize rollout manager first to calculate num_rollout (#473)
1 parent 7e77f61 commit 6bc208b

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

slime/ray/placement_group.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,7 @@ def create_rollout_manager(args, pg, wandb_run_id):
182182
args.num_rollout = num_rollout_per_epoch * args.num_epoch
183183
assert args.num_rollout > 0
184184

185+
if args.offload:
186+
ray.get(rollout_manager.offload.remote())
187+
185188
return rollout_manager, num_rollout_per_epoch

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ def train(args):
1616

1717
_TensorboardAdapter(args)
1818

19-
# create the actor and critic models
20-
actor_model, critic_model = create_training_models(args, pgs, wandb_run_id=wandb_run_id)
21-
2219
# create the rollout manager, with sglang engines inside.
20+
# need to initialize rollout manager first to calculate num_rollout
2321
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
2422

23+
# create the actor and critic models
24+
actor_model, critic_model = create_training_models(args, pgs, wandb_run_id=wandb_run_id)
25+
2526
actor_model.set_rollout_manager(rollout_manager)
2627

2728
if args.offload:
28-
ray.get(rollout_manager.offload.remote())
2929
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
3030

3131
# always update weight first so that sglang has the loaded weights from training.

train_async.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ def train(args):
1111
pgs = create_placement_groups(args)
1212
wandb_run_id = init_wandb_primary(args)
1313

14-
# create the actor and critic models
15-
actor_model, critic_model = create_training_models(args, pgs, wandb_run_id=wandb_run_id)
16-
1714
# create the rollout manager, with sglang engines inside.
15+
# need to initialize rollout manager first to calculate num_rollout
1816
rollout_manager, num_rollout_per_epoch = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
1917

18+
# create the actor and critic models
19+
actor_model, critic_model = create_training_models(args, pgs, wandb_run_id=wandb_run_id)
20+
2021
actor_model.set_rollout_manager(rollout_manager)
2122

2223
# always update weight first so that sglang has the loaded weights from training.

0 commit comments

Comments
 (0)