@@ -464,7 +464,7 @@ def param_init_fn_tied_param(module: torch.nn.Module):
464464 return param_init_fn_tied_param
465465
466466
467- def fsdp2_load_full_state_dict (accelerator , model : torch .nn .Module , full_sd : dict ):
467+ def fsdp2_load_full_state_dict (accelerator , model : torch .nn .Module , full_sd : dict , cpu_offload : bool = False ):
468468 """
469469 Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
470470 parameters from rank 0 to all other ranks. This function modifies the model in-place.
@@ -474,6 +474,8 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
474474 model (`torch.nn.Module`):
475475 The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
476476 full_sd (`dict`): The full state dict to load, can only be on rank 0
477+ cpu_offload (`bool`, defaults to `False`):
478+ If True, move sharded parameters to CPU after distribution. Required when FSDP CPU offloading is enabled.
477479 """
478480 import torch .distributed as dist
479481 from torch .distributed .tensor import DTensor , distribute_tensor
@@ -525,6 +527,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
525527 full_param ,
526528 )
527529 sharded_tensor = _cast_and_contiguous (sharded_tensor , to_contiguous , casting_dtype )
530+ # When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU
531+ if cpu_offload :
532+ sharded_tensor = sharded_tensor .to ("cpu" )
528533 sharded_sd [param_name ] = sharded_tensor
529534 # We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
530535 else :
@@ -539,6 +544,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
539544 full_tensor ,
540545 )
541546 sharded_tensor = _cast_and_contiguous (sharded_tensor , to_contiguous , casting_dtype )
547+ # When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU
548+ if cpu_offload :
549+ sharded_tensor = sharded_tensor .to ("cpu" )
542550 sharded_sd [param_name ] = sharded_tensor
543551
544552 # we set `assign=True` because our params are on meta device
@@ -686,7 +694,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
686694 if fsdp2_plugin .cpu_ram_efficient_loading :
687695 # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights
688696 # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly
689- fsdp2_load_full_state_dict (accelerator , model , original_sd )
697+ # When CPU offloading is enabled, parameters need to stay on CPU after distribution
698+ fsdp2_load_full_state_dict (accelerator , model , original_sd , cpu_offload = bool (fsdp2_plugin .cpu_offload ))
690699
691700 if fsdp2_plugin .cpu_ram_efficient_loading and not model_has_params4bit :
692701 # We re-register the buffers, as they may not be in the state_dict
0 commit comments