diff --git a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py index ff55cf038d4..d92eb79c15e 100644 --- a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +++ b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py @@ -83,6 +83,8 @@ def __init__( if has_config_logger_enabled(config): log_config_to_disk(config, locals(), prefix=type(self).__name__) + self.num_moe_experts = getattr(config, "num_moe_experts", None) + self.ddp_config = ddp_config log_single_rank( logger, @@ -279,7 +281,7 @@ def _init_dist_index(self, pg_collection): expt_tp_group = single_rank_group if enable_hsdp: - if expt_dp_group is not None: + if self.num_moe_experts is not None: expt_mesh = _get_hsdp_tp_mesh( outer_fsdp_group, expt_dp_group, expt_tp_group, ep_size=ep_group.size() ) @@ -308,7 +310,7 @@ def _init_dist_index(self, pg_collection): expt_device_mesh=expt_device_mesh, ) else: - if ep_group is not None: + if self.num_moe_experts is not None: expt_mesh = _get_dp_tp_mesh(expt_dp_group, expt_tp_group, ep_size=ep_group.size()) expt_device_mesh = DeviceMesh.from_group( [expt_dp_group, expt_tp_group], diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index b961a449d3e..6dd5ab6d342 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -514,10 +514,6 @@ def __init__( self.hsdp_outer_dp_shard = hsdp_outer_dp_shard self.expt_device_mesh = expt_device_mesh - # Handling the situation where M-Core MoE EP=1 - if self.expt_device_mesh is None: - self.expt_device_mesh = device_mesh - # Hybrid FSDP Process Groups # Retrieve the FSDP process group from the DeviceMesh. self.fsdp_group = ( diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index d4c664cda9c..d88abb20514 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -225,6 +225,52 @@ def train_step(model, optimizer, inputs): msg=f"Parameters for {name1} don't match", ) + def test_fsdp_expt_device_mesh(self): + """Test that expt_device_mesh is None for dense models and not None for MoE models.""" + if not is_torch_min_version("2.4.0"): + pytest.skip("Megatron FSDP requires torch >= 2.4.0") + + fsdp_config = DistributedDataParallelConfig( + data_parallel_sharding_strategy="optim_grads_params", + overlap_grad_reduce=True, + overlap_param_gather=True, + bucket_size=10000, + use_megatron_fsdp=True, + ) + input_dim, output_dim = 13, 17 + + # Dense model: expt_device_mesh should not be built without MoE config + dense_config = TransformerConfig( + num_attention_heads=1, num_layers=1, context_parallel_size=1 + ) + dense_model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() + fsdp_dense = FullyShardedDataParallel( + config=dense_config, + ddp_config=fsdp_config, + module=dense_model, + fsdp_unit_modules=[torch.nn.Linear], + ) + assert ( + fsdp_dense.megatron_fsdp_dist_index.expt_device_mesh is None + ), "Dense model: expt_device_mesh should be None" + fsdp_dense.stop_communication() + + # MoE model: expt_device_mesh should be built when num_moe_experts is set + moe_config = TransformerConfig( + num_attention_heads=1, num_layers=1, context_parallel_size=1, num_moe_experts=4 + ) + moe_model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda() + fsdp_moe = FullyShardedDataParallel( + config=moe_config, + ddp_config=fsdp_config, + module=moe_model, + fsdp_unit_modules=[torch.nn.Linear], + ) + assert ( + fsdp_moe.megatron_fsdp_dist_index.expt_device_mesh is not None + ), "MoE model: expt_device_mesh should not be None" + fsdp_moe.stop_communication() + # Testing fsdp_double_buffer with and without nccl_ub @pytest.mark.parametrize( ("dp_size", "nccl_ub", "fsdp_double_buffer", "fsdp_manual_registration"),