Skip to content

Issue Loading FSDP wrapped module using FULL_STATE_DICT type.  #141

@hbikki

Description

@hbikki

🐛 Describe the bug

Hello , I am working on training a pretrained hugging face model "t5-small". Using the torchsnpashot examples provided form the documentaion, I am able to save/load checkpoint for LOCAL_STATE_DICT type, I am also able to save the model checkpoint for FULL_STATE_DICT. But, when loading the full statedict checkpoint I am facing the below issue.

Versions:
pytorch = 2.0.0+cu117
torchx-nightly>=2023.3.15
torchsnapshot=0.1.0

Host Details:
The bellow training is tested on a single node with 8 NPROC_PER_NODE.

Code:

Model training code:

def train() -> None:
    init_process_group(backend="nccl")
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank())
    model = load_model("t5-small")

    fsdp_model = FSDP(
        model,
        auto_wrap_policy=functools.partial(
            transformer_auto_wrap_policy, transformer_layer_cls={T5Block}
        ),
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        device_id=local_rank(),
    )
    <-------training -loop-->
    <-------save_checkpoint-->

stateDictType = FULL_STATE_DICT
related saving/loading code:

  def save_checkpoint() -> None:
        with FSDP.state_dict_type(
            checkpoint.model,
            self.stateDictType):
            Snapshot.take(path=str(save_dir), app_state=app_state)

    def load_checkpoint() -> None:
        with FSDP.state_dict_type(checkpoint.model, self.stateDictType):
            Snapshot(path=str(load_dir)).restore(app_state=app_state)
   

Error stack trace:
https://pastebin.com/ih9qSbwR

.snapshot_metadata for the model on local rank:
https://pastebin.com/t6grkKyX

Does anyone know how to resolve this ? thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions