11import ray
2- from sglang .srt .constants import GPU_MEMORY_TYPE_KV_CACHE , GPU_MEMORY_TYPE_WEIGHTS
3-
4- try :
5- from sglang .srt .constants import GPU_MEMORY_TYPE_CUDA_GRAPH
6- except ImportError :
7- GPU_MEMORY_TYPE_CUDA_GRAPH = None
82
93from slime .ray .placement_group import create_placement_groups , create_rollout_manager , create_training_models
104from slime .utils .arguments import parse_args
@@ -27,7 +21,7 @@ def train(args):
2721 actor_model , critic_model = create_training_models (args , pgs , rollout_manager )
2822
2923 if args .offload_rollout :
30- ray .get (rollout_manager .onload .remote (tags = [ GPU_MEMORY_TYPE_WEIGHTS ] ))
24+ ray .get (rollout_manager .onload_weights .remote ())
3125
3226 # always update weight first so that sglang has the loaded weights from training.
3327 actor_model .update_weights ()
@@ -36,9 +30,7 @@ def train(args):
3630 ray .get (rollout_manager .check_weights .remote (action = "compare" ))
3731
3832 if args .offload_rollout :
39- if GPU_MEMORY_TYPE_CUDA_GRAPH is not None :
40- ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_CUDA_GRAPH ]))
41- ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_KV_CACHE ]))
33+ ray .get (rollout_manager .onload_kv .remote ())
4234
4335 # special case for eval-only
4436 if args .num_rollout == 0 and args .eval_interval is not None :
@@ -55,9 +47,19 @@ def offload_train():
5547 else :
5648 actor_model .clear_memory ()
5749
58- def onload_rollout ():
59- if args .offload_rollout :
60- ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_WEIGHTS ]))
50+ def save (rollout_id ):
51+ if (not args .use_critic ) or (rollout_id >= args .num_critic_only_steps ):
52+ actor_model .save_model (
53+ rollout_id ,
54+ force_sync = rollout_id == args .num_rollout - 1 ,
55+ )
56+ if args .use_critic :
57+ critic_model .save_model (
58+ rollout_id ,
59+ force_sync = rollout_id == args .num_rollout - 1 ,
60+ )
61+ if args .rollout_global_dataset :
62+ ray .get (rollout_manager .save .remote (rollout_id ))
6163
6264 # train loop.
6365 # note that for async training, one can change the position of the sync operation(ray.get).
@@ -79,27 +81,14 @@ def onload_rollout():
7981 ray .get (actor_model .async_train (rollout_id , rollout_data_ref ))
8082
8183 if should_run_periodic_action (rollout_id , args .save_interval , num_rollout_per_epoch , args .num_rollout ):
82- if (not args .use_critic ) or (rollout_id >= args .num_critic_only_steps ):
83- actor_model .save_model (
84- rollout_id ,
85- force_sync = rollout_id == args .num_rollout - 1 ,
86- )
87- if args .use_critic :
88- critic_model .save_model (
89- rollout_id ,
90- force_sync = rollout_id == args .num_rollout - 1 ,
91- )
92- if args .rollout_global_dataset :
93- ray .get (rollout_manager .save .remote (rollout_id ))
84+ save (rollout_id )
9485
9586 offload_train ()
96- onload_rollout ()
87+ if args .offload_rollout :
88+ ray .get (rollout_manager .onload_weights .remote ())
9789 actor_model .update_weights ()
98-
9990 if args .offload_rollout :
100- if GPU_MEMORY_TYPE_CUDA_GRAPH is not None :
101- ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_CUDA_GRAPH ]))
102- ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_KV_CACHE ]))
91+ ray .get (rollout_manager .onload_kv .remote ())
10392
10493 if should_run_periodic_action (rollout_id , args .eval_interval , num_rollout_per_epoch ):
10594 ray .get (rollout_manager .eval .remote (rollout_id ))
0 commit comments