Skip to content

Commit 3719218

Browse files
committed
fix bug when both cpu_ram_efficient_loading and cpu_offload are enabled
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 38dadd9 commit 3719218

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def param_init_fn_tied_param(module: torch.nn.Module):
464464
return param_init_fn_tied_param
465465

466466

467-
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
467+
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict, cpu_offload: bool = False):
468468
"""
469469
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
470470
parameters from rank 0 to all other ranks. This function modifies the model in-place.
@@ -474,6 +474,8 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
474474
model (`torch.nn.Module`):
475475
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
476476
full_sd (`dict`): The full state dict to load, can only be on rank 0
477+
cpu_offload (`bool`, defaults to `False`):
478+
If True, move sharded parameters to CPU after distribution. Required when FSDP CPU offloading is enabled.
477479
"""
478480
import torch.distributed as dist
479481
from torch.distributed.tensor import DTensor, distribute_tensor
@@ -525,6 +527,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
525527
full_param,
526528
)
527529
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
530+
# When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU
531+
if cpu_offload:
532+
sharded_tensor = sharded_tensor.to("cpu")
528533
sharded_sd[param_name] = sharded_tensor
529534
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
530535
else:
@@ -539,6 +544,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
539544
full_tensor,
540545
)
541546
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
547+
# When CPU offloading is enabled, FSDP2's lazy_init expects parameters on CPU
548+
if cpu_offload:
549+
sharded_tensor = sharded_tensor.to("cpu")
542550
sharded_sd[param_name] = sharded_tensor
543551

544552
# we set `assign=True` because our params are on meta device
@@ -686,7 +694,8 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
686694
if fsdp2_plugin.cpu_ram_efficient_loading:
687695
# If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights
688696
# Other ranks have an empty model on `meta` device, so we need to distribute the weights properly
689-
fsdp2_load_full_state_dict(accelerator, model, original_sd)
697+
# When CPU offloading is enabled, parameters need to stay on CPU after distribution
698+
fsdp2_load_full_state_dict(accelerator, model, original_sd, cpu_offload=bool(fsdp2_plugin.cpu_offload))
690699

691700
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
692701
# We re-register the buffers, as they may not be in the state_dict

0 commit comments

Comments
 (0)