Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Adds GLOBE model (`physicsnemo.experimental.models.globe.model.GLOBE`)
- Adds GLOBE AirFRANS example case (`examples/cfd/external_aerodynamics/globe/airfrans`)
- Adds automatic support for `FSDP` and/or `ShardTensor` models in checkpoint save/load
functionality

### Changed

Expand All @@ -25,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Dependencies

- Increments minimum viable PyTorch version to `torch>=2.5.0` to support FSDP better

## [2.0.0] - 2026-XX-YY

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import torch
import time

from physicsnemo.distributed import DistributedManager, scatter_tensor
from physicsnemo.distributed import DistributedManager
from physicsnemo.domain_parallel import scatter_tensor
from torch.distributed.tensor.placement_types import Shard

# Another really big tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import torch.distributed as dist
import time

from physicsnemo.distributed import DistributedManager, scatter_tensor, ShardTensor
from physicsnemo.distributed import DistributedManager
from physicsnemo.domain_parallel import scatter_tensor, ShardTensor
from torch.distributed.tensor.placement_types import Shard, Replicate


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@
)


from physicsnemo.distributed import (
DistributedManager,
ShardTensor,
scatter_tensor,
)
from physicsnemo.distributed import DistributedManager
from physicsnemo.domain_parallel import ShardTensor, scatter_tensor

DistributedManager.initialize()
dm = DistributedManager()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP

# Imports for Domain Parallelism
from physicsnemo.distributed import DistributedManager, scatter_tensor
from physicsnemo.domain_parallel import scatter_tensor
from torch.distributed.tensor import distribute_module, distribute_tensor

# FSDP instead of DDP
Expand Down
5 changes: 5 additions & 0 deletions examples/weather/stormcast/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.mdlus
*.png
*.pt
*.tfevents*
*wandb/
2 changes: 1 addition & 1 deletion examples/weather/stormcast/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ These models can make longer forecasts (more than one timestep) during inference
## Getting started

