@@ -55,6 +55,11 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
5555 if self .args .debug_rollout_only :
5656 return 0
5757
58+ self .fsdp_cpu_offload = getattr (self .args , "fsdp_cpu_offload" , False )
59+ # Offload train and fsdp cpu offload cannot be used together, fsdp_cpu_offload is more aggressive
60+ if self .args .offload_train and self .fsdp_cpu_offload :
61+ self .args .offload_train = False
62+
5863 self ._enable_true_on_policy_optimizations (args )
5964 if dist .get_rank () == 0 :
6065 init_tracking (args , primary = False )
@@ -73,20 +78,29 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
7378 if self .args .multimodal_keys :
7479 self .vlm_processor = AutoProcessor .from_pretrained (self .args .hf_checkpoint , trust_remote_code = True )
7580
76- # Load model
77- model = AutoModelForCausalLM .from_pretrained (
78- self .args .hf_checkpoint ,
79- trust_remote_code = True ,
80- attn_implementation = self .args .attn_implementation ,
81- )
81+ init_context = self ._get_init_weight_context_manager ()
82+
83+ with init_context ():
84+ model = AutoModelForCausalLM .from_pretrained (
85+ self .args .hf_checkpoint ,
86+ trust_remote_code = True ,
87+ attn_implementation = self .args .attn_implementation ,
88+ )
89+
8290 model .train ()
8391
84- if args .gradient_checkpointing :
85- model .gradient_checkpointing_enable ()
92+ full_state = model .state_dict ()
93+
94+ model = apply_fsdp2 (model , mesh = self .dp_mesh , cpu_offload = self .fsdp_cpu_offload )
95+
96+ model = self ._fsdp2_load_full_state_dict (
97+ model , full_state , self .dp_mesh , cpu_offload = True if self .fsdp_cpu_offload else None
98+ )
8699
87- # Apply FSDP with DP mesh and CPU offload policy if requested
88- cpu_offload = getattr (args , "fsdp_cpu_offload" , False )
89- self .model = apply_fsdp2 (model , mesh = self .dp_mesh , cpu_offload = cpu_offload )
100+ self .model = model
101+
102+ if args .gradient_checkpointing :
103+ self .model .gradient_checkpointing_enable ()
90104
91105 if args .optimizer == "adam" :
92106 self .optimizer = torch .optim .AdamW (
@@ -188,6 +202,69 @@ def setup_device_mesh(self) -> None:
188202 else :
189203 logger .info (f"[Rank { rank } ] Pure DP mode (cp_size=1)" )
190204
205+ def _get_init_weight_context_manager (self ):
206+ """Get context manager for model initialization.
207+
208+ Returns a callable that creates a context manager.
209+ Uses meta device (no memory allocation) for non-rank-0 processes,
210+ UNLESS tie_word_embeddings=True (which causes hangs with meta tensors).
211+
212+ Ref: verl/utils/fsdp_utils.py::get_init_weight_context_manager
213+ NOTE: tie_word_embedding causes meta_tensor init to hang
214+ """
215+ from accelerate import init_empty_weights
216+
217+ # Check if model uses tied word embeddings (which doesn't work with meta tensors)
218+ use_meta_tensor = not self .hf_config .tie_word_embeddings
219+
220+ cpu_init_weights = lambda : torch .device ("cpu" )
221+
222+ if use_meta_tensor :
223+ # Rank 0: CPU, others: meta device (memory efficient for large models)
224+ return init_empty_weights if dist .get_rank () != 0 else cpu_init_weights
225+ else :
226+ logger .info (f"[Rank { dist .get_rank ()} ] tie_word_embeddings=True, loading full model to CPU on all ranks" )
227+ return cpu_init_weights
228+
229+ def _fsdp2_load_full_state_dict (self , model , full_state , device_mesh , cpu_offload ):
230+ """Load full state dict into FSDP2 model with efficient broadcast from rank 0.
231+
232+ This function loads weights from rank 0 and broadcasts to all other ranks,
233+ avoiding the need for each rank to load the full model from disk.
234+
235+ Args:
236+ model: FSDP2-wrapped model
237+ full_state: State dict (only rank 0 has real weights, others have empty dict)
238+ device_mesh: Device mesh for FSDP
239+ cpu_offload: If not None, enables StateDictOptions cpu_offload
240+
241+ Ref:verl/utils/fsdp_utils.py::fsdp2_load_full_state_dict
242+ """
243+ from torch .distributed .checkpoint .state_dict import StateDictOptions , set_model_state_dict
244+
245+ # Rank 0: move with weights, others: allocate empty tensors on device
246+ if dist .get_rank () == 0 :
247+ model = model .to (device = torch .cuda .current_device (), non_blocking = True )
248+ else :
249+ # to_empty creates tensors on device without initializing memory
250+ model = model .to_empty (device = torch .cuda .current_device ())
251+
252+ is_cpu_offload = cpu_offload is not None
253+ options = StateDictOptions (full_state_dict = True , cpu_offload = is_cpu_offload , broadcast_from_rank0 = True )
254+
255+ set_model_state_dict (model , full_state , options = options )
256+
257+ # set_model_state_dict will not broadcast buffers, so we need to broadcast them manually.
258+ for name , buf in model .named_buffers ():
259+ dist .broadcast (buf , src = 0 )
260+
261+ if is_cpu_offload :
262+ model .to ("cpu" , non_blocking = True )
263+ for buf in model .buffers ():
264+ buf .data = buf .data .to (torch .cuda .current_device ())
265+
266+ return model
267+
191268 @timer
192269 def sleep (self ) -> None :
193270 """Pause CUDA memory for all tracked tensors."""
@@ -246,14 +323,11 @@ def compute_log_prob(
246323 """
247324 # Select which model to use
248325 if model_tag == "ref" and self .ref_model is not None :
249- # Offload actor model to CPU to save GPU memory
250- logger .info ("[Rank {}] Offloading actor model to CPU" .format (dist .get_rank ()))
251- self .model .cpu ()
252- torch .cuda .empty_cache ()
253-
254- # Load ref model to GPU
255- logger .info ("[Rank {}] Loading ref model to GPU" .format (dist .get_rank ()))
256- self .ref_model .cuda ()
326+ if not self .fsdp_cpu_offload :
327+ self .model .cpu ()
328+ torch .cuda .empty_cache ()
329+ dist .barrier (group = get_gloo_group ())
330+
257331 active_model = self .ref_model
258332 active_model .eval ()
259333 else :
@@ -285,11 +359,14 @@ def compute_log_prob(
285359 return rollout_data
286360
287361 finally :
288- # Offload ref model back to CPU
362+ # Restore actor model if it was offloaded
289363 if model_tag == "ref" and self .ref_model is not None :
290- self .ref_model .cpu ()
291364 torch .cuda .empty_cache ()
292- self .model .cuda ()
365+ dist .barrier (group = get_gloo_group ())
366+
367+ if not self .fsdp_cpu_offload :
368+ self .model .cuda ()
369+ dist .barrier (group = get_gloo_group ())
293370
294371 def packed_data (
295372 self , rollout_data : dict [str , list [torch .Tensor ]]
@@ -472,7 +549,7 @@ def _train_core(self, rollout_id: int, rollout_data) -> None:
472549 # Copy actor model state to ref model
473550 actor_state = self .model .state_dict ()
474551 self .ref_model .load_state_dict (actor_state )
475- self .ref_model .cpu () # Keep ref in CPU
552+ self .ref_model .cpu ()
476553
477554 def _train_step (self , packed_batch , reported_accum , mbs_id , grad_accum ):
478555 # Prepare model inputs
@@ -672,19 +749,19 @@ def update_weights(self) -> None: # type: ignore[override]
672749 clear_memory ()
673750
674751 def create_ref_model (self , ref_load_path : str | None ):
675- """Create and initialize a separate reference model (kept in CPU) .
752+ """Create and initialize a separate reference model with FSDP2 CPUOffloadPolicy .
676753
677754 Parameters:
678755 ref_load_path: Path to a directory containing a HF checkpoint. If
679756 None, a ValueError is raised.
680757
681758 Returns:
682- FSDP -wrapped ref model in CPU memory
759+ FSDP2 -wrapped ref model with CPU offload enabled
683760
684761 Note:
685- Creates a separate FSDP model instance for the reference model.
686- This model is kept in CPU and loaded to GPU only when needed in
687- compute_log_prob(). This approach is cleaner than weight swapping .
762+ Creates a separate FSDP2 model instance for the reference model.
763+ ALWAYS uses CPUOffloadPolicy for the reference model to save memory,
764+ regardless of the actor model's CPU offload setting .
688765 """
689766 if ref_load_path is None :
690767 raise ValueError ("ref_load_path must be provided when loading reference model" )
@@ -694,17 +771,22 @@ def create_ref_model(self, ref_load_path: str | None):
694771 if os .path .isdir (ref_load_path ):
695772 logger .info (f"[Rank { dist .get_rank ()} ] Creating separate ref model from { ref_load_path } " )
696773
697- # Load model same way as actor model
698- ref_model = AutoModelForCausalLM .from_pretrained (
699- ref_load_path ,
700- trust_remote_code = True ,
701- attn_implementation = self .args .attn_implementation ,
702- )
774+ init_context = self ._get_init_weight_context_manager ()
775+
776+ with init_context ():
777+ ref_model = AutoModelForCausalLM .from_pretrained (
778+ ref_load_path ,
779+ trust_remote_code = True ,
780+ attn_implementation = self .args .attn_implementation ,
781+ )
782+
783+ full_state = ref_model .state_dict ()
703784
704- ref_model = apply_fsdp2 (ref_model , mesh = self .dp_mesh )
705- ref_model .cpu ()
785+ # Always use CPUOffloadPolicy for reference, let FSDP2 handle the offload. It is faster than model.cpu().
786+ ref_model = apply_fsdp2 (ref_model , mesh = self .dp_mesh , cpu_offload = True )
787+ ref_model = self ._fsdp2_load_full_state_dict (ref_model , full_state , self .dp_mesh , cpu_offload = True )
706788
707- logger .info (f"[Rank { dist .get_rank ()} ] Reference model created and offloaded to CPU " )
789+ logger .info (f"[Rank { dist .get_rank ()} ] Reference model created with FSDP2 CPUOffloadPolicy " )
708790 return ref_model
709791 else :
710792 raise NotImplementedError (f"Loading from checkpoint file { ref_load_path } not yet implemented" )
0 commit comments