11import ray
22from sglang .srt .constants import GPU_MEMORY_TYPE_KV_CACHE , GPU_MEMORY_TYPE_WEIGHTS
33
4- from slime .ray .placement_group import create_placement_groups , create_rollout_manager , create_training_group
5- from slime .ray .registry import register_actor
4+ from slime .ray .placement_group import create_placement_groups , create_rollout_manager , create_training_models
65from slime .utils .arguments import parse_args
76from slime .utils .wandb_utils import init_wandb_primary
87
@@ -12,54 +11,19 @@ def train(args):
1211 pgs = create_placement_groups (args )
1312 wandb_run_id = init_wandb_primary (args )
1413
15- actor_model = create_training_group (args , pgs ["actor" ], wandb_run_id = wandb_run_id )
16- if args .use_critic :
17- critic_model = create_training_group (args , pgs ["critic" ], wandb_run_id = wandb_run_id )
14+ # create the actor and critic models
15+ actor_model , critic_model = create_training_models (args , pgs , wandb_run_id = wandb_run_id )
1816
1917 # create the rollout manager, with sglang engines inside.
20- rollout_manager = create_rollout_manager (args , pgs ["rollout" ], wandb_run_id = wandb_run_id )
21-
22- # TODO: extract this to single function
23- rollout_engines , rollout_engine_lock = ray .get (rollout_manager .get_rollout_engines_and_lock .remote ())
24- for i , rollout_engine in enumerate (rollout_engines ):
25- register_actor ("rollout" , i , rollout_engine )
26- register_actor ("rollout_lock" , 0 , rollout_engine_lock )
27- for i , actor in enumerate (actor_model ._actor_handlers ):
28- register_actor ("actor" , i , actor )
29- if args .use_critic :
30- for i , critic in enumerate (critic_model ._actor_handlers ):
31- register_actor ("critic" , i , critic )
32-
33- # calculate num_rollout from num_epoch
34- num_rollout_per_epoch = None
35- if args .num_rollout is None :
36- num_rollout_per_epoch = ray .get (rollout_manager .get_num_rollout_per_epoch .remote ())
37- args .num_rollout = num_rollout_per_epoch * args .num_epoch
38- assert args .num_rollout > 0
39-
40- # sync the initialization (model initalization, load checkpoint, etc.)
41- if args .use_critic :
42- critic_init_handle = critic_model .async_init (args , role = "critic" , with_ref = False )
43-
44- start_rollout_ids = ray .get (
45- actor_model .async_init (args , role = "actor" , with_ref = args .kl_coef != 0 or args .use_kl_loss )
46- )
47- assert len (set (start_rollout_ids )) == 1
48- if args .start_rollout_id is None :
49- args .start_rollout_id = start_rollout_ids [0 ]
50-
51- if args .rollout_global_dataset :
52- ray .get (rollout_manager .load .remote (args .start_rollout_id - 1 ))
53-
54- if args .use_critic :
55- ray .get (critic_init_handle )
56- ray .get (actor_model .async_connect (critic_model ))
18+ rollout_manager , num_rollout_per_epoch = create_rollout_manager (args , pgs ["rollout" ], wandb_run_id = wandb_run_id )
19+
20+ actor_model .set_rollout_manager (rollout_manager )
5721
5822 if args .offload :
5923 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_WEIGHTS ]))
6024
6125 # always update weight first so that sglang has the loaded weights from training.
62- ray . get ( actor_model .async_update_weights () )
26+ actor_model .update_weights ( )
6327
6428 if args .offload :
6529 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_KV_CACHE ]))
@@ -88,21 +52,23 @@ def train(args):
8852 (rollout_id + 1 ) % args .save_interval == 0
8953 or (num_rollout_per_epoch is not None and (rollout_id + 1 ) % num_rollout_per_epoch == 0 )
9054 ):
91- ray .get (actor_model .async_save_model (rollout_id ))
55+ actor_model .save_model (rollout_id )
56+ if args .use_critic :
57+ critic_model .save_model (rollout_id )
9258 if args .rollout_global_dataset :
9359 ray .get (rollout_manager .save .remote (rollout_id ))
9460
9561 if args .offload :
9662 if args .use_critic :
97- ray . get ( critic_model .async_offload () )
63+ critic_model .offload ( )
9864 if rollout_id >= args .num_critic_only_steps :
99- ray . get ( actor_model .async_offload () )
65+ actor_model .offload ( )
10066 else :
101- ray . get ( actor_model .async_offload () )
67+ actor_model .offload ( )
10268
10369 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_WEIGHTS ]))
10470
105- ray . get ( actor_model .async_update_weights () )
71+ actor_model .update_weights ( )
10672
10773 if args .offload :
10874 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_KV_CACHE ]))
0 commit comments