@@ -581,22 +581,21 @@ def _build_model_optimizer(
581581 def update_weighs_by_checkpoint_engine (
582582 self ,
583583 weights : Generator [tuple [str , torch .Tensor ], None , None ],
584- req_func : Callable [[list [tuple [str , str ]]], None ]
584+ req_func : Callable [[list [tuple [str , str ]]], None ],
585585 ):
586586 named_tensors = {}
587587 for tensor_idx , (name , tensor ) in enumerate (weights ):
588588 if tensor_idx % self .world_size == self .rank :
589589 named_tensors [name ] = tensor
590590
591- checkpoint_name = f "checkpoint_engine"
591+ checkpoint_name = "checkpoint_engine"
592592 self .parameter_server .register_checkpoint (checkpoint_name , named_tensors = named_tensors )
593593 named_tensors = {}
594594 dist .barrier ()
595595 self .parameter_server .gather_metas (checkpoint_name )
596596 self .parameter_server .update (checkpoint_name , req_func )
597597 self .parameter_server .unregister_checkpoint (checkpoint_name )
598598
599-
600599 def _build_rollout (self , trust_remote_code = False ):
601600 from torch .distributed .device_mesh import init_device_mesh
602601
@@ -744,16 +743,18 @@ async def rollout_mode(self):
744743 )
745744 if self .config .rollout .enable_checkpoint_engine :
746745 req_func = await self .rollout .checkpoint_engine_req_func (self .infer_world_size )
747- self .update_weighs_by_checkpoint_engine (per_tensor_param , req_func )
746+ self .update_weighs_by_checkpoint_engine (per_tensor_base_params , req_func )
748747 else :
749748 await self .rollout .update_weights (per_tensor_base_params , base_sync_done = False )
750749 del base_model_params , per_tensor_base_params
751-
750+
752751 if self .config .rollout .enable_checkpoint_engine :
753752 req_func = await self .rollout .checkpoint_engine_req_func (self .infer_world_size )
754753 self .update_weighs_by_checkpoint_engine (per_tensor_param , req_func )
755754 else :
756- await self .rollout .update_weights (per_tensor_param , peft_config = peft_config , base_sync_done = self .base_sync_done )
755+ await self .rollout .update_weights (
756+ per_tensor_param , peft_config = peft_config , base_sync_done = self .base_sync_done
757+ )
757758 log_gpu_memory_usage ("After update_weights" , logger = logger )
758759 del params , per_tensor_param
759760 aggressive_empty_cache (force_sync = True )
0 commit comments