Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)

self.num_moe_experts = getattr(config, "num_moe_experts", None)

self.ddp_config = ddp_config
log_single_rank(
logger,
Expand Down Expand Up @@ -279,7 +281,7 @@ def _init_dist_index(self, pg_collection):
expt_tp_group = single_rank_group

if enable_hsdp:
if expt_dp_group is not None:
if self.num_moe_experts is not None:
expt_mesh = _get_hsdp_tp_mesh(
outer_fsdp_group, expt_dp_group, expt_tp_group, ep_size=ep_group.size()
)
Expand Down Expand Up @@ -308,7 +310,7 @@ def _init_dist_index(self, pg_collection):
expt_device_mesh=expt_device_mesh,
)
else:
if ep_group is not None:
if self.num_moe_experts is not None:
expt_mesh = _get_dp_tp_mesh(expt_dp_group, expt_tp_group, ep_size=ep_group.size())
expt_device_mesh = DeviceMesh.from_group(
[expt_dp_group, expt_tp_group],
Expand Down
4 changes: 0 additions & 4 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,6 @@ def __init__(
self.hsdp_outer_dp_shard = hsdp_outer_dp_shard
self.expt_device_mesh = expt_device_mesh

# Handling the situation where M-Core MoE EP=1
if self.expt_device_mesh is None:
self.expt_device_mesh = device_mesh

# Hybrid FSDP Process Groups
# Retrieve the FSDP process group from the DeviceMesh.
self.fsdp_group = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,52 @@ def train_step(model, optimizer, inputs):
msg=f"Parameters for {name1} don't match",
)

def test_fsdp_expt_device_mesh(self):
"""Test that expt_device_mesh is None for dense models and not None for MoE models."""
if not is_torch_min_version("2.4.0"):
pytest.skip("Megatron FSDP requires torch >= 2.4.0")

fsdp_config = DistributedDataParallelConfig(
data_parallel_sharding_strategy="optim_grads_params",
overlap_grad_reduce=True,
overlap_param_gather=True,
bucket_size=10000,
use_megatron_fsdp=True,
)
input_dim, output_dim = 13, 17

# Dense model: expt_device_mesh should not be built without MoE config
dense_config = TransformerConfig(
num_attention_heads=1, num_layers=1, context_parallel_size=1
)
dense_model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda()
fsdp_dense = FullyShardedDataParallel(
config=dense_config,
ddp_config=fsdp_config,
module=dense_model,
fsdp_unit_modules=[torch.nn.Linear],
)
assert (
fsdp_dense.megatron_fsdp_dist_index.expt_device_mesh is None
), "Dense model: expt_device_mesh should be None"
fsdp_dense.stop_communication()

# MoE model: expt_device_mesh should be built when num_moe_experts is set
moe_config = TransformerConfig(
num_attention_heads=1, num_layers=1, context_parallel_size=1, num_moe_experts=4
)
moe_model = TestModel(input_dim=input_dim, output_dim=output_dim).cuda()
fsdp_moe = FullyShardedDataParallel(
config=moe_config,
ddp_config=fsdp_config,
module=moe_model,
fsdp_unit_modules=[torch.nn.Linear],
)
assert (
fsdp_moe.megatron_fsdp_dist_index.expt_device_mesh is not None
), "MoE model: expt_device_mesh should not be None"
fsdp_moe.stop_communication()

# Testing fsdp_double_buffer with and without nccl_ub
@pytest.mark.parametrize(
("dp_size", "nccl_ub", "fsdp_double_buffer", "fsdp_manual_registration"),
Expand Down
Loading