Llama3.2 vision does not run with distributed state dict #2277
Open
Description
Running Llama3.2 vision full finetune distributed using distributed state dict, I am running into errors that come from set_model_state_dict
.
Full error log:
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 911, in <module>
[rank2]: sys.exit(recipe_main())
[rank2]: ^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]: sys.exit(recipe_main(conf))
[rank2]: ^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 905, in recipe_main
[rank2]: recipe.setup(cfg=cfg)
[rank2]: File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 253, in setup
[rank2]: self._model = self._setup_model(
[rank2]: ^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 546, in _setup_model
[rank2]: training.load_from_full_model_state_dict(
[rank2]: File "/home/jessicazhong/torchtune/torchtune/training/_distributed.py", line 216, in load_from_full_model_state_dict
[rank2]: return set_model_state_dict(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py", line 1218, in set_model_state_dict
[rank2]: return _load_model_state_dict(model, model_state_dict, info)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py", line 591, in _load_model_state_dict
[rank2]: _state_dict_fn(model, "load_state_dict")(
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2561, in load_state_dict
[rank2]: load(self, state_dict)
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2549, in load
[rank2]: load(child, child_state_dict, child_prefix) # noqa: F821
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2549, in load
[rank2]: load(child, child_state_dict, child_prefix) # noqa: F821
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2549, in load
[rank2]: load(child, child_state_dict, child_prefix) # noqa: F821
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2554, in load
[rank2]: out = hook(module, incompatible_keys)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 251, in <lambda>
[rank2]: lambda *args, **kwargs: self.reset_sharded_param()
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 839, in reset_sharded_param
[rank2]: local_tensor = new_param._local_tensor
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: AttributeError: 'Parameter' object has no attribute '_local_tensor'
Metadata
Assignees
Labels
No labels