Skip to content

Commit 1a7e5a5

Browse files
committed
Cherry-pick bug fixes into 0.15.X.
Signed-off-by: Cory Ye <[email protected]>
1 parent 1221b91 commit 1a7e5a5

File tree

7 files changed

+474
-282
lines changed

7 files changed

+474
-282
lines changed

megatron/core/distributed/fsdp/src/README.md

Lines changed: 99 additions & 32 deletions
Large diffs are not rendered by default.

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 220 additions & 166 deletions
Large diffs are not rendered by default.

megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,14 @@ def __init__(
283283
self._register_fsdp_hooks(self.module)
284284
self.microbatch_count = 0
285285

286+
# Add a reference from the distributed parameters to self for API
287+
# accessibility, e.g. when attaching MegatronFSDP scheduled ops
288+
# to the distributed optimizer.step() and optimizer.zero_grad().
286289
self.is_param_fsdp_distributed = False
287290
self._replace_param_with_distributed_if_needed()
291+
for param in self.module.parameters():
292+
# Attach MegatronFSDP reference to the parameter.
293+
setattr(param, "_megatron_fsdp_model", self)
288294

289295
def _check_module_parameter_types(self):
290296
"""

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import torch
3232
from torch.distributed import _coalescing_manager
3333
from torch.distributed.tensor import DTensor, Replicate, Shard
34-
from torch.distributed.tensor.device_mesh import _mesh_resources
3534

3635
from .uneven_dtensor import update_uneven_dtensor_chunk_metadata, validate_uneven_dtensor
3736
from .utils import _MODEL_PARALLEL_RNG_TRACKER_NAME, FSDPDistributedIndex, get_global_memory_buffer
@@ -3525,20 +3524,6 @@ def _get_fsdp_tensor_spec(param, dist_index: FSDPDistributedIndex, is_sharded_pa
35253524
if isinstance(param, DTensor) and cast(DTensor, param)._spec.num_shards > 1:
35263525
# Retrieve original DTensorSpec (for TP).
35273526
dtensor_spec = cast(DTensor, param)._spec
3528-
dtensor_mesh = getattr(dtensor_spec, "mesh", None)
3529-
3530-
# Validate that the DTensor root mesh is identical to the Megatron-FSDP device mesh.
3531-
megatron_fsdp_global_mesh = dist_index.get_root_mesh()
3532-
dtensor_global_mesh = _mesh_resources.get_root_mesh(dtensor_mesh)
3533-
# FIXME(boxiangw): add or megatron_fsdp_global_mesh != dtensor_global_mesh:
3534-
# _mesh_resources.get_root_mesh(dtensor_mesh) is not getting the correct root mesh
3535-
if dtensor_global_mesh is None:
3536-
raise ValueError(
3537-
f"When utilizing DTensor-based modules with Megatron-FSDP, the DTensor root "
3538-
f"device mesh must be identical to the Megatron-FSDP root device mesh.\n"
3539-
f"DTensor Root Mesh: {dtensor_global_mesh} / Megatron-FSDP "
3540-
f"Root Mesh: {megatron_fsdp_global_mesh}"
3541-
)
35423527

35433528
# Get the placements for the parameter.
35443529
assert len(dtensor_spec.placements) == 1, (
@@ -3724,7 +3709,7 @@ def make_fsdp_dtensor(
37243709
device_mesh=tp_mesh,
37253710
placements=[Shard(tp_dim)],
37263711
run_check=run_check,
3727-
shape=global_shape,
3712+
shape=tuple(global_shape),
37283713
stride=torch.empty(global_shape).stride(),
37293714
)
37303715

megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType
2626
from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard
2727

28+
from .utils import get_mesh_names
29+
2830

2931
def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata:
3032
"""
@@ -272,7 +274,25 @@ def gather_uneven_dtensor_to_full_tensor(
272274
if not device_mesh.mesh_dim_names:
273275
process_group = device_mesh.get_group()
274276
else:
275-
process_group = device_mesh._flatten().get_group()
277+
# Check if the fully-flattened mesh exists first.
278+
full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names)
279+
if full_flattened_mesh_dim_name in get_mesh_names(device_mesh):
280+
# Retrieve the existing flattened DeviceMesh ProcessGroup.
281+
try:
282+
# Two Cases: Name is a root dimension, or using the old DeviceMesh
283+
# API which allows us to get flattened dimensions.
284+
process_group = device_mesh[full_flattened_mesh_dim_name].get_group()
285+
except:
286+
# Name is a flattened dimension that cannot be retrieved from the
287+
# DeviceMesh.__getitem__, so fall-back to new DeviceMesh API.
288+
process_group = (
289+
device_mesh._get_root_mesh()
290+
._flatten_mapping[full_flattened_mesh_dim_name]
291+
.get_group()
292+
)
293+
else:
294+
# Create the _-separated flattened DeviceMesh ProcessGroup.
295+
process_group = device_mesh._flatten().get_group()
276296

277297
# Collect chunk metadata for uneven shards (update if missing)
278298
if not hasattr(dtensor._local_tensor, "__create_chunk_list__"):

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from torch.cuda import _lazy_call, _lazy_init
3535
from torch.cuda import device as device_ctx_manager
3636
from torch.distributed import DeviceMesh, ProcessGroup
37-
from torch.distributed.device_mesh import _mesh_resources
3837

3938
logger = logging.getLogger(__name__)
4039

@@ -150,30 +149,50 @@ def is_float8tensor(tensor: torch.Tensor) -> bool:
150149
return HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS)
151150

