1+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
3+ import logging
14import shutil
2- from contextlib import nullcontext
35from copy import deepcopy
46from pathlib import Path
57
1214
1315from tests .unit_tests .test_utilities import Utils
1416
17+ logger = logging .getLogger (__name__ )
18+
1519HSDP = "hsdp"
1620DP = "dp"
1721DP_SHARD = "dp_shard"
3640
3741
3842def destroy_device_mesh (device_mesh ):
39- from torch .distributed .device_mesh import _mesh_resources
4043
4144 # Teardown device mesh.
4245 del device_mesh
43- _mesh_resources .mesh_stack .clear ()
44- _mesh_resources .child_to_root_mapping .clear ()
45- _mesh_resources .root_to_flatten_mapping .clear ()
46- _mesh_resources .flatten_name_to_root_dims .clear ()
47- _mesh_resources .mesh_dim_group_options .clear ()
46+ try :
47+ from torch .distributed .device_mesh import _mesh_resources
48+
49+ _mesh_resources .child_to_root_mapping .clear ()
50+ _mesh_resources .root_to_flatten_mapping .clear ()
51+ _mesh_resources .mesh_stack .clear ()
52+ _mesh_resources .mesh_dim_group_options .clear ()
53+ _mesh_resources .flatten_name_to_root_dims .clear ()
54+ except Exception as e :
55+ # Global _MeshEnv is on a convoluted deprecation path.
56+ # Attempt to clean the global state, otherwise skip.
57+ logger .warning (f"Did not clean the deprecated DeviceMesh global state. Skipping...\n { e } " )
58+ pass
4859
4960
5061class ToyCNN (torch .nn .Module ):
@@ -127,9 +138,9 @@ def forward(self, x):
127138 return x
128139
129140
130- def build_toy_model_and_optimizer (model_type : str , init_model_with_meta_device : bool , seed = None ):
141+ def build_toy_model (model_type : str , init_model_with_meta_device : bool , seed = None ):
131142 """
132- Helper function to build a toy model and optimizer for testing Megatron-FSDP.
143+ Helper function to build a toy model for testing Megatron-FSDP.
133144 """
134145 # Set the seed to make sure the same model is initialized on all ranks.
135146 if seed is not None :
@@ -158,10 +169,9 @@ def build_toy_model_and_optimizer(model_type: str, init_model_with_meta_device:
158169 model_dim = DIM_SIZE , num_heads = 2 , num_layers = NUM_LAYERS , output_dim = DIM_SIZE
159170 )
160171 fsdp_unit_modules = [te .pytorch .TransformerLayer ]
161- toy_adam = Adam (params = toy_model .parameters (), lr = 0.01 )
162172
163173 # Return the toy model, optimizer, and FSDP unit modules.
164- return toy_model , toy_adam , fsdp_unit_modules
174+ return toy_model , fsdp_unit_modules
165175
166176
167177def build_distributed_environment (mesh_dim_config : tuple ):
@@ -264,9 +274,8 @@ def test_fully_shard(
264274 device_mesh = build_distributed_environment (mesh_dim_config )
265275
266276 # Construct toy model.
267- toy_model , toy_adam , fsdp_unit_modules = build_toy_model_and_optimizer (
268- model_type , init_model_with_meta_device
269- )
277+ toy_model , fsdp_unit_modules = build_toy_model (model_type , init_model_with_meta_device )
278+ toy_adam = Adam (params = toy_model .parameters (), lr = 0.01 )
270279
271280 # Wrap in fully_shard.
272281 model , optimizer = fully_shard (
@@ -315,7 +324,7 @@ def test_fully_shard(
315324 # Validate gradients exist in the Torch Module, i.e. non-None and non-zero.
316325 grads_exist = any (
317326 isinstance (p .grad , torch .Tensor ) and p .grad .to_local ().count_nonzero ().item () > 0
318- for p in model .module . parameters ()
327+ for p in model .parameters ()
319328 )
320329 sharding_group = (
321330 device_mesh [HSDP ].get_group ()
@@ -326,27 +335,19 @@ def test_fully_shard(
326335 # Because of uneven sharding, we need to gather the result from all ranks
327336 # to verify if any gradients exist or not at this step of training.
328337 grads_exist_gathered = [None ] * sharding_group .size ()
329- torch .distributed .gather_object (
330- grads_exist ,
331- object_gather_list = grads_exist_gathered if sharding_group .rank () == 0 else None ,
332- group = sharding_group ,
333- group_dst = 0 ,
338+ torch .distributed .all_gather_object (
339+ object_list = grads_exist_gathered , obj = grads_exist , group = sharding_group
334340 )
335- if sharding_group .rank () == 0 :
336- # Gradients exist on at least one of the optimizer sharding ranks.
337- # Update grads_exist on Rank 0 only.
338- grads_exist = any (grads_exist_gathered )
339- torch .distributed .barrier ()
341+ # Gradients exist on at least one of the optimizer sharding ranks.
342+ grads_exist = any (grads_exist_gathered )
340343
341344 # Gradients do not exist until synchronization is activated.
342- # Use collected result on Rank 0 only.
343- if sharding_group .rank () == 0 :
344- if step == NUM_STEPS - 1 :
345- assert grads_exist , "Root module gradients should exist on final microbatch."
346- else :
347- assert (
348- not grads_exist
349- ), "Root module gradients should not exist prior to optimization step."
345+ if step == NUM_STEPS - 1 :
346+ assert grads_exist , "Root module gradients should exist on final microbatch."
347+ else :
348+ assert (
349+ not grads_exist
350+ ), "Root module gradients should not exist prior to optimization step."
350351 torch .distributed .barrier ()
351352
352353 # Optimizer step. Apply accumulated gradients to the model weights.
@@ -403,9 +404,8 @@ def test_dcp_checkpoint_save_and_load(
403404 accuracy tests are non-trivial, i.e. don't just use the initialized weights.
404405 """
405406 # Test model.
406- toy_model , toy_adam , fsdp_unit_modules = build_toy_model_and_optimizer (
407- model_type , False , seed = 0
408- )
407+ toy_model , fsdp_unit_modules = build_toy_model (model_type , False , seed = 0 )
408+ toy_adam = Adam (params = toy_model .parameters (), lr = 0.01 )
409409
410410 # Wrap in fully_shard.
411411 model , optimizer = fully_shard (
@@ -484,9 +484,8 @@ def test_dcp_checkpoint_save_and_load(
484484 """
485485 # Initialize a new model for checkpoint loading. Set a different seed to force a different model init,
486486 # to ensure the checkpoint loading is accurate and non-trivial.
487- toy_model , toy_adam , fsdp_unit_modules = build_toy_model_and_optimizer (
488- model_type , False , seed = 1
489- )
487+ toy_model , fsdp_unit_modules = build_toy_model (model_type , False , seed = 1 )
488+ toy_adam = Adam (params = toy_model .parameters (), lr = 0.01 )
490489
491490 # Wrap in fully_shard.
492491 model , optimizer = fully_shard (
@@ -598,3 +597,44 @@ def test_dcp_checkpoint_save_and_load(
598597
599598 # Destroy device mesh.
600599 destroy_device_mesh (device_mesh )
600+
601+ @pytest .mark .parametrize ("shard_strategy" , [OPTIM_GRADS_PARAMS , OPTIM_GRADS , OPTIM , NO_SHARD ])
602+ def test_fully_shard_ez (self , shard_strategy ):
603+ """
604+ Test fully_shard(device_mesh=None). Represents the easiest entrypoint to Megatron-FSDP.
605+ """
606+ from megatron .core .distributed .fsdp .src .megatron_fsdp .fully_shard import (
607+ fully_shard_model ,
608+ fully_shard_optimizer ,
609+ )
610+
611+ # Construct toy model.
612+ toy_model , fsdp_unit_modules = build_toy_model (TRANSFORMER , False )
613+
614+ # Fully-shard the model.
615+ mfsdp_model = fully_shard_model (
616+ module = toy_model , fsdp_unit_modules = fsdp_unit_modules , zero_dp_strategy = shard_strategy
617+ )
618+
619+ # Initialize the distributed optimizer on the MegatronFSDP model.
620+ toy_adam = Adam (params = mfsdp_model .parameters (), lr = 0.01 )
621+ optimizer = fully_shard_optimizer (optimizer = toy_adam )
622+
623+ # Mock input and target.
624+ toy_input = torch .randn (1 , DIM_SIZE , DIM_SIZE ).to ("cuda" )
625+ toy_target = torch .randn (1 , DIM_SIZE , DIM_SIZE ).to ("cuda" )
626+
627+ for step in range (NUM_STEPS ):
628+
629+ # Forward pass.
630+ output = mfsdp_model (toy_input , toy_input )
631+
632+ # Loss.
633+ loss = mse_loss (output , toy_target )
634+
635+ # Backward pass.
636+ loss .backward ()
637+
638+ # Optimizer step.
639+ optimizer .step ()
640+ optimizer .zero_grad ()
0 commit comments