File tree Expand file tree Collapse file tree 3 files changed +24
-16
lines changed
slime/backends/megatron_utils Expand file tree Collapse file tree 3 files changed +24
-16
lines changed Original file line number Diff line number Diff line change @@ -91,7 +91,10 @@ def init(
9191
9292 self .weights_backuper = TensorBackuper .create (
9393 source_getter = lambda : named_params_and_buffers (
94- self .args , self .model , convert_to_global_name = args .megatron_to_hf_mode == "raw"
94+ self .args ,
95+ self .model ,
96+ convert_to_global_name = args .megatron_to_hf_mode == "raw" ,
97+ translate_gpu_to_cpu = not self .args .enable_weights_backuper ,
9598 ),
9699 single_tag = None if args .enable_weights_backuper else "actor" ,
97100 )
Original file line number Diff line number Diff line change @@ -117,11 +117,26 @@ def named_params_and_buffers(
117117 args : Namespace ,
118118 model : Sequence [torch .nn .Module ],
119119 convert_to_global_name : bool = True ,
120+ translate_gpu_to_cpu : bool = False ,
120121) -> Iterator [tuple [str , torch .Tensor ]]:
121122 if convert_to_global_name :
122- return _named_params_and_buffers_global (args , model )
123+ ans = _named_params_and_buffers_global (args , model )
123124 else :
124- return _named_params_and_buffers_vanilla (model )
125+ ans = _named_params_and_buffers_vanilla (model )
126+
127+ if translate_gpu_to_cpu :
128+ ans = ((name , _maybe_get_cpu_backup (tensor )) for name , tensor in ans )
129+
130+ return ans
131+
132+
133+ def _maybe_get_cpu_backup (x : torch .Tensor ):
134+ from torch_memory_saver import torch_memory_saver
135+
136+ if (cpu_tensor := torch_memory_saver .get_cpu_backup (x )) is not None :
137+ return cpu_tensor
138+
139+ return x
125140
126141
127142def _named_params_and_buffers_vanilla (model : Sequence [torch .nn .Module ]) -> Iterator [tuple [str , torch .Tensor ]]:
Original file line number Diff line number Diff line change @@ -28,12 +28,8 @@ def train(args):
2828 if args .offload_rollout :
2929 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_WEIGHTS ]))
3030
31- if args .offload_train and not args .enable_weights_backuper :
32- actor_model .onload ()
3331 # always update weight first so that sglang has the loaded weights from training.
3432 actor_model .update_weights ()
35- if args .offload_train and not args .enable_weights_backuper :
36- actor_model .offload ()
3733
3834 if args .check_weight_update_equal :
3935 ray .get (rollout_manager .check_weights .remote (action = "compare" ))
@@ -93,15 +89,9 @@ def onload_rollout():
9389 if args .rollout_global_dataset :
9490 ray .get (rollout_manager .save .remote (rollout_id ))
9591
96- if args .enable_weights_backuper :
97- offload_train ()
98- onload_rollout ()
99- actor_model .update_weights ()
100- else :
101- actor_model .clear_memory ()
102- onload_rollout ()
103- actor_model .update_weights ()
104- offload_train ()
92+ offload_train ()
93+ onload_rollout ()
94+ actor_model .update_weights ()
10595
10696 if args .offload_rollout :
10797 if GPU_MEMORY_TYPE_CUDA_GRAPH is not None :
You can’t perform that action at this time.
0 commit comments