22from sglang .srt .constants import GPU_MEMORY_TYPE_KV_CACHE , GPU_MEMORY_TYPE_WEIGHTS
33
44from slime .ray .placement_group import create_placement_groups , create_rollout_manager , create_training_group
5+ from slime .ray .registry import register_actor
56from slime .utils .arguments import parse_args
67from slime .utils .wandb_utils import init_wandb_primary
78
@@ -18,6 +19,17 @@ def train(args):
1819 # create the rollout manager, with sglang engines inside.
1920 rollout_manager = create_rollout_manager (args , pgs ["rollout" ], wandb_run_id = wandb_run_id )
2021
22+ # TODO: extract this to single function
23+ rollout_engines , rollout_engine_lock = ray .get (rollout_manager .get_rollout_engines_and_lock .remote ())
24+ for i , rollout_engine in enumerate (rollout_engines ):
25+ register_actor ("rollout" , i , rollout_engine )
26+ register_actor ("rollout_lock" , 0 , rollout_engine_lock )
27+ for i , actor in enumerate (actor_model ._actor_handlers ):
28+ register_actor ("actor" , i , actor )
29+ if args .use_critic :
30+ for i , critic in enumerate (critic_model ._actor_handlers ):
31+ register_actor ("critic" , i , critic )
32+
2133 # calculate num_rollout from num_epoch
2234 num_rollout_per_epoch = None
2335 if args .num_rollout is None :
@@ -32,17 +44,13 @@ def train(args):
3244 start_rollout_ids = ray .get (
3345 actor_model .async_init (args , role = "actor" , with_ref = args .kl_coef != 0 or args .use_kl_loss )
3446 )
35-
3647 assert len (set (start_rollout_ids )) == 1
3748 if args .start_rollout_id is None :
3849 args .start_rollout_id = start_rollout_ids [0 ]
3950
4051 if args .rollout_global_dataset :
4152 ray .get (rollout_manager .load .remote (args .start_rollout_id - 1 ))
4253
43- # initialize the connection for weight update during training
44- ray .get (actor_model .async_init_weight_update_connections (rollout_manager ))
45-
4654 if args .use_critic :
4755 ray .get (critic_init_handle )
4856 ray .get (actor_model .async_connect (critic_model ))
0 commit comments