Skip to content

Checkpoint save/load FSDP and ShardTensor support#1472

Open
pzharrington wants to merge 14 commits intoNVIDIA:mainfrom
pzharrington:fsdp-shardtensor-ckpt
Open

Checkpoint save/load FSDP and ShardTensor support#1472
pzharrington wants to merge 14 commits intoNVIDIA:mainfrom
pzharrington:fsdp-shardtensor-ckpt

Conversation

@pzharrington
Copy link
Collaborator

@pzharrington pzharrington commented Mar 5, 2026

PhysicsNeMo Pull Request

Description

Summary

  • FSDP/ShardTensor-aware checkpoint save and load: save_checkpoint and load_checkpoint now automatically detect FSDP-wrapped and DTensor/ShardTensor-distributed models and use PyTorch's Distributed Checkpoint (DCP) state-dict APIs to gather/scatter model and optimizer state. In distributed mode all ranks call the functions collectively, while only rank 0 performs file I/O. This eliminates the need for manual parameter gathering/scattering that recipe code (e.g. StormCast) previously had to implement.
  • New load_model_weights utility: A convenience function for loading a single .mdlus or .pt file directly into a (potentially distributed) model, handling FSDP + DTensor redistribution automatically.
  • StormCast recipe simplification: Removed ~200 lines of manual checkpoint gather/scatter logic from parallel.py (gather_training_state, scatter_optimizer_state, shard_state_dict, scatter_object, get_state_dict_shard) and ~50 lines of rank-0 CPU model/optimizer bookkeeping from trainer.py. All ranks now participate symmetrically in _resume_or_init, calling load_checkpoint / save_checkpoint directly.
  • physicsnemo.core.Module.save: Added an optional state_dict parameter so save_checkpoint can pass a pre-gathered full state dictionary for FSDP/DTensor models without calling self.state_dict() on the distributed module.
  • Minimum torch version bump 2.4 → 2.5: Required because StateDictOptions.broadcast_from_rank0 (used in the pure-FSDP load path) was introduced in PyTorch 2.5. This option enables rank 0 to broadcast the full state dict to all other ranks without manual scatter, which is the standard non-DTensor distributed load mechanism.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington
Copy link
Collaborator Author

@greptileai

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces FSDP- and DTensor/ShardTensor-aware checkpoint save/load by wiring PyTorch's Distributed Checkpoint (DCP) state-dict APIs into save_checkpoint, load_checkpoint, and a new load_model_weights utility. It also removes ~250 lines of manual gather/scatter bookkeeping from the StormCast recipe and bumps the minimum PyTorch version to 2.5.0.

The overall architecture is sound and the distributed mechanics (collective state-dict gathering, broadcast_from_rank0, per-rank seeding) are correctly implemented. The new test suite in test_checkpoint_distributed.py is thorough and covers the key combinations of use_orig_params, sharding_strategy, and sync_module_states.

Issues found:

  • Race condition in save_checkpoint for mixed distributed/non-distributed model lists (physicsnemo/utils/checkpoint.py lines 548–554): When is_distributed=True (because at least one model in the list is FSDP-wrapped), _get_checkpoint_filename is called with distributed=True for all models, forcing model_parallel_rank=0 and therefore the same filename for all ranks. The else branch (non-distributed model) has no if should_write: guard, so every rank simultaneously writes to that file. The distributed load path explicitly comments that "a mix of distributed and non-distributed models is valid," so the save path needs a matching should_write guard.
  • All ranks download remote checkpoint in load_model_weights (physicsnemo/utils/checkpoint.py line 829): _is_mdlus_archive calls _cache_if_needed, which downloads remote files to a local cache. This executes on all ranks before the if not _is_distributed_model(model): branch, so for large models on object storage every rank incurs the full download cost even though only rank 0 reads the file in the distributed path.
  • Full optimizer state held in GPU memory on non-rank-0 ranks (physicsnemo/utils/checkpoint.py lines 568–585): get_optimizer_state_dict(full_state_dict=True) materialises the complete optimizer state on every rank, but _cpu_offload_state_dict is applied only on should_write (rank 0). Non-rank-0 ranks therefore unnecessarily hold a full copy in device memory until GC.

Important Files Changed

Filename Overview
physicsnemo/utils/checkpoint.py Core file of this PR; adds FSDP/DTensor-aware save/load logic. Two issues found: (1) non-distributed models in a mixed distributed+non-distributed save call are written by all ranks to the same file (missing should_write guard in the else branch); (2) load_model_weights calls _is_mdlus_archive on all ranks which downloads remote checkpoints unnecessarily. Also, get_optimizer_state_dict materialises the full optimizer state on every rank but only CPU-offloads it on rank 0.
physicsnemo/core/module.py Adds an optional state_dict parameter to Module.save() so that save_checkpoint can pass a pre-gathered full state dict for distributed models without calling self.state_dict() on the sharded module. Change is minimal and well-contained.
examples/weather/stormcast/utils/trainer.py Removes ~50 lines of rank-0-only CPU bookkeeping (net_full, optimizer_full, etc.) and replaces with symmetric all-rank participation. Seeding is now deterministic per (step, rank) rather than using a multiplied offset. All changes look correct and aligned with the new DCP-based checkpoint helpers.
examples/weather/stormcast/utils/parallel.py Removes ~200 lines of manual gather/scatter helpers (gather_training_state, scatter_optimizer_state, shard_state_dict, scatter_object, get_state_dict_shard). Also switches FSDP to sync_module_states=True which is now safe since all ranks initialize with the same seed.
test/utils/test_checkpoint_distributed.py New distributed test file with comprehensive coverage: plain FSDP (NO_SHARD + FULL_SHARD), FSDP+ShardTensor on 2-D mesh, load_model_weights, mixed distributed/non-distributed fallback, and GradScaler preservation. All tests look correct and thorough.
test/utils/test_checkpoint.py Refactors the moto mock to use context manager form (more idiomatic) and adds a new non-distributed test for load_model_weights covering both physicsnemo.Module and plain nn.Module. No issues found.

Last reviewed commit: a5e258e

@coreyjadams
Copy link
Collaborator

@pzharrington Does this need review from shard tensor side too?

@coreyjadams
Copy link
Collaborator

I took a look through this PR - over all I think this is much needed (and long overdue) functionality, thank you for finally taking action when no one else would!

Overall I have one concern to discuss: over the next release I think it's important to decouple the ShardTensor(DTensor) inheritance structure (I know that's not a surprise just raising it in this context... ). Will that break anything you've implemented here? We probably will need to use distributed tooling but there could be weird behavior introduced if I do that.

What do you think?

@pzharrington
Copy link
Collaborator Author

Copying response from Slack here for posterity

yes pivoting away from ShardTensor(DTensor) inheritance will break some stuff, but the level of refactoring needed I guess would depend on how far ShardTensor ends up drifting from DTensor. I wouldn't let the checkpoint functionality determine much in the "to DTensor or not to DTensor" decision as I assume whatever we end up with, there will be a non-horrible pathway to keeping the same user-facing functionality for checkpoints by shuffling things around under the hood. Bar none, eventually it will need some refactoring to support fsdp2 instead of FSDP (which is deprecating NO_SHARD)

@pzharrington pzharrington marked this pull request as ready for review March 10, 2026 00:43
@pzharrington pzharrington requested a review from ktangsali as a code owner March 10, 2026 00:43
@pzharrington
Copy link
Collaborator Author

/blossom-ci

@pzharrington
Copy link
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants