Checkpoint save/load FSDP and ShardTensor support#1472
Checkpoint save/load FSDP and ShardTensor support#1472pzharrington wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR introduces FSDP- and DTensor/ShardTensor-aware checkpoint save/load by wiring PyTorch's Distributed Checkpoint (DCP) state-dict APIs into The overall architecture is sound and the distributed mechanics (collective state-dict gathering, Issues found:
Important Files Changed
Last reviewed commit: a5e258e |
|
@pzharrington Does this need review from shard tensor side too? |
examples/minimal/ShardTensorExamples/1_vector_addition/vector_add_sharded.py
Show resolved
Hide resolved
|
I took a look through this PR - over all I think this is much needed (and long overdue) functionality, thank you for finally taking action when no one else would! Overall I have one concern to discuss: over the next release I think it's important to decouple the ShardTensor(DTensor) inheritance structure (I know that's not a surprise just raising it in this context... ). Will that break anything you've implemented here? We probably will need to use distributed tooling but there could be weird behavior introduced if I do that. What do you think? |
|
Copying response from Slack here for posterity
|
|
/blossom-ci |
|
/blossom-ci |
PhysicsNeMo Pull Request
Description
Summary
save_checkpointandload_checkpointnow automatically detect FSDP-wrapped and DTensor/ShardTensor-distributed models and use PyTorch's Distributed Checkpoint (DCP) state-dict APIs to gather/scatter model and optimizer state. In distributed mode all ranks call the functions collectively, while only rank 0 performs file I/O. This eliminates the need for manual parameter gathering/scattering that recipe code (e.g. StormCast) previously had to implement.load_model_weightsutility: A convenience function for loading a single.mdlusor.ptfile directly into a (potentially distributed) model, handling FSDP + DTensor redistribution automatically.parallel.py(gather_training_state,scatter_optimizer_state,shard_state_dict,scatter_object,get_state_dict_shard)and ~50 lines of rank-0 CPU model/optimizer bookkeeping from trainer.py. All ranks now participate symmetrically in_resume_or_init, callingload_checkpoint/save_checkpointdirectly.physicsnemo.core.Module.save: Added an optionalstate_dictparameter sosave_checkpointcan pass a pre-gathered full state dictionary for FSDP/DTensor models without callingself.state_dict()on the distributed module.StateDictOptions.broadcast_from_rank0(used in the pure-FSDP load path) was introduced in PyTorch 2.5. This option enables rank 0 to broadcast the full state dict to all other ranks without manual scatter, which is the standard non-DTensor distributed load mechanism.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.