### Preliminaries
Start by installing PhysicsNeMo (if not already installed) and copying this folder (`examples/weather/stormcast`) to a system with a GPU available. Also, prepare a combined HRRR/ERA5 dataset in the form specified in `datasets/data_loader_hrrr_era5.py` or implement a custom dataset class as shown below under [Adding custom datasets](#adding-custom-datasets). (**Note: subsequent versions of this example will include more detailed dataset preparation instructions**)
Start by installing PhysicsNeMo (if not already installed) with the `datapipes-extras`, `nn-extras`, and `utils-extras` optional dependency groups, and copy this folder (`examples/weather/stormcast`) to a system with a GPU available. Also, prepare a combined HRRR/ERA5 dataset in the form specified in `datasets/data_loader_hrrr_era5.py` or implement a custom dataset class as shown below under [Adding custom datasets](#adding-custom-datasets).

### Testing

Expand Down
1 change: 1 addition & 0 deletions examples/weather/stormcast/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pyproj
pydantic
torch-optimi[optimi]
tensorboard
2 changes: 1 addition & 1 deletion examples/weather/stormcast/utils/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_preconditioned_natten_dit(
patch_size: int = 4,
attn_kernel_size: int = 31,
lead_time_steps: int = 0,
layernorm_backend: Literal["torch", "apex"] = "apex",
layernorm_backend: Literal["torch", "apex"] = "torch",
conditioning_embedder: Literal["dit", "edm", "zero"] = "dit",
**model_kwargs,
) -> EDMPreconditioner:
Expand Down
198 changes: 0 additions & 198 deletions examples/weather/stormcast/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,10 @@

import numpy as np
import torch
from torch.distributed.checkpoint.state_dict import (
get_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
BackwardPrefetch,
OptimStateKeyType,
)
from torch.distributed.tensor import distribute_module, distribute_tensor
from torch.distributed.tensor.placement_types import Replicate, Shard
Expand Down Expand Up @@ -257,146 +251,6 @@ def distribute_model(self, model: torch.nn.Module) -> FSDP:
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Backward prefetching for overlap
)

def scatter_object(self, x: Any | None) -> Any:
"""Scatter a Python object from rank 0 to all ranks.

Parameters
----------
x : Any or None
Object to scatter from rank 0.

Returns
-------
Any
Object received by the local rank.
"""
states_to_sync = [x] * self.dist.world_size if self.dist.rank == 0 else None
output_list = [None]
torch.distributed.barrier()
torch.distributed.scatter_object_list(output_list, states_to_sync, src=0)
return output_list[0]

def shard_state_dict(self, state_dict: dict[str, Any] | None) -> dict[str, Any]:
"""Shard a state dict across the domain mesh and scatter.

Parameters
----------
state_dict : dict[str, Any] or None
Full state dict provided on rank 0.

Returns
-------
dict[str, Any]
Sharded state dict for the local rank.
"""
if self.dist.rank == 0:
# shard of state dict for each domain rank
shards = [
self.get_state_dict_shard(state_dict, domain_rank=i)
for i in range(self.domain_parallel_size)
]
# shard of state dict for each global rank
shards = [
shards[i % self.domain_parallel_size]
for i in range(self.dist.world_size)
]

states_to_sync = shards if self.dist.rank == 0 else None
output_list = [None]
torch.distributed.barrier()
torch.distributed.scatter_object_list(output_list, states_to_sync, src=0)
return output_list[0]

def scatter_optimizer_state(
self,
model_full: torch.nn.Module | None,
optimizer_full: torch.optim.Optimizer | None,
scheduler_full: torch.optim.lr_scheduler.LRScheduler | None,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler | None,
):
"""Scatter and load optimizer and scheduler state.

Parameters
----------
model_full : torch.nn.Module or None
Full model on rank 0 (used for rekeying).
optimizer_full : torch.optim.Optimizer or None
Full optimizer on rank 0.
scheduler_full : torch.optim.lr_scheduler.LRScheduler or None
Full scheduler on rank 0.
model : torch.nn.Module
Local model instance.
optimizer : torch.optim.Optimizer
Local optimizer instance.
scheduler : torch.optim.lr_scheduler.LRScheduler or None
Local scheduler instance.
"""
if self.dist.rank == 0:
optim_state_dict = optimizer_full.state_dict()
if isinstance(model, FSDP):
optim_state_dict = FSDP.rekey_optim_state_dict(
optim_state_dict, OptimStateKeyType.PARAM_NAME, model_full
)

if self.use_shard_tensor:
# shard positional embeddings
optim_state_dict = self.shard_state_dict(
optim_state_dict if self.dist.rank == 0 else None
)
else:
optim_state_dict = self.scatter_object(
optim_state_dict if self.dist.rank == 0 else None
)

options = StateDictOptions(full_state_dict=True)
set_optimizer_state_dict(model, optimizer, optim_state_dict, options=options)

if scheduler is not None:
sched_state_dict_full = (
None if scheduler_full is None else scheduler_full.state_dict()
)
sched_state_dict_full = self.scatter_object(sched_state_dict_full)
scheduler.load_state_dict(sched_state_dict_full)

def gather_training_state(
self,
model: FSDP,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler | None,
model_full: torch.nn.Module | None,
optimizer_full: torch.optim.Optimizer | None,
scheduler_full: torch.optim.lr_scheduler.LRScheduler | None,
):
"""Gather model and optimizer state onto rank 0.

Parameters
----------
model : torch.distributed.fsdp.FullyShardedDataParallel
Distributed model wrapper.
optimizer : torch.optim.Optimizer
Local optimizer.
scheduler : torch.optim.lr_scheduler.LRScheduler or None
Local scheduler.
model_full : torch.nn.Module or None
Full model to populate on rank 0, or None if rank != 0.
optimizer_full : torch.optim.Optimizer or None
Full optimizer to populate on rank 0, or None if rank != 0.
scheduler_full : torch.optim.lr_scheduler.LRScheduler or None
Full scheduler to populate on rank 0, or None if rank != 0.
"""
# TODO: we should be using the cpu_offload=True option but it seems to cause this to hang
options = StateDictOptions(full_state_dict=True)
(state_dict, optim_state_dict) = get_state_dict(
model, optimizer, options=options
)
if self.dist.rank == 0:
model_full.load_state_dict(state_dict)
optimizer_full.load_state_dict(optim_state_dict)
if scheduler is not None:
scheduler_full.load_state_dict(scheduler.state_dict())

def nested_scatter(
self,
x: torch.Tensor | Mapping | list | tuple | Any,
Expand Down Expand Up @@ -454,58 +308,6 @@ def nested_scatter(

return x

def get_state_dict_shard(
self,
x: Any,
domain_rank: int | None = None,
_key: str = "",
) -> Any:
"""Extract shard of a nested state dict for one domain rank.

Parameters
----------
x : Any
State dict or nested structure.
domain_rank : int or None, optional
Domain rank to shard for.

Returns
-------
Any
Sharded structure for the target domain rank.
"""
if domain_rank is None:
domain_rank = self.domain_rank

kwargs = {"domain_rank": domain_rank}
if isinstance(x, Mapping):
return {
k: self.get_state_dict_shard(v, _key=(_key + "." + k), **kwargs)
for (k, v) in x.items()
}
elif isinstance(x, (list, tuple)):
return [
self.get_state_dict_shard(v, _key=(_key + "." + str(i)), **kwargs)
for (i, v) in enumerate(x)
]
else:
shard_dim = shard_dim_selector(_key)
if (
isinstance(x, torch.Tensor)
and (shard_dim is not None)
and (shard_dim < x.ndim)
):
shard_size = x.shape[shard_dim] // self.domain_parallel_size
i0 = domain_rank * shard_size
i1 = i0 + shard_size
shard_slice = tuple(
slice(i0, i1) if i == shard_dim else slice(None)
for i in range(x.ndim)
)
return x[shard_slice]
else:
return x


def shard_dim_selector(param_name: str) -> int | None:
"""
Expand Down
Loading