Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 90 additions & 38 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import random
from typing import List, Optional
from typing import Dict, List, Optional

try:
import einops
Expand All @@ -26,6 +26,7 @@
import numpy as np
import torch
import torch.distributed as dist
from torch import nn

try:
from torch.distributed import DeviceMesh
Expand All @@ -38,7 +39,6 @@
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.extensions.transformer_engine import TELinear
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
Expand All @@ -64,6 +64,32 @@ class FullyShardedDataParallel(_BaseDataParallel):
Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.
"""

# Module type registry (forked from Megatron-Bridge param_mapping utilities).
_MODULE_TYPE_REGISTRY: Dict[str, set] = {
"column": {
"ColumnParallelLinear",
"TEColumnParallelLinear",
"TELayerNormColumnParallelLinear",
"TEColumnParallelGroupedLinear",
"VocabParallelEmbedding",
"DotProductAttention", # for attention sink only
"TEDotProductAttention", # for attention sink only
},
"row": {"RowParallelLinear", "TERowParallelLinear", "TERowParallelGroupedLinear"},
"replicated": {
# Normalization layers
"TENorm",
"FusedLayerNorm",
"WrappedTorchNorm",
"LayerNorm",
"RMSNorm",
"L2Norm",
# Other non-parallel modules
"IdentityOp",
"TopKRouter",
},
}

def __init__(
self,
config: TransformerConfig,
Expand Down Expand Up @@ -122,7 +148,7 @@ def __init__(
else:
self.fsdp_unit_modules = []

self._fix_tensor_parallel_attributes(module)
self._annotate_tensor_parallelism(module)

super().__init__(
config=config,
Expand Down Expand Up @@ -177,43 +203,69 @@ def load_state_dict(self, state_dict, strict=True):

self.module.load_state_dict(custom_state_dict, strict=strict)

def _fix_tensor_parallel_attributes(self, module):
is_expert_param = lambda n, p: ".experts." in n
is_router_param = lambda n, p: ".router.weight" in n
def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Optional[str]:
"""
Infer tensor-parallelism type for a parameter under a given module
(forked from Megatron-Bridge).

if parallel_state.get_tensor_model_parallel_group():
tp_size = parallel_state.get_tensor_model_parallel_group().size()
else:
tp_size = 1
Returns:
"column", "row", or "replicated" if a type can be inferred, else None.
"""
module_type = type(module).__name__

# Handle fused modules like TELayerNormColumnParallelLinear
# These modules have both column-parallel weights (weight, bias)
# and replicated layer norm weights (layer_norm_weight, layer_norm_bias)
if module_type == "TELayerNormColumnParallelLinear":
# Check the actual parameter name to determine the correct parallelism type
if param_name.endswith("layer_norm_weight") or param_name.endswith("layer_norm_bias"):
return "replicated"
# All other parameters (weight, bias) are column-parallel
return "column"

# Check registry first
for parallelism, types in self._MODULE_TYPE_REGISTRY.items():
if module_type in types:
return parallelism

# Fallback to inspecting module attributes
if hasattr(module, "tensor_model_parallel"):
if not module.tensor_model_parallel:
return "replicated"

# Check partition dimension
partition_dim = getattr(module, "partition_dim", None)
if partition_dim == 0:
return "column"
elif partition_dim == 1:
return "row"

# Fallback for normalization layers
if any(norm in module_type for norm in ["Norm", "Normalization"]):
return "replicated"

# Check parallel_mode for TELinear
if module_type == "TELinear":
if module.parallel_mode == "column":
return "column"
elif module.parallel_mode == "row":
return "row"
else:
return "replicated"

if parallel_state.get_expert_tensor_parallel_group():
expt_tp_size = parallel_state.get_expert_tensor_parallel_group().size()
else:
expt_tp_size = 1

param_to_direct_module = {}
for name, m in module.named_modules():
for p in m.parameters(recurse=False):
param_to_direct_module[p] = (name, m)

for name, param in module.named_parameters():
if is_expert_param(name, param) and expt_tp_size > 1:
setattr(param, "_mcore_tp", True)
if "linear_fc1.weight" in name:
setattr(param, "_tp_partition_dim", 0)
elif "linear_fc2.weight" in name:
setattr(param, "_tp_partition_dim", 1)

if not is_expert_param(name, param) and tp_size > 1:
m_name, direct_module = param_to_direct_module[param]
if isinstance(direct_module, (TELinear,)):
parallel_mode = getattr(direct_module, "parallel_mode", None)
if parallel_mode is None:
setattr(param, "_mcore_tp", True)
setattr(param, "_tp_duplicated", True)
elif is_router_param(name, param):
setattr(param, "_mcore_tp", True)
setattr(param, "_tp_duplicated", True)
return None

def _annotate_tensor_parallelism(self, root_module: nn.Module) -> None:
"""Annotate parameters under root_module with inferred tensor-parallel metadata.