152151

153-
def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]:
152+
def get_mesh_names(
153+
device_mesh: Optional[DeviceMesh] = None, only_submesh_dims: bool = False
154+
) -> list[str]:
154155
"""
155-
Get all the sub-mesh names in the DeviceMesh.
156+
Get all the sub-mesh ("dp", "cp", etc.) and flattened-mesh ("dp_cp", etc.) names
157+
in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.
156158
"""
157159
if device_mesh is None:
158160
# Device mesh does not exist.
159161
return []
160-
# Order of the returned list of mesh dimension names must match the order / index
161-
# of the root mesh dimension names followed by children / flattened sub-meshes:
162-
# [<root mesh dimension names>, <child mesh dimension names>]
163-
mesh_dim_names = (
162+
163+
# Sub-mesh dimension names.
164+
submesh_dim_names = (
164165
list(device_mesh.mesh_dim_names) if device_mesh.mesh_dim_names is not None else []
165166
)
166-
submesh_dim_names = [
167-
submesh_dim_name
168-
for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
169-
for submesh_dim_name in (child_mesh.mesh_dim_names or [])
170-
if root_mesh == device_mesh
171-
]
172-
# Combine without duplicate dimensions.
173-
for dim_name in submesh_dim_names:
174-
if dim_name not in mesh_dim_names:
175-
mesh_dim_names.append(dim_name)
176-
return mesh_dim_names
167+
168+
# Flattened mesh dimension names.
169+
try:
170+
# Retrieve all flattened meshes associated with DeviceMesh.
171+
# The flattened DeviceMesh are all located in the _flatten_mapping
172+
# dictionary of the root DeviceMesh.
173+
flatten_mesh_names = [
174+
flat_dim
175+
for flat_dim, flat_mesh in device_mesh._get_root_mesh()._flatten_mapping.items()
176+
]
177+
except AttributeError:
178+
# Fallback to the DeviceMesh global state to retrieve flattened
179+
# meshes associated with the DeviceMesh.
180+
from torch.distributed.device_mesh import _mesh_resources
181+
182+
flatten_mesh_names = [
183+
child_mesh_dim_name
184+
for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
185+
for child_mesh_dim_name in (child_mesh.mesh_dim_names or [])
186+
if root_mesh == device_mesh and child_mesh_dim_name not in submesh_dim_names
187+
]
188+
189+
# Order of the returned list of mesh dimension names must match the index
190+
# of the root mesh dimension names followed by flattened sub-meshes:
191+
# [<root mesh dimension names>, <flattened mesh dimension names>]
192+
if only_submesh_dims:
193+
return submesh_dim_names
194+
else:
195+
return submesh_dim_names + flatten_mesh_names
177196

178197

179198
def contains_submesh(
@@ -720,16 +739,14 @@ def __init__(
720739
self.hybrid_fsdp_group = hybrid_fsdp_group
721740

722741
"""
723-
Store a persistent reference to the core device meshes that back Megatron-FSDP.
724-
This is necessary because _MeshEnv (_mesh_resources) may not persist:
725-
- _mesh_resources.child_to_root_mapping
726-
- _mesh_resources.root_to_flatten_mapping
727-
- _mesh_resources.flatten_name_to_root_dims
728-
- ...
729-
during Torch Autograd, so child and flattened sub-meshes may be cleared.
730-
For example, this breaks Megatron-FSDP when self.dp_shard_dim is the flattened
731-
sub-mesh of the DP and CP root mesh dimensions.
732-
FIXME(@cspades): Identify the root cause of this behavior.
742+
Megatron-FSDP is responsible for storing all required DeviceMesh
743+
as per best practices recommended by the DeviceMesh API.
744+
745+
NOTE(@cspades): In PyTorch 2.11, retrieving flattened mesh dimensions
746+
will be impossible via the device_mesh[...] API. We will require all
747+
users to correctly _unflatten() their DeviceMesh such that all
748+
dimensions used by Megatron-FSDP are sub-meshes of the DeviceMesh.
749+
contains_submesh(...) -> get_mesh_names(only_submesh_dims=True).
733750
"""
734751
self.mesh_library = {}
735752
# TP Mesh
@@ -825,6 +842,9 @@ def get_outer_fsdp_group(self) -> ProcessGroup:
825842

826843
def get_root_mesh(self, is_expert_parallel: bool = False) -> DeviceMesh:
827844
"""Get the device mesh."""
845+
# NOTE(@cspades): This is FSDPDistributedIndex's root mesh, NOT the actual
846+
# root mesh that the DeviceMesh or expert DeviceMesh was un-flattened from.
847+
# To get the root mesh, use: DeviceMesh._get_root_mesh().
828848
if is_expert_parallel:
829849
raise NotImplementedError("Expert parallel is not supported in Megatron-FSDP.")
830850
return self.device_mesh

tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
3+
import logging
14
import shutil
2-
from contextlib import nullcontext
35
from copy import deepcopy
46
from pathlib import Path
57

@@ -12,6 +14,8 @@
1214

1315
from tests.unit_tests.test_utilities import Utils
1416

17+
logger = logging.getLogger(__name__)
18+
1519
HSDP = "hsdp"
1620
DP = "dp"
1721
DP_SHARD = "dp_shard"
@@ -36,15 +40,22 @@
3640

3741

3842
def destroy_device_mesh(device_mesh):
39-
from torch.distributed.device_mesh import _mesh_resources
4043

4144
# Teardown device mesh.
4245
del device_mesh
43-
_mesh_resources.mesh_stack.clear()
44-
_mesh_resources.child_to_root_mapping.clear()
45-
_mesh_resources.root_to_flatten_mapping.clear()
46-
_mesh_resources.flatten_name_to_root_dims.clear()
47-
_mesh_resources.mesh_dim_group_options.clear()
46+
try:
47+
from torch.distributed.device_mesh import _mesh_resources
48+
49+
_mesh_resources.child_to_root_mapping.clear()
50+
_mesh_resources.root_to_flatten_mapping.clear()
51+
_mesh_resources.mesh_stack.clear()
52+
_mesh_resources.mesh_dim_group_options.clear()
53+
_mesh_resources.flatten_name_to_root_dims.clear()
54+
except Exception as e:
55+
# Global _MeshEnv is on a convoluted deprecation path.
56+
# Attempt to clean the global state, otherwise skip.
57+
logger.warning(f"Did not clean the deprecated DeviceMesh global state. Skipping...\n{e}")
58+
pass
4859

4960

5061
class ToyCNN(torch.nn.Module):
@@ -127,9 +138,9 @@ def forward(self, x):
127138
return x
128139

129140

130-
def build_toy_model_and_optimizer(model_type: str, init_model_with_meta_device: bool, seed=None):
141+
def build_toy_model(model_type: str, init_model_with_meta_device: bool, seed=None):
131142
"""
132-
Helper function to build a toy model and optimizer for testing Megatron-FSDP.
143+
Helper function to build a toy model for testing Megatron-FSDP.
133144
"""
134145
# Set the seed to make sure the same model is initialized on all ranks.
135146
if seed is not None:
@@ -158,10 +169,9 @@ def build_toy_model_and_optimizer(model_type: str, init_model_with_meta_device:
158169
model_dim=DIM_SIZE, num_heads=2, num_layers=NUM_LAYERS, output_dim=DIM_SIZE
159170
)
160171
fsdp_unit_modules = [te.pytorch.TransformerLayer]
161-
toy_adam = Adam(params=toy_model.parameters(), lr=0.01)
162172

163173
# Return the toy model, optimizer, and FSDP unit modules.
164-
return toy_model, toy_adam, fsdp_unit_modules
174+
return toy_model, fsdp_unit_modules
165175

166176

167177
def build_distributed_environment(mesh_dim_config: tuple):
@@ -264,9 +274,8 @@ def test_fully_shard(
264274
device_mesh = build_distributed_environment(mesh_dim_config)
265275

266276
# Construct toy model.
267-
toy_model, toy_adam, fsdp_unit_modules = build_toy_model_and_optimizer(
268-
model_type, init_model_with_meta_device
269-
)
277+
toy_model, fsdp_unit_modules = build_toy_model(model_type, init_model_with_meta_device)
278+
toy_adam = Adam(params=toy_model.parameters(), lr=0.01)
270279

271280
# Wrap in fully_shard.
272281
model, optimizer = fully_shard(
@@ -315,7 +324,7 @@ def test_fully_shard(
315324
# Validate gradients exist in the Torch Module, i.e. non-None and non-zero.
316325
grads_exist = any(
317326
isinstance(p.grad, torch.Tensor) and p.grad.to_local().count_nonzero().item() > 0
318-
for p in model.module.parameters()
327+
for p in model.parameters()
319328
)
320329
sharding_group = (
321330
device_mesh[HSDP].get_group()
@@ -326,27 +335,19 @@ def test_fully_shard(
326335
# Because of uneven sharding, we need to gather the result from all ranks
327336
# to verify if any gradients exist or not at this step of training.
328337
grads_exist_gathered = [None] * sharding_group.size()
329-
torch.distributed.gather_object(
330-
grads_exist,
331-
object_gather_list=grads_exist_gathered if sharding_group.rank() == 0 else None,
332-
group=sharding_group,
333-
group_dst=0,
338+
torch.distributed.all_gather_object(
339+
object_list=grads_exist_gathered, obj=grads_exist, group=sharding_group
334340
)
335-
if sharding_group.rank() == 0:
336-
# Gradients exist on at least one of the optimizer sharding ranks.
337-
# Update grads_exist on Rank 0 only.
338-
grads_exist = any(grads_exist_gathered)
339-
torch.distributed.barrier()
341+
# Gradients exist on at least one of the optimizer sharding ranks.
342+
grads_exist = any(grads_exist_gathered)
340343

341344
# Gradients do not exist until synchronization is activated.
342-
# Use collected result on Rank 0 only.
343-
if sharding_group.rank() == 0:
344-
if step == NUM_STEPS - 1:
345-
assert grads_exist, "Root module gradients should exist on final microbatch."
346-
else:
347-
assert (
348-
not grads_exist
349-
), "Root module gradients should not exist prior to optimization step."
345+
if step == NUM_STEPS - 1:
346+
assert grads_exist, "Root module gradients should exist on final microbatch."
347+
else:
348+
assert (
349+
not grads_exist
350+
), "Root module gradients should not exist prior to optimization step."
350351
torch.distributed.barrier()
351352

352353
# Optimizer step. Apply accumulated gradients to the model weights.
@@ -403,9 +404,8 @@ def test_dcp_checkpoint_save_and_load(
403404
accuracy tests are non-trivial, i.e. don't just use the initialized weights.
404405
"""
405406
# Test model.
406-
toy_model, toy_adam, fsdp_unit_modules = build_toy_model_and_optimizer(
407-
model_type, False, seed=0
408-
)
407+
toy_model, fsdp_unit_modules = build_toy_model(model_type, False, seed=0)
408+
toy_adam = Adam(params=toy_model.parameters(), lr=0.01)
409409

410410
# Wrap in fully_shard.
411411
model, optimizer = fully_shard(
@@ -484,9 +484,8 @@ def test_dcp_checkpoint_save_and_load(
484484
"""
485485
# Initialize a new model for checkpoint loading. Set a different seed to force a different model init,
486486
# to ensure the checkpoint loading is accurate and non-trivial.
487-
toy_model, toy_adam, fsdp_unit_modules = build_toy_model_and_optimizer(
488-
model_type, False, seed=1
489-
)
487+
toy_model, fsdp_unit_modules = build_toy_model(model_type, False, seed=1)
488+
toy_adam = Adam(params=toy_model.parameters(), lr=0.01)
490489

491490
# Wrap in fully_shard.
492491
model, optimizer = fully_shard(
@@ -598,3 +597,44 @@ def test_dcp_checkpoint_save_and_load(
598597

599598
# Destroy device mesh.
600599
destroy_device_mesh(device_mesh)
600+
601+
@pytest.mark.parametrize("shard_strategy", [OPTIM_GRADS_PARAMS, OPTIM_GRADS, OPTIM, NO_SHARD])
602+
def test_fully_shard_ez(self, shard_strategy):
603+
"""
604+
Test fully_shard(device_mesh=None). Represents the easiest entrypoint to Megatron-FSDP.
605+
"""
606+
from megatron.core.distributed.fsdp.src.megatron_fsdp.fully_shard import (
607+
fully_shard_model,
608+
fully_shard_optimizer,
609+
)
610+
611+
# Construct toy model.
612+
toy_model, fsdp_unit_modules = build_toy_model(TRANSFORMER, False)
613+
614+
# Fully-shard the model.
615+
mfsdp_model = fully_shard_model(
616+
module=toy_model, fsdp_unit_modules=fsdp_unit_modules, zero_dp_strategy=shard_strategy
617+
)
618+
619+
# Initialize the distributed optimizer on the MegatronFSDP model.
620+
toy_adam = Adam(params=mfsdp_model.parameters(), lr=0.01)
621+
optimizer = fully_shard_optimizer(optimizer=toy_adam)
622+
623+
# Mock input and target.
624+
toy_input = torch.randn(1, DIM_SIZE, DIM_SIZE).to("cuda")
625+
toy_target = torch.randn(1, DIM_SIZE, DIM_SIZE).to("cuda")
626+
627+
for step in range(NUM_STEPS):
628+
629+
# Forward pass.
630+
output = mfsdp_model(toy_input, toy_input)
631+
632+
# Loss.
633+
loss = mse_loss(output, toy_target)
634+
635+
# Backward pass.
636+
loss.backward()
637+
638+
# Optimizer step.
639+
optimizer.step()
640+
optimizer.zero_grad()

0 commit comments

Comments
 (0)