@@ -73,7 +73,7 @@ def init(
7373 )
7474
7575 if role == "critic" :
76- if self .args .offload :
76+ if self .args .offload_train :
7777 self .sleep (("model" ))
7878 Timer ().start ("train_wait" )
7979 return
@@ -106,7 +106,7 @@ def init(
106106 # empty cache after initialization
107107 clear_memory ()
108108
109- if self .args .offload :
109+ if self .args .offload_train :
110110 # recover to actor in the end.
111111 self .update_gpu_params_dict (self .weights ["actor" ])
112112 self .sleep (("model" ))
@@ -156,7 +156,7 @@ def update_gpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None:
156156
157157 @timer
158158 def sleep (self , tags : Union [str , Tuple [str , ...]]) -> None :
159- assert self .args .offload
159+ assert self .args .offload_train
160160 assert "model" in tags
161161 if isinstance (tags , str ):
162162 tags = (tags ,)
@@ -171,7 +171,7 @@ def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
171171
172172 @timer
173173 def wake_up (self , tags : Union [str , Tuple [str , ...]]) -> None :
174- assert self .args .offload
174+ assert self .args .offload_train
175175
176176 # there are weird times when sglang is not offloaded immediately, so we wait here.
177177 mem_fraction_static = self .args .sglang_mem_fraction_static or 0.8
@@ -243,7 +243,7 @@ def compute_log_prob(
243243 def train (self , rollout_id : int , rollout_data_ref : Box ) -> None :
244244 Timer ().end ("train_wait" )
245245
246- if self .args .offload :
246+ if self .args .offload_train :
247247 self .wake_up (("model" ))
248248
249249 with timer ("data_preprocess" ):
@@ -408,7 +408,7 @@ def update_weights(self) -> None:
408408 if self .args .debug_train_only or self .args .debug_rollout_only :
409409 return
410410
411- if self .args .offload :
411+ if self .args .offload_train :
412412 reload_process_groups ()
413413
414414 rollout_engines , rollout_engine_lock , num_new_engines = ray .get (
@@ -418,7 +418,7 @@ def update_weights(self) -> None:
418418 self .weight_updater .connect_rollout_engines (rollout_engines , rollout_engine_lock )
419419 dist .barrier (group = get_gloo_group ())
420420
421- with torch_memory_saver .disable () if self .args .offload else nullcontext ():
421+ with torch_memory_saver .disable () if self .args .offload_train else nullcontext ():
422422 print_memory ("before update_weights" )
423423 self .weight_updater .update_weights ()
424424 print_memory ("after update_weights" )
@@ -435,7 +435,7 @@ def update_weights(self) -> None:
435435 else :
436436 self .update_cpu_params_dict (self .weights ["old_actor" ])
437437
438- if self .args .offload :
438+ if self .args .offload_train :
439439 destroy_process_groups ()
440440
441441 def load_other_checkpoint (self , model_tag : str , path : str ) -> None :
0 commit comments