Each parameter that can be classified will get a `_tensor_parallel_mode` attribute
set to one of: "column", "row", or "replicated".
"""
for submodule in root_module.modules():
for name, param in submodule.named_parameters(recurse=False):
detected_type = self._detect_parallelism_type(name, submodule)
if detected_type is not None:
setattr(param, "_tensor_parallel_mode", detected_type)

def _init_dist_index(self, pg_collection):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,9 @@ def _register_pre_backward_param_unshard_hook(module):
]

for param in grad_acc_param_list:
# Only register grad acc hook for parameters that require gradients.
if not param.requires_grad:
continue
self.grad_acc_hooks[f"grad_acc and reduce for {self.param_to_name[param]}"] = (
param.register_post_accumulate_grad_hook(
lambda p: _process_post_backward_gradients([p])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
FSDPDistributedIndex,
get_global_memory_buffer,
get_mcore_tensor_parallel_partition_dim,
is_mcore_tensor_model_parallel,
is_mcore_tensor_parallel_duplicated,
using_tensor_parallel,
log_single_rank,
)

Expand Down Expand Up @@ -2627,8 +2627,7 @@ def _reset_parameters(self, old_params, new_params):
del self.param_to_direct_module[old_param]

new_param.requires_grad_(old_param.requires_grad)

for tp_attr in ["_mcore_tp", "_tp_partition_dim", "_tp_duplicated"]:
for tp_attr in ["_tensor_parallel_mode"]:
if getattr(old_param, tp_attr, None) is not None:
setattr(new_param, tp_attr, getattr(old_param, tp_attr))

Expand Down Expand Up @@ -2781,9 +2780,7 @@ def set_param_attribute():
"partition_stride",
"is_embedding_or_output_parameter",
"is_embedding_parameter",
"_mcore_tp",
"_tp_duplicated",
"_tp_partition_dim",
"_tensor_parallel_mode",
]:
if hasattr(orig_param, attr_name):
setattr(param, attr_name, getattr(orig_param, attr_name))
Expand Down Expand Up @@ -4425,43 +4422,44 @@ def make_fsdp_dtensor(
orig_param = param

# Handle tensor model parallel specific logic
if is_mcore_tensor_model_parallel(param):
if not isinstance(param, DTensor) and using_tensor_parallel(
dist_index, is_expert_parallel=is_expert_param
):
# Ensure parameter is not already a DTensor
assert not isinstance(param, DTensor), (
"[Megatron-FSDP] Parameter is already a DTensor, yet tensor_model_parallel " "is True."
)

tp_mesh = dist_index.get_submesh(dist_index.tp_dim, is_expert_parallel=is_expert_param)
global_shape = list(param.shape)
if tp_mesh.mesh.numel() > 1:
if is_mcore_tensor_parallel_duplicated(param):
placements = [Replicate()]
if force_sync_tp_duplicated_param:
if local_tensor.numel() > 0:
torch.distributed.broadcast(
local_tensor, group=tp_mesh.get_group(), group_src=0
)
elif run_check:
# TODO: Implement consistency check for duplicated TP parameters
pass
else:
tp_dim = get_mcore_tensor_parallel_partition_dim(param)
assert tp_dim is not None, (
"[Megatron-FSDP] Parameter is not tensor model parallel, "
"yet tensor_model_parallel is True."
)
placements = [Shard(tp_dim)]
global_shape[tp_dim] *= tp_mesh.mesh.numel()

# Construct TP-sharded DTensor using Megatron-style placement
param = DTensor.from_local(
local_tensor=local_tensor,
device_mesh=tp_mesh,
placements=placements,
run_check=run_check,
shape=tuple(global_shape),
stride=torch.empty(global_shape).stride(),
if is_mcore_tensor_parallel_duplicated(param):
placements = [Replicate()]
if force_sync_tp_duplicated_param:
if local_tensor.numel() > 0:
torch.distributed.broadcast(
local_tensor, group=tp_mesh.get_group(), group_src=0
)
elif run_check:
# TODO: Implement consistency check for duplicated TP parameters
pass
else:
tp_dim = get_mcore_tensor_parallel_partition_dim(param)
assert tp_dim is not None, (
"[Megatron-FSDP] Parameter is not tensor model parallel, "
"yet tensor_model_parallel is True."
)
placements = [Shard(tp_dim)]
global_shape[tp_dim] *= tp_mesh.mesh.numel()

# Construct TP-sharded DTensor using Megatron-style placement
param = DTensor.from_local(
local_tensor=local_tensor,
device_mesh=tp_mesh,
placements=placements,
run_check=run_check,
shape=tuple(global_shape),
stride=torch.empty(global_shape).stride(),
)

# Get FSDP-configured mesh and placements from provided param
device_mesh, placements = _get_fsdp_tensor_spec(
Expand Down
Loading