Skip to content

Commit d7f60be

Browse files
authored
[FSDP] Optimize FSDP2 Model Loading with Rank-0 Broadcast (#915)
1 parent bbb67c0 commit d7f60be

File tree

1 file changed

+119
-37
lines changed

1 file changed

+119
-37
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 119 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
5555
if self.args.debug_rollout_only:
5656
return 0
5757

58+
self.fsdp_cpu_offload = getattr(self.args, "fsdp_cpu_offload", False)
59+
# Offload train and fsdp cpu offload cannot be used together, fsdp_cpu_offload is more aggressive
60+
if self.args.offload_train and self.fsdp_cpu_offload:
61+
self.args.offload_train = False
62+
5863
self._enable_true_on_policy_optimizations(args)
5964
if dist.get_rank() == 0:
6065
init_tracking(args, primary=False)
@@ -73,20 +78,29 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
7378
if self.args.multimodal_keys:
7479
self.vlm_processor = AutoProcessor.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
7580

76-
# Load model
77-
model = AutoModelForCausalLM.from_pretrained(
78-
self.args.hf_checkpoint,
79-
trust_remote_code=True,
80-
attn_implementation=self.args.attn_implementation,
81-
)
81+
init_context = self._get_init_weight_context_manager()
82+
83+
with init_context():
84+
model = AutoModelForCausalLM.from_pretrained(
85+
self.args.hf_checkpoint,
86+
trust_remote_code=True,
87+
attn_implementation=self.args.attn_implementation,
88+
)
89+
8290
model.train()
8391

84-
if args.gradient_checkpointing:
85-
model.gradient_checkpointing_enable()
92+
full_state = model.state_dict()
93+
94+
model = apply_fsdp2(model, mesh=self.dp_mesh, cpu_offload=self.fsdp_cpu_offload)
95+
96+
model = self._fsdp2_load_full_state_dict(
97+
model, full_state, self.dp_mesh, cpu_offload=True if self.fsdp_cpu_offload else None
98+
)
8699

87-
# Apply FSDP with DP mesh and CPU offload policy if requested
88-
cpu_offload = getattr(args, "fsdp_cpu_offload", False)
89-
self.model = apply_fsdp2(model, mesh=self.dp_mesh, cpu_offload=cpu_offload)
100+
self.model = model
101+
102+
if args.gradient_checkpointing:
103+
self.model.gradient_checkpointing_enable()
90104

91105
if args.optimizer == "adam":
92106
self.optimizer = torch.optim.AdamW(
@@ -188,6 +202,69 @@ def setup_device_mesh(self) -> None:
188202
else:
189203
logger.info(f"[Rank {rank}] Pure DP mode (cp_size=1)")
190204

205+
def _get_init_weight_context_manager(self):
206+
"""Get context manager for model initialization.
207+
208+
Returns a callable that creates a context manager.
209+
Uses meta device (no memory allocation) for non-rank-0 processes,
210+
UNLESS tie_word_embeddings=True (which causes hangs with meta tensors).
211+
212+
Ref: verl/utils/fsdp_utils.py::get_init_weight_context_manager
213+
NOTE: tie_word_embedding causes meta_tensor init to hang
214+
"""
215+
from accelerate import init_empty_weights
216+
217+
# Check if model uses tied word embeddings (which doesn't work with meta tensors)
218+
use_meta_tensor = not self.hf_config.tie_word_embeddings
219+
220+
cpu_init_weights = lambda: torch.device("cpu")
221+
222+
if use_meta_tensor:
223+
# Rank 0: CPU, others: meta device (memory efficient for large models)
224+
return init_empty_weights if dist.get_rank() != 0 else cpu_init_weights
225+
else:
226+
logger.info(f"[Rank {dist.get_rank()}] tie_word_embeddings=True, loading full model to CPU on all ranks")
227+
return cpu_init_weights
228+
229+
def _fsdp2_load_full_state_dict(self, model, full_state, device_mesh, cpu_offload):
230+
"""Load full state dict into FSDP2 model with efficient broadcast from rank 0.
231+
232+
This function loads weights from rank 0 and broadcasts to all other ranks,
233+
avoiding the need for each rank to load the full model from disk.
234+
235+
Args:
236+
model: FSDP2-wrapped model
237+
full_state: State dict (only rank 0 has real weights, others have empty dict)
238+
device_mesh: Device mesh for FSDP
239+
cpu_offload: If not None, enables StateDictOptions cpu_offload
240+
241+
Ref:verl/utils/fsdp_utils.py::fsdp2_load_full_state_dict
242+
"""
243+
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
244+
245+
# Rank 0: move with weights, others: allocate empty tensors on device
246+
if dist.get_rank() == 0:
247+
model = model.to(device=torch.cuda.current_device(), non_blocking=True)
248+
else:
249+
# to_empty creates tensors on device without initializing memory
250+
model = model.to_empty(device=torch.cuda.current_device())
251+
252+
is_cpu_offload = cpu_offload is not None
253+
options = StateDictOptions(full_state_dict=True, cpu_offload=is_cpu_offload, broadcast_from_rank0=True)
254+
255+
set_model_state_dict(model, full_state, options=options)
256+
257+
# set_model_state_dict will not broadcast buffers, so we need to broadcast them manually.
258+
for name, buf in model.named_buffers():
259+
dist.broadcast(buf, src=0)
260+
261+
if is_cpu_offload:
262+
model.to("cpu", non_blocking=True)
263+
for buf in model.buffers():
264+
buf.data = buf.data.to(torch.cuda.current_device())
265+
266+
return model
267+
191268
@timer
192269
def sleep(self) -> None:
193270
"""Pause CUDA memory for all tracked tensors."""
@@ -246,14 +323,11 @@ def compute_log_prob(
246323
"""
247324
# Select which model to use
248325
if model_tag == "ref" and self.ref_model is not None:
249-
# Offload actor model to CPU to save GPU memory
250-
logger.info("[Rank {}] Offloading actor model to CPU".format(dist.get_rank()))
251-
self.model.cpu()
252-
torch.cuda.empty_cache()
253-
254-
# Load ref model to GPU
255-
logger.info("[Rank {}] Loading ref model to GPU".format(dist.get_rank()))
256-
self.ref_model.cuda()
326+
if not self.fsdp_cpu_offload:
327+
self.model.cpu()
328+
torch.cuda.empty_cache()
329+
dist.barrier(group=get_gloo_group())
330+
257331
active_model = self.ref_model
258332
active_model.eval()
259333
else:
@@ -285,11 +359,14 @@ def compute_log_prob(
285359
return rollout_data
286360

287361
finally:
288-
# Offload ref model back to CPU
362+
# Restore actor model if it was offloaded
289363
if model_tag == "ref" and self.ref_model is not None:
290-
self.ref_model.cpu()
291364
torch.cuda.empty_cache()
292-
self.model.cuda()
365+
dist.barrier(group=get_gloo_group())
366+
367+
if not self.fsdp_cpu_offload:
368+
self.model.cuda()
369+
dist.barrier(group=get_gloo_group())
293370

294371
def packed_data(
295372
self, rollout_data: dict[str, list[torch.Tensor]]
@@ -472,7 +549,7 @@ def _train_core(self, rollout_id: int, rollout_data) -> None:
472549
# Copy actor model state to ref model
473550
actor_state = self.model.state_dict()
474551
self.ref_model.load_state_dict(actor_state)
475-
self.ref_model.cpu() # Keep ref in CPU
552+
self.ref_model.cpu()
476553

477554
def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
478555
# Prepare model inputs
@@ -672,19 +749,19 @@ def update_weights(self) -> None: # type: ignore[override]
672749
clear_memory()
673750

674751
def create_ref_model(self, ref_load_path: str | None):
675-
"""Create and initialize a separate reference model (kept in CPU).
752+
"""Create and initialize a separate reference model with FSDP2 CPUOffloadPolicy.
676753
677754
Parameters:
678755
ref_load_path: Path to a directory containing a HF checkpoint. If
679756
None, a ValueError is raised.
680757
681758
Returns:
682-
FSDP-wrapped ref model in CPU memory
759+
FSDP2-wrapped ref model with CPU offload enabled
683760
684761
Note:
685-
Creates a separate FSDP model instance for the reference model.
686-
This model is kept in CPU and loaded to GPU only when needed in
687-
compute_log_prob(). This approach is cleaner than weight swapping.
762+
Creates a separate FSDP2 model instance for the reference model.
763+
ALWAYS uses CPUOffloadPolicy for the reference model to save memory,
764+
regardless of the actor model's CPU offload setting.
688765
"""
689766
if ref_load_path is None:
690767
raise ValueError("ref_load_path must be provided when loading reference model")
@@ -694,17 +771,22 @@ def create_ref_model(self, ref_load_path: str | None):
694771
if os.path.isdir(ref_load_path):
695772
logger.info(f"[Rank {dist.get_rank()}] Creating separate ref model from {ref_load_path}")
696773

697-
# Load model same way as actor model
698-
ref_model = AutoModelForCausalLM.from_pretrained(
699-
ref_load_path,
700-
trust_remote_code=True,
701-
attn_implementation=self.args.attn_implementation,
702-
)
774+
init_context = self._get_init_weight_context_manager()
775+
776+
with init_context():
777+
ref_model = AutoModelForCausalLM.from_pretrained(
778+
ref_load_path,
779+
trust_remote_code=True,
780+
attn_implementation=self.args.attn_implementation,
781+
)
782+
783+
full_state = ref_model.state_dict()
703784

704-
ref_model = apply_fsdp2(ref_model, mesh=self.dp_mesh)
705-
ref_model.cpu()
785+
# Always use CPUOffloadPolicy for reference, let FSDP2 handle the offload. It is faster than model.cpu().
786+
ref_model = apply_fsdp2(ref_model, mesh=self.dp_mesh, cpu_offload=True)
787+
ref_model = self._fsdp2_load_full_state_dict(ref_model, full_state, self.dp_mesh, cpu_offload=True)
706788

707-
logger.info(f"[Rank {dist.get_rank()}] Reference model created and offloaded to CPU")
789+
logger.info(f"[Rank {dist.get_rank()}] Reference model created with FSDP2 CPUOffloadPolicy")
708790
return ref_model
709791
else:
710792
raise NotImplementedError(f"Loading from checkpoint file {ref_load_path} not yet implemented")

0 commit comments

Comments
 (0)