diff --git a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py index ff55cf038d4..82cb5fce506 100644 --- a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +++ b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py @@ -14,7 +14,7 @@ import logging import random -from typing import List, Optional +from typing import Dict, List, Optional try: import einops @@ -26,6 +26,7 @@ import numpy as np import torch import torch.distributed as dist +from torch import nn try: from torch.distributed import DeviceMesh @@ -38,7 +39,6 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.distributed.data_parallel_base import _BaseDataParallel from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig -from megatron.core.extensions.transformer_engine import TELinear from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer @@ -64,6 +64,32 @@ class FullyShardedDataParallel(_BaseDataParallel): Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model. """ + # Module type registry (forked from Megatron-Bridge param_mapping utilities). + _MODULE_TYPE_REGISTRY: Dict[str, set] = { + "column": { + "ColumnParallelLinear", + "TEColumnParallelLinear", + "TELayerNormColumnParallelLinear", + "TEColumnParallelGroupedLinear", + "VocabParallelEmbedding", + "DotProductAttention", # for attention sink only + "TEDotProductAttention", # for attention sink only + }, + "row": {"RowParallelLinear", "TERowParallelLinear", "TERowParallelGroupedLinear"}, + "replicated": { + # Normalization layers + "TENorm", + "FusedLayerNorm", + "WrappedTorchNorm", + "LayerNorm", + "RMSNorm", + "L2Norm", + # Other non-parallel modules + "IdentityOp", + "TopKRouter", + }, + } + def __init__( self, config: TransformerConfig, @@ -122,7 +148,7 @@ def __init__( else: self.fsdp_unit_modules = [] - self._fix_tensor_parallel_attributes(module) + self._annotate_tensor_parallelism(module) super().__init__( config=config, @@ -177,43 +203,69 @@ def load_state_dict(self, state_dict, strict=True): self.module.load_state_dict(custom_state_dict, strict=strict) - def _fix_tensor_parallel_attributes(self, module): - is_expert_param = lambda n, p: ".experts." in n - is_router_param = lambda n, p: ".router.weight" in n + def _detect_parallelism_type(self, param_name: str, module: nn.Module) -> Optional[str]: + """ + Infer tensor-parallelism type for a parameter under a given module + (forked from Megatron-Bridge). - if parallel_state.get_tensor_model_parallel_group(): - tp_size = parallel_state.get_tensor_model_parallel_group().size() - else: - tp_size = 1 + Returns: + "column", "row", or "replicated" if a type can be inferred, else None. + """ + module_type = type(module).__name__ + + # Handle fused modules like TELayerNormColumnParallelLinear + # These modules have both column-parallel weights (weight, bias) + # and replicated layer norm weights (layer_norm_weight, layer_norm_bias) + if module_type == "TELayerNormColumnParallelLinear": + # Check the actual parameter name to determine the correct parallelism type + if param_name.endswith("layer_norm_weight") or param_name.endswith("layer_norm_bias"): + return "replicated" + # All other parameters (weight, bias) are column-parallel + return "column" + + # Check registry first + for parallelism, types in self._MODULE_TYPE_REGISTRY.items(): + if module_type in types: + return parallelism + + # Fallback to inspecting module attributes + if hasattr(module, "tensor_model_parallel"): + if not module.tensor_model_parallel: + return "replicated" + + # Check partition dimension + partition_dim = getattr(module, "partition_dim", None) + if partition_dim == 0: + return "column" + elif partition_dim == 1: + return "row" + + # Fallback for normalization layers + if any(norm in module_type for norm in ["Norm", "Normalization"]): + return "replicated" + + # Check parallel_mode for TELinear + if module_type == "TELinear": + if module.parallel_mode == "column": + return "column" + elif module.parallel_mode == "row": + return "row" + else: + return "replicated" - if parallel_state.get_expert_tensor_parallel_group(): - expt_tp_size = parallel_state.get_expert_tensor_parallel_group().size() - else: - expt_tp_size = 1 - - param_to_direct_module = {} - for name, m in module.named_modules(): - for p in m.parameters(recurse=False): - param_to_direct_module[p] = (name, m) - - for name, param in module.named_parameters(): - if is_expert_param(name, param) and expt_tp_size > 1: - setattr(param, "_mcore_tp", True) - if "linear_fc1.weight" in name: - setattr(param, "_tp_partition_dim", 0) - elif "linear_fc2.weight" in name: - setattr(param, "_tp_partition_dim", 1) - - if not is_expert_param(name, param) and tp_size > 1: - m_name, direct_module = param_to_direct_module[param] - if isinstance(direct_module, (TELinear,)): - parallel_mode = getattr(direct_module, "parallel_mode", None) - if parallel_mode is None: - setattr(param, "_mcore_tp", True) - setattr(param, "_tp_duplicated", True) - elif is_router_param(name, param): - setattr(param, "_mcore_tp", True) - setattr(param, "_tp_duplicated", True) + return None + + def _annotate_tensor_parallelism(self, root_module: nn.Module) -> None: + """Annotate parameters under root_module with inferred tensor-parallel metadata. + + Each parameter that can be classified will get a `_tensor_parallel_mode` attribute + set to one of: "column", "row", or "replicated". + """ + for submodule in root_module.modules(): + for name, param in submodule.named_parameters(recurse=False): + detected_type = self._detect_parallelism_type(name, submodule) + if detected_type is not None: + setattr(param, "_tensor_parallel_mode", detected_type) def _init_dist_index(self, pg_collection): """ diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 6987729ba8f..9fb20a8a0ef 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -1020,6 +1020,9 @@ def _register_pre_backward_param_unshard_hook(module): ] for param in grad_acc_param_list: + # Only register grad acc hook for parameters that require gradients. + if not param.requires_grad: + continue self.grad_acc_hooks[f"grad_acc and reduce for {self.param_to_name[param]}"] = ( param.register_post_accumulate_grad_hook( lambda p: _process_post_backward_gradients([p]) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index 684cd7a99eb..a818445ecaa 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -51,8 +51,8 @@ FSDPDistributedIndex, get_global_memory_buffer, get_mcore_tensor_parallel_partition_dim, - is_mcore_tensor_model_parallel, is_mcore_tensor_parallel_duplicated, + using_tensor_parallel, log_single_rank, ) @@ -2627,8 +2627,7 @@ def _reset_parameters(self, old_params, new_params): del self.param_to_direct_module[old_param] new_param.requires_grad_(old_param.requires_grad) - - for tp_attr in ["_mcore_tp", "_tp_partition_dim", "_tp_duplicated"]: + for tp_attr in ["_tensor_parallel_mode"]: if getattr(old_param, tp_attr, None) is not None: setattr(new_param, tp_attr, getattr(old_param, tp_attr)) @@ -2781,9 +2780,7 @@ def set_param_attribute(): "partition_stride", "is_embedding_or_output_parameter", "is_embedding_parameter", - "_mcore_tp", - "_tp_duplicated", - "_tp_partition_dim", + "_tensor_parallel_mode", ]: if hasattr(orig_param, attr_name): setattr(param, attr_name, getattr(orig_param, attr_name)) @@ -4425,7 +4422,9 @@ def make_fsdp_dtensor( orig_param = param # Handle tensor model parallel specific logic - if is_mcore_tensor_model_parallel(param): + if not isinstance(param, DTensor) and using_tensor_parallel( + dist_index, is_expert_parallel=is_expert_param + ): # Ensure parameter is not already a DTensor assert not isinstance(param, DTensor), ( "[Megatron-FSDP] Parameter is already a DTensor, yet tensor_model_parallel " "is True." @@ -4433,35 +4432,34 @@ def make_fsdp_dtensor( tp_mesh = dist_index.get_submesh(dist_index.tp_dim, is_expert_parallel=is_expert_param) global_shape = list(param.shape) - if tp_mesh.mesh.numel() > 1: - if is_mcore_tensor_parallel_duplicated(param): - placements = [Replicate()] - if force_sync_tp_duplicated_param: - if local_tensor.numel() > 0: - torch.distributed.broadcast( - local_tensor, group=tp_mesh.get_group(), group_src=0 - ) - elif run_check: - # TODO: Implement consistency check for duplicated TP parameters - pass - else: - tp_dim = get_mcore_tensor_parallel_partition_dim(param) - assert tp_dim is not None, ( - "[Megatron-FSDP] Parameter is not tensor model parallel, " - "yet tensor_model_parallel is True." - ) - placements = [Shard(tp_dim)] - global_shape[tp_dim] *= tp_mesh.mesh.numel() - - # Construct TP-sharded DTensor using Megatron-style placement - param = DTensor.from_local( - local_tensor=local_tensor, - device_mesh=tp_mesh, - placements=placements, - run_check=run_check, - shape=tuple(global_shape), - stride=torch.empty(global_shape).stride(), + if is_mcore_tensor_parallel_duplicated(param): + placements = [Replicate()] + if force_sync_tp_duplicated_param: + if local_tensor.numel() > 0: + torch.distributed.broadcast( + local_tensor, group=tp_mesh.get_group(), group_src=0 + ) + elif run_check: + # TODO: Implement consistency check for duplicated TP parameters + pass + else: + tp_dim = get_mcore_tensor_parallel_partition_dim(param) + assert tp_dim is not None, ( + "[Megatron-FSDP] Parameter is not tensor model parallel, " + "yet tensor_model_parallel is True." ) + placements = [Shard(tp_dim)] + global_shape[tp_dim] *= tp_mesh.mesh.numel() + + # Construct TP-sharded DTensor using Megatron-style placement + param = DTensor.from_local( + local_tensor=local_tensor, + device_mesh=tp_mesh, + placements=placements, + run_check=run_check, + shape=tuple(global_shape), + stride=torch.empty(global_shape).stride(), + ) # Get FSDP-configured mesh and placements from provided param device_mesh, placements = _get_fsdp_tensor_spec( diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py index 5df9c2e95c0..8b98514ae5c 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, Union import torch import torch.distributed as dist @@ -25,8 +25,6 @@ from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard -from .utils import get_mesh_names - def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata: """ @@ -250,148 +248,144 @@ def preprocess_state_dict_for_uneven_dtensor(state_dict: dict) -> dict: return state_dict -def gather_uneven_dtensor_to_full_tensor( - dtensor: DTensor, target_device: Optional[torch.device] = None -) -> DTensor: +def uneven_dtensor_to_full_tensor(dtensor: DTensor) -> torch.Tensor: """ - Gather an unevenly sharded DTensor distributed across multiple ranks, - reconstructing the full (unsharded) tensor on each rank. + Gather a DTensor with potentially uneven sharding across ranks into a full tensor. - This function handles uneven chunk sizes and offsets by collecting - chunk metadata from all ranks, performing all-gather operations, - and assembling the full tensor accordingly. The returned tensor - is fully replicated across the given device mesh. + This function handles DTensors with uneven shards (where different ranks may have + different-sized chunks) by gathering chunk metadata and local tensors across all + ranks, then reconstructing the complete tensor. Args: - dtensor (DTensor): Distributed tensor with uneven sharding across ranks. - target_device (Optional[torch.device]): If specified, move the resulting - full tensor to this device. Otherwise, use the original device. + dtensor (DTensor): The distributed tensor to gather. Must have chunk metadata + available (either pre-existing or will be computed). Returns: - DTensor: Fully replicated DTensor representing the reconstructed full tensor. + torch.Tensor: The fully reconstructed tensor with shape matching the original + DTensor's global shape. + + Raises: + TypeError: If input is not a DTensor. + ValueError: If chunk metadata is malformed (expected exactly one chunk per rank). + AssertionError: If an unexpected placement type is encountered after processing + Shard placements. + + Note: + - This function performs collective operations (all_gather_object, all_gather) + across the device mesh, requiring synchronization across ranks. + - Works with Shard and _StridedShard placements, and expects Replicate placements + for non-sharded dimensions. + - The function modifies the DTensor in-place by adding chunk metadata if missing. + + Example: + >>> mesh = DeviceMesh("cuda", [0, 1, 2, 3]) + >>> # Create a DTensor with uneven sharding + >>> dtensor = DTensor(..., placements=[Shard(0)]) + >>> full_tensor = gather_uneven_dtensor_to_full_tensor(dtensor) + >>> assert full_tensor.shape == dtensor.shape """ + # Validate input type if not isinstance(dtensor, DTensor): - raise TypeError("Input must be a DTensor.") - - device_mesh = dtensor.device_mesh - if not device_mesh.mesh_dim_names: - process_group = device_mesh.get_group() - else: - # Check if the fully-flattened mesh exists first. - full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names) - if full_flattened_mesh_dim_name in get_mesh_names(device_mesh): - # Retrieve the existing flattened DeviceMesh ProcessGroup. - try: - # Two Cases: Name is a root dimension, or using the old DeviceMesh - # API which allows us to get flattened dimensions. - process_group = device_mesh[full_flattened_mesh_dim_name].get_group() - except: - # Name is a flattened dimension that cannot be retrieved from the - # DeviceMesh.__getitem__, so fall-back to new DeviceMesh API. - process_group = ( - device_mesh._get_root_mesh() - ._flatten_mapping[full_flattened_mesh_dim_name] - .get_group() - ) - else: - # Create the _-separated flattened DeviceMesh ProcessGroup. - process_group = device_mesh._flatten().get_group() + raise TypeError(f"Input must be a DTensor, got {type(dtensor).__name__}.") - # Collect chunk metadata for uneven shards (update if missing) + # Ensure chunk metadata is available for uneven shards if not hasattr(dtensor._local_tensor, "__create_chunk_list__"): update_uneven_dtensor_chunk_metadata(dtensor) + # Retrieve and validate chunk metadata chunk_metadata_list = dtensor.__create_chunk_list__() if len(chunk_metadata_list) != 1: - raise ValueError(f"Expected exactly one chunk metadata, got {len(chunk_metadata_list)}.") - + raise ValueError( + f"Expected exactly one chunk metadata per rank, got {len(chunk_metadata_list)}." + ) local_chunk_metadata = chunk_metadata_list[0] - world_size = process_group.size() - - # Prepare local chunk info dictionary - local_chunk_info = { - "shape": list(dtensor.to_local().shape), - "offset": getattr(local_chunk_metadata, "offsets", [0] * len(dtensor.shape)), - "rank": process_group.rank(), - } - - # Gather chunk info from all ranks - all_chunk_info = [None] * world_size - dist.all_gather_object(all_chunk_info, local_chunk_info, group=process_group) - - # Delegate to helper function - return _assemble_full_tensor_from_uneven_chunks( - dtensor, all_chunk_info, process_group, target_device - ) + # Prepare local chunk information for gathering + local_chunks_info = [ + { + "shape": dtensor.to_local().shape, + "offset": getattr(local_chunk_metadata, "offsets", [0] * len(dtensor.shape)), + } + ] + local_buffer = dtensor.to_local().contiguous().view(-1) + + # Iterate through device mesh dimensions and gather across sharded dimensions + for mesh_dim, placement in enumerate(dtensor.placements): + if isinstance(placement, (Shard, _StridedShard)): + # Get the process group for this mesh dimension + shard_group = dtensor.device_mesh.get_group(mesh_dim) + + # Gather chunk metadata from all ranks in this dimension + group_chunks_info = [None] * shard_group.size() + dist.all_gather_object(group_chunks_info, local_chunks_info, group=shard_group) + + # Prepare buffers for gathering tensors from all ranks + group_tensors = [ + torch.empty( + sum(chunk["shape"].numel() for chunk in chunks_info), + dtype=dtensor.dtype, + device=dtensor.device, + ) + for chunks_info in group_chunks_info + ] -def _assemble_full_tensor_from_uneven_chunks( - dtensor: DTensor, - all_chunk_info: List[dict], - process_group: torch.distributed.ProcessGroup, - target_device: Optional[torch.device], -) -> DTensor: - """ - Assemble the full tensor from unevenly sized chunks gathered from all ranks. + # Gather actual tensor data from all ranks + dist.all_gather(group_tensors, local_buffer, group=shard_group) - Args: - dtensor (DTensor): The original distributed tensor. - all_chunk_info (List[Dict]): List of shard info dicts from all ranks, - including shapes and offsets. - process_group: Process group for collective communication. - target_device: Optional device to move the final full tensor onto. + # Flatten the gathered metadata and concatenate tensors + local_chunks_info = [item for sublist in group_chunks_info for item in sublist] + local_buffer = torch.cat(group_tensors) + elif not isinstance(placement, Replicate): + raise ValueError( + f"Unexpected placement {placement} at mesh dimension {mesh_dim}. " + f"Expected Shard, _StridedShard, or Replicate." + ) - Returns: - DTensor: Fully replicated tensor constructed by placing chunks at - the appropriate offsets. + # Split the gathered buffer back into individual chunks + all_local_chunks = [] + buffer_offset = 0 + for chunk_info in local_chunks_info: + chunk_shape = chunk_info["shape"] + chunk_numel = chunk_shape.numel() + chunk_tensor = local_buffer[buffer_offset : buffer_offset + chunk_numel].view(chunk_shape) + all_local_chunks.append(chunk_tensor) + buffer_offset += chunk_numel + + debug_slices = [] + for chunk_info, local_chunk in zip(local_chunks_info, all_local_chunks): + offset = chunk_info["offset"] + slices = tuple(slice(o, o + s) for o, s in zip(offset, local_chunk.shape)) + debug_slices.append(slices) + + # Reconstruct the full tensor by placing chunks at their correct offsets + full_tensor = torch.zeros(dtensor.shape, dtype=dtensor.dtype, device=dtensor.device) + for chunk_info, local_chunk in zip(local_chunks_info, all_local_chunks): + offset = chunk_info["offset"] + slices = tuple(slice(o, o + s) for o, s in zip(offset, local_chunk.shape)) + full_tensor[slices] = local_chunk + + return full_tensor + + +def redistribute_uneven_dtensor_to_replicated(dtensor: DTensor) -> DTensor: """ - local_tensor = dtensor.to_local() - - # Check if the DTensor has any shard placements - have_shard_placement = any( - isinstance(placement, Shard) or isinstance(placement, _StridedShard) - for placement in dtensor.placements - ) - - if not have_shard_placement: - # No sharding (replicated tensor), just clone and move if needed - full_tensor = local_tensor.clone() - if target_device: - full_tensor = full_tensor.to(target_device) - else: - # Prepare empty buffers to receive tensors from each rank - gathered_tensors = [ - torch.empty(rank_info["shape"], dtype=local_tensor.dtype, device=local_tensor.device) - for rank_info in all_chunk_info - ] + Redistribute an unevenly sharded DTensor to a fully replicated DTensor. - # Gather local tensors from all ranks - dist.all_gather(gathered_tensors, local_tensor, group=process_group) + This function first gathers the unevenly sharded DTensor into a full tensor + and then redistributes it as a replicated DTensor across all ranks. - # Allocate full tensor buffer - full_tensor = torch.empty( - dtensor.shape, dtype=local_tensor.dtype, device=local_tensor.device - ) - - # Copy each gathered shard into the full tensor at its offset - for rank_info, local_shard in zip(all_chunk_info, gathered_tensors): - offset = rank_info["offset"] - slices = tuple(slice(o, o + s) for o, s in zip(offset, local_shard.shape)) - full_tensor[slices] = local_shard - - # Optionally move to target device - if target_device is not None: - full_tensor = full_tensor.to(target_device) - - # Free memory of gathered shards as they are copied - del gathered_tensors - - # Wrap into a replicated DTensor and return - return DTensor.from_local( + Args: + dtensor (DTensor): The unevenly sharded DTensor to redistribute. + Returns: + DTensor: A replicated DTensor with the same data as the input DTensor. + """ + full_tensor = uneven_dtensor_to_full_tensor(dtensor) + replicated_dtensor = DTensor.from_local( full_tensor, placements=[Replicate()] * len(dtensor.placements), device_mesh=dtensor.device_mesh, ) + return replicated_dtensor def _intersection(s1, s2): diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index b961a449d3e..8911712642e 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -814,23 +814,31 @@ def is_mcore_tensor_model_parallel(param: torch.Tensor) -> bool: """ Check if the given parameter is Megatron-Core tensor model parallel. """ - return getattr(param, "_mcore_tp", False) or getattr(param, "tensor_model_parallel", False) + return get_mcore_tensor_parallel_partition_dim(param) is not None def is_mcore_tensor_parallel_duplicated(param: torch.Tensor) -> bool: """ Check if the given parameter is Megatron-Core tensor model parallel and duplicated. """ - return getattr(param, "_tp_duplicated", False) + return get_mcore_tensor_parallel_partition_dim(param) is None def get_mcore_tensor_parallel_partition_dim(param: torch.Tensor) -> Optional[int]: """ Get the partition dimension for a Megatron-Core tensor model parallel parameter. """ - if is_mcore_tensor_model_parallel(param): - if hasattr(param, "_tp_partition_dim"): - return param._tp_partition_dim - else: - return param.partition_dim + if hasattr(param, "_tensor_parallel_mode"): + if param._tensor_parallel_mode == "column": + return 0 + elif param._tensor_parallel_mode == "row": + return 1 return None + + +def using_tensor_parallel(dist_index, is_expert_parallel: bool = False) -> bool: + """ + Check if tensor parallelism is being used based on the distributed index. + """ + tp_mesh = dist_index.get_submesh(dist_index.tp_dim, is_expert_parallel=is_expert_parallel) + return tp_mesh.mesh.numel() > 1 diff --git a/megatron/core/transformer/fsdp_dtensor_checkpoint.py b/megatron/core/transformer/fsdp_dtensor_checkpoint.py index e408209c778..a9d547ed5a1 100644 --- a/megatron/core/transformer/fsdp_dtensor_checkpoint.py +++ b/megatron/core/transformer/fsdp_dtensor_checkpoint.py @@ -31,7 +31,7 @@ make_fsdp_dtensor, ) from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import ( - gather_uneven_dtensor_to_full_tensor, + uneven_dtensor_to_full_tensor, ) from megatron.core.distributed.fsdp.src.megatron_fsdp.utils import ( get_mcore_tensor_parallel_partition_dim, @@ -427,13 +427,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path): load_item_dict, storage_reader=reader, planner=default_planner.DefaultLoadPlanner() ) if isinstance(value, DTensor): - full_value = gather_uneven_dtensor_to_full_tensor(value) + full_tensor_v = uneven_dtensor_to_full_tensor(value) loaded_tensor = load_item_dict[key].redistribute( placements=[Replicate()] * len(value.placements) ) assert torch.allclose( - loaded_tensor._local_tensor, full_value._local_tensor, atol=1e-8, rtol=1e-5 - ), f"key: {key}; {loaded_tensor} {full_value}" + loaded_tensor._local_tensor, full_tensor_v, atol=1e-8, rtol=1e-5 + ), f"key: {key}; {loaded_tensor} {full_tensor_v}" else: assert torch.allclose( value, load_item_dict[key] diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py new file mode 100644 index 00000000000..97e7c9a740e --- /dev/null +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py @@ -0,0 +1,258 @@ +import types + +import torch +from torch import nn + +from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + +from megatron.core.distributed.fsdp.src.megatron_fsdp.utils import ( + using_tensor_parallel, + is_mcore_tensor_parallel_duplicated, + get_mcore_tensor_parallel_partition_dim, +) + + +class DummyMesh: + def __init__(self, numel: int): + # mimic DeviceMesh.mesh + self.mesh = torch.arange(numel) + + +class DummyDistIndex: + def __init__(self, tp_dim: str = "tp", numel: int = 1): + self.tp_dim = tp_dim + self._is_expert = {} + self._meshes = {(tp_dim, False): DummyMesh(numel)} + + def get_submesh(self, dim_name: str, is_expert_parallel: bool = False): + return self._meshes[(dim_name, is_expert_parallel)] + + +def test_get_mcore_tensor_parallel_partition_dim_column_row_and_none(): + # Column-parallel param -> partition_dim 0 + p_col = torch.nn.Parameter(torch.empty(4, 4)) + p_col._tensor_parallel_mode = "column" + assert get_mcore_tensor_parallel_partition_dim(p_col) == 0 + + # Row-parallel param -> partition_dim 1 + p_row = torch.nn.Parameter(torch.empty(4, 4)) + p_row._tensor_parallel_mode = "row" + assert get_mcore_tensor_parallel_partition_dim(p_row) == 1 + + # Replicated or unknown mode -> None + p_rep = torch.nn.Parameter(torch.empty(4, 4)) + p_rep._tensor_parallel_mode = "replicated" + assert get_mcore_tensor_parallel_partition_dim(p_rep) is None + + # No attribute at all -> None + p_plain = torch.nn.Parameter(torch.empty(4, 4)) + assert get_mcore_tensor_parallel_partition_dim(p_plain) is None + + +def test_is_mcore_tensor_parallel_duplicated_behaviour(): + # Column / row -> not duplicated (partition_dim not None) + p_col = torch.nn.Parameter(torch.empty(4, 4)) + p_col._tensor_parallel_mode = "column" + assert is_mcore_tensor_parallel_duplicated(p_col) is False + + p_row = torch.nn.Parameter(torch.empty(4, 4)) + p_row._tensor_parallel_mode = "row" + assert is_mcore_tensor_parallel_duplicated(p_row) is False + + # Replicated or no mode -> duplicated == True + p_rep = torch.nn.Parameter(torch.empty(4, 4)) + p_rep._tensor_parallel_mode = "replicated" + assert is_mcore_tensor_parallel_duplicated(p_rep) is True + + p_plain = torch.nn.Parameter(torch.empty(4, 4)) + assert is_mcore_tensor_parallel_duplicated(p_plain) is True + + +def test_using_tensor_parallel_true_when_mesh_size_gt_one(): + # Mesh with >1 element -> using tensor parallel + dist_index = DummyDistIndex(numel=4) + assert using_tensor_parallel(dist_index) is True + + +def test_using_tensor_parallel_false_when_mesh_size_one(): + # Mesh with 1 element -> no tensor parallel + dist_index = DummyDistIndex(numel=1) + assert using_tensor_parallel(dist_index) is False + + +class DummyConfig: + # Just enough attributes for __init__ to run if needed in future tests. + gradient_accumulation_fusion = False + calculate_per_token_loss = False + init_model_with_meta_device = False + fp8_recipe = None + fp8 = False + gated_linear_unit = False + + +class DummyDDPConfig: + # Minimal stub to avoid touching real Megatron FSDP in these tests. + bucket_size = 1 + grad_reduce_in_fp32 = False + data_parallel_sharding_strategy = "no_shard" + num_distributed_optimizer_instances = 1 + outer_dp_sharding_strategy = "no_shard" + fp8_param_gather = False + + +def _make_fsdp_for_unit_tests(): + """Construct a FullyShardedDataParallel with minimal stubs. + + We bypass its heavy __init__ by creating an instance via __new__ + and only setting the attributes that _detect_parallelism_type + and _annotate_tensor_parallelism actually use. + """ + fsdp = FullyShardedDataParallel.__new__(FullyShardedDataParallel) + + # Copy the registry from the real class. + fsdp._MODULE_TYPE_REGISTRY = FullyShardedDataParallel._MODULE_TYPE_REGISTRY + + return fsdp + + +def test_detect_parallelism_telayernormcolumnparallellinear_layernorm_params(): + fsdp = _make_fsdp_for_unit_tests() + + class TELayerNormColumnParallelLinear(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(8, 8)) + self.bias = nn.Parameter(torch.empty(8)) + self.layer_norm_weight = nn.Parameter(torch.empty(8)) + self.layer_norm_bias = nn.Parameter(torch.empty(8)) + + module = TELayerNormColumnParallelLinear() + + # layer norm parameters should be replicated + assert ( + fsdp._detect_parallelism_type("layer_norm_weight", module) == "replicated" + ) + assert ( + fsdp._detect_parallelism_type("layer_norm_bias", module) == "replicated" + ) + + # non-layer-norm parameters should be column + assert fsdp._detect_parallelism_type("weight", module) == "column" + assert fsdp._detect_parallelism_type("bias", module) == "column" + + +def test_detect_parallelism_registry_column_row_replicated(): + fsdp = _make_fsdp_for_unit_tests() + + # Fabricate simple module classes whose __name__ matches the registry + class ColumnParallelLinear(nn.Module): + pass + + class RowParallelLinear(nn.Module): + pass + + class LayerNorm(nn.Module): + pass + + assert fsdp._detect_parallelism_type("weight", ColumnParallelLinear()) == "column" + assert fsdp._detect_parallelism_type("weight", RowParallelLinear()) == "row" + assert fsdp._detect_parallelism_type("weight", LayerNorm()) == "replicated" + + +def test_detect_parallelism_tensor_model_parallel_flag_and_partition_dim(): + fsdp = _make_fsdp_for_unit_tests() + + class DummyModule(nn.Module): + def __init__(self, tensor_model_parallel, partition_dim=None): + super().__init__() + self.tensor_model_parallel = tensor_model_parallel + if partition_dim is not None: + self.partition_dim = partition_dim + + # tensor_model_parallel = False -> replicated + m_rep = DummyModule(tensor_model_parallel=False) + assert fsdp._detect_parallelism_type("weight", m_rep) == "replicated" + + # tensor_model_parallel = True and partition_dim = 0 -> column + m_col = DummyModule(tensor_model_parallel=True, partition_dim=0) + assert fsdp._detect_parallelism_type("weight", m_col) == "column" + + # tensor_model_parallel = True and partition_dim = 1 -> row + m_row = DummyModule(tensor_model_parallel=True, partition_dim=1) + assert fsdp._detect_parallelism_type("weight", m_row) == "row" + + +def test_detect_parallelism_norm_fallback(): + fsdp = _make_fsdp_for_unit_tests() + + class MyNormalization(nn.Module): + pass + + class MyNorm(nn.Module): + pass + + assert fsdp._detect_parallelism_type("weight", MyNormalization()) == "replicated" + assert fsdp._detect_parallelism_type("weight", MyNorm()) == "replicated" + + +def test_detect_parallelism_teliner_parallel_mode_variants(): + fsdp = _make_fsdp_for_unit_tests() + + class TELinear(nn.Module): + def __init__(self, mode): + super().__init__() + self.parallel_mode = mode + + assert fsdp._detect_parallelism_type("weight", TELinear("column")) == "column" + assert fsdp._detect_parallelism_type("weight", TELinear("row")) == "row" + assert fsdp._detect_parallelism_type("weight", TELinear("none")) == "replicated" + + +def test_detect_parallelism_returns_none_when_cannot_infer(): + fsdp = _make_fsdp_for_unit_tests() + + class PlainModule(nn.Module): + pass + + assert fsdp._detect_parallelism_type("weight", PlainModule()) is None + + +def test_annotate_tensor_parallelism_sets_attribute_on_params(): + fsdp = _make_fsdp_for_unit_tests() + + class ColumnParallelLinear(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 4)) + self.bias = nn.Parameter(torch.empty(4)) + + class RowParallelLinear(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 4)) + + class PlainModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.empty(4, 4)) + + class Root(nn.Module): + def __init__(self): + super().__init__() + self.col = ColumnParallelLinear() + self.row = RowParallelLinear() + self.plain = PlainModule() + + root = Root() + + # Exercise _annotate_tensor_parallelism + fsdp._annotate_tensor_parallelism(root) + + # Check that known module types got annotated + assert root.col.weight._tensor_parallel_mode == "column" + assert root.col.bias._tensor_parallel_mode == "column" + assert root.row.weight._tensor_parallel_mode == "row" + + # For unknown module type, _detect_parallelism_type should return None + # and _annotate_tensor_parallelism must not set the attribute. + assert not hasattr(root.plain.weight, "_tensor_parallel_mode") diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py index a09f8f703c5..86ffcff41f1 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_fully_shard.py @@ -14,6 +14,10 @@ from torch.nn.functional import mse_loss from torch.optim import Adam +from megatron.core.distributed.fsdp.src.megatron_fsdp.fully_shard import ( + fully_shard_model, + fully_shard_optimizer, +) from tests.unit_tests.test_utilities import Utils logger = logging.getLogger(__name__) @@ -808,6 +812,60 @@ def test_fully_shard_te_quantized(self, init_model_with_meta_device, te_recipe): optimizer.step() optimizer.zero_grad() + @pytest.mark.parametrize("init_model_with_meta_device", [True, False]) + def test_model_with_frozen_param(self, init_model_with_meta_device): + """ + Test Megatron-FSDP with frozen parameters. + """ + # Build a toy TRANSFORMER model and identify FSDP unit modules. + toy_model, fsdp_unit_modules = build_toy_model( + model_type=TRANSFORMER, init_model_with_meta_device=init_model_with_meta_device + ) + + # Freeze a subset of parameters in the original model. + original_params = list(toy_model.parameters()) + num_frozen = len(original_params) // 2 + for param in original_params[:num_frozen]: + param.requires_grad = False + + # Fully shard the model with Megatron-FSDP. + mfsdp_model = fully_shard_model( + module=toy_model, + fsdp_unit_modules=fsdp_unit_modules, + zero_dp_strategy=OPTIM_GRADS, + init_model_with_meta_device=init_model_with_meta_device, + ) + + # Validate that the corresponding parameters remain frozen. + sharded_params = list(mfsdp_model.parameters()) + assert len(sharded_params) == len( + original_params + ), "Megatron-FSDP changed parameter count unexpectedly." + for idx, param in enumerate(sharded_params[:num_frozen]): + assert not param.requires_grad, f"Parameter {idx} is not frozen in Megatron-FSDP model." + + # Initialize the distributed optimizer on the Megatron-FSDP model. + toy_adam = Adam(params=mfsdp_model.parameters(), lr=0.01) + optimizer = fully_shard_optimizer(optimizer=toy_adam) + + # Mock input and target. + toy_input = torch.randn(1, DIM_SIZE, DIM_SIZE).to("cuda") + toy_target = torch.randn(1, DIM_SIZE, DIM_SIZE).to("cuda") + + for _ in range(NUM_STEPS): + # Forward pass. + output = mfsdp_model(toy_input, toy_input) + + # Loss. + loss = mse_loss(output, toy_target) + + # Backward pass. + loss.backward() + + # Optimizer step. + optimizer.step() + optimizer.zero_grad() + @pytest.mark.skipif( version.parse(torch.__version__) < version.parse('2.4.0'), reason="Requires DTensor and DeviceMesh support in (approximately) PyTorch 2.4.0 or later.", diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_uneven_dtensor.py b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_uneven_dtensor.py new file mode 100644 index 00000000000..4df464215b7 --- /dev/null +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mfsdp_uneven_dtensor.py @@ -0,0 +1,814 @@ +""" +Unit tests for Megatron-FSDP uneven_dtensor functions. + +Run with torchrun: + torchrun --nproc_per_node=4 pytest test_mfsdp_uneven_dtensor.py -v + torchrun --nproc_per_node=8 pytest test_mfsdp_uneven_dtensor.py -v +""" + +import os +import pytest +import torch +import torch.distributed as dist +from torch.distributed._tensor import ( + DeviceMesh, + DTensor, + Shard, + Replicate, + distribute_tensor, +) +from torch.distributed.tensor.placement_types import _StridedShard + +from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import ( + uneven_dtensor_to_full_tensor, + split_dtensor, +) + + +# Pytest fixtures for distributed setup +@pytest.fixture(scope="module") +def distributed_setup(): + """Setup distributed environment for pytest with proper CUDA device assignment.""" + # Check if running under torchrun + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + pytest.skip("Not running in distributed mode. Use torchrun to run this test.") + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + + # Determine device type and set CUDA device + if torch.cuda.is_available(): + device_type = "cuda" + # Set CUDA device for this process + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + backend = "nccl" + else: + device_type = "cpu" + device = torch.device("cpu") + backend = "gloo" + + # Initialize process group if not already initialized + if not dist.is_initialized(): + dist.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + + yield { + "rank": rank, + "world_size": world_size, + "local_rank": local_rank, + "device_type": device_type, + "device": device, + } + + # Cleanup + if dist.is_initialized(): + dist.destroy_process_group() + + +# ---------- Helper: distributed setup ---------- + +@pytest.fixture(scope="module") +def distributed_setup(): + """Setup torch.distributed and CUDA device for torchrun + pytest.""" + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + pytest.skip("Not running under torchrun. Use torchrun to run this test file.") + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + + if torch.cuda.is_available(): + device_type = "cuda" + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + backend = "nccl" + else: + device_type = "cpu" + device = torch.device("cpu") + backend = "gloo" + + if not dist.is_initialized(): + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + + yield { + "rank": rank, + "world_size": world_size, + "local_rank": local_rank, + "device_type": device_type, + "device": device, + } + + if dist.is_initialized(): + dist.destroy_process_group() + + +# ---------- Helper: broadcast-based global tensor creation ---------- + +def make_global_randn(shape, dtype=torch.float32, device=torch.device("cpu")): + """ + Create the same random tensor on all ranks by generating on rank 0 + and broadcasting to everyone. + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Make sure shape is known on all ranks (it is, since passed as arg) + if rank == 0: + tensor = torch.randn(*shape, dtype=dtype, device=device) + else: + # allocate empty tensor, then broadcast into it + tensor = torch.empty(*shape, dtype=dtype, device=device) + + dist.broadcast(tensor, src=0) + return tensor + + +def make_global_arange(shape, dtype=torch.float32, device=torch.device("cpu")): + """Same idea as make_global_randn, but deterministic arange.""" + rank = dist.get_rank() + if rank == 0: + tensor = torch.arange(torch.prod(torch.tensor(shape)).item(), + dtype=dtype, + device=device).reshape(*shape) + else: + tensor = torch.empty(*shape, dtype=dtype, device=device) + dist.broadcast(tensor, src=0) + return tensor + + +# ---------- Tests ---------- + +# --------------------------------------------------------------------------- +# uneven_dtensor_to_full_tensor tests +# --------------------------------------------------------------------------- + +@pytest.mark.distributed +def test_basic_shard_gather(distributed_setup): + """Basic 1D shard gather, world_size-agnostic.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_arange((4, 3), dtype=torch.float32, device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0)]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor) + + +@pytest.mark.distributed +def test_replicated_dtensor(distributed_setup): + """Replicated placement should reconstruct the same tensor.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_randn((8, 4), device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Replicate()]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor) + + +@pytest.mark.distributed +def test_uneven_sharding_dim0(distributed_setup): + """Uneven sharding on dim 0 using manual split + DTensor.from_local.""" + setup = distributed_setup + world_size = setup["world_size"] + mesh = DeviceMesh(setup["device_type"], list(range(world_size))) + + # size intentionally not divisible by world_size + rows = world_size * 3 + 2 + global_tensor = make_global_arange((rows, 4), dtype=torch.float32, device=setup["device"]) + + shard = Shard(0) + local_list, _ = shard._split_tensor( + global_tensor, + world_size, + with_padding=False, + contiguous=True, + ) + + local = local_list[setup["rank"]] + + dtensor = DTensor.from_local( + local, + mesh, + (Shard(0),), + shape=global_tensor.size(), + stride=global_tensor.stride(), + ) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor) + + +@pytest.mark.distributed +def test_uneven_sharding_dim1(distributed_setup): + """Uneven sharding on dim 1 using manual split + DTensor.from_local.""" + setup = distributed_setup + world_size = setup["world_size"] + mesh = DeviceMesh(setup["device_type"], list(range(world_size))) + + cols = world_size * 2 + 1 + global_tensor = make_global_randn((8, cols), device=setup["device"]) + + shard = Shard(1) + local_list, _ = shard._split_tensor( + global_tensor, + world_size, + with_padding=False, + contiguous=True, + ) + local = local_list[setup["rank"]] + + dtensor = DTensor.from_local( + local, + mesh, + (Shard(1),), + shape=global_tensor.size(), + stride=global_tensor.stride(), + ) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor) + + +@pytest.mark.distributed +def test_2d_mesh_shard_and_replicate(distributed_setup): + """2D mesh with Shard + Replicate, for world_size=4 or 8.""" + setup = distributed_setup + world_size = setup["world_size"] + + if world_size == 4: + mesh_shape = (2, 2) + elif world_size == 8: + mesh_shape = (2, 4) + else: + pytest.skip(f"2D mesh test expects world_size 4 or 8, got {world_size}") + + mesh_ids = torch.arange(world_size).reshape(mesh_shape) + mesh = DeviceMesh(setup["device_type"], mesh_ids) + + global_tensor = make_global_randn((16, 12), device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Replicate()]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_multiple_sharded_dims_even(distributed_setup): + """Shard on two dimensions with even splits, 2D mesh.""" + setup = distributed_setup + world_size = setup["world_size"] + + if world_size == 4: + mesh_shape = (2, 2) + elif world_size == 8: + mesh_shape = (2, 4) + else: + pytest.skip(f"2D mesh test expects world_size 4 or 8, got {world_size}") + + mesh_ids = torch.arange(world_size).reshape(mesh_shape) + mesh = DeviceMesh(setup["device_type"], mesh_ids) + + global_tensor = make_global_randn((16, 24), device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Shard(1)]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_multiple_sharded_dims_uneven(distributed_setup): + """Shard on two dimensions with uneven sizes using manual splitting.""" + setup = distributed_setup + world_size = setup["world_size"] + + if world_size == 4: + mesh_shape = (2, 2) + dim0 = 13 + dim1 = 15 + elif world_size == 8: + mesh_shape = (2, 4) + dim0 = 13 + dim1 = 23 + else: + pytest.skip(f"2D mesh test expects world_size 4 or 8, got {world_size}") + + mesh_ids = torch.arange(world_size).reshape(mesh_shape) + mesh = DeviceMesh(setup["device_type"], mesh_ids) + + global_tensor = make_global_randn((dim0, dim1), device=setup["device"]) + + shard0 = Shard(0) + shard1 = Shard(1) + + list0, _ = shard0._split_tensor( + global_tensor, + mesh_shape[0], + with_padding=False, + contiguous=True, + ) + rank0 = setup["rank"] // mesh_shape[1] + intermediate = list0[rank0] + + list1, _ = shard1._split_tensor( + intermediate, + mesh_shape[1], + with_padding=False, + contiguous=True, + ) + rank1 = setup["rank"] % mesh_shape[1] + local = list1[rank1] + + dtensor = DTensor.from_local( + local, + mesh, + (Shard(0), Shard(1)), + shape=global_tensor.size(), + stride=global_tensor.stride(), + ) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_3d_tensor_two_shards(distributed_setup): + """3D tensor with sharding on two dims on 2D mesh.""" + setup = distributed_setup + world_size = setup["world_size"] + + if world_size == 4: + mesh_shape = (2, 2) + elif world_size == 8: + mesh_shape = (2, 4) + else: + pytest.skip(f"2D mesh test expects world_size 4 or 8, got {world_size}") + + mesh_ids = torch.arange(world_size).reshape(mesh_shape) + mesh = DeviceMesh(setup["device_type"], mesh_ids) + + global_tensor = make_global_randn((16, 8, 24), device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0), Shard(2)]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_different_dtypes(distributed_setup): + """Verify correctness across several dtypes.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + for dtype in (torch.float32, torch.float64, torch.int32, torch.int64): + global_tensor = make_global_arange((4, 4), dtype=dtype, device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0)]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.dtype == dtype + assert torch.equal(gathered, global_tensor) + + +@pytest.mark.distributed +def test_large_tensor(distributed_setup): + """Scalability: larger tensor, sharded on dim 0.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_randn((1024, 512), device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0)]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_3d_tensor_single_shard(distributed_setup): + """3D tensor with sharding on dim 0 only.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_randn((8, 6, 4), device=setup["device"]) + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0)]) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor) + + +@pytest.mark.distributed +def test_error_on_invalid_input(distributed_setup): + """Non-DTensor input should raise TypeError.""" + x = torch.randn(4, 4) + with pytest.raises(TypeError): + uneven_dtensor_to_full_tensor(x) + + +@pytest.mark.distributed +def test_backward_compatibility(distributed_setup): + """Check gathered tensor can participate in autograd.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_randn((8, 4), device=setup["device"]) + global_tensor.requires_grad_(True) + + dtensor = distribute_tensor(global_tensor, mesh, [Shard(0)]) + gathered = uneven_dtensor_to_full_tensor(dtensor) + + loss = gathered.sum() + assert loss is not None + + +@pytest.mark.distributed +def test_strided_shard_2d_mesh(distributed_setup): + """ + Test _StridedShard on a 2D mesh, sharding the same dimension across two mesh dims. + This is similar to TP + DP style strided sharding. + """ + setup = distributed_setup + world_size = setup["world_size"] + + if world_size == 4: + mesh_shape = (2, 2) + elif world_size == 8: + mesh_shape = (2, 4) + else: + pytest.skip(f"2D mesh test expects world_size 4 or 8, got {world_size}") + + mesh_ids = torch.arange(world_size).reshape(mesh_shape) + mesh = DeviceMesh(setup["device_type"], mesh_ids) + + rows = 8 + cols = 8 + global_tensor = make_global_randn((rows, cols), device=setup["device"]) + + # Shard dim 0 over both mesh dims; one of them is encoded as _StridedShard. + # Example pattern: [Shard(0), _StridedShard(0, split_factor=mesh_shape[0])] + # so that combined, dim 0 is split across mesh_shape[0] * mesh_shape[1] ranks. [web:146] + placements = [ + Shard(0), + _StridedShard(0, split_factor=mesh_shape[0]), + ] + + dtensor = distribute_tensor(global_tensor, mesh, placements) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + assert gathered.shape == global_tensor.shape + assert torch.allclose(gathered, global_tensor, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_wild_random_uneven_shards(distributed_setup): + """ + Wild random uneven sharding: + - We choose a global shape. + - We randomly assign each rank a slice length (including 0) along a sharded dimension. + - Each rank holds arbitrary local data of its length. + - We build a DTensor via from_local + explicit offsets metadata (through your metadata updater), + and then use uneven_dtensor_to_full_tensor to reconstruct. + """ + setup = distributed_setup + rank = setup["rank"] + world_size = setup["world_size"] + mesh = DeviceMesh(setup["device_type"], list(range(world_size))) + + # Global logical shape on dim 1 (sharded dim) + rows = 2 + max_local = 8 + + if rank == 0: + # Random lengths per rank, allowing zero + lengths = torch.randint(low=0, high=max_local + 1, size=(world_size,), device=setup["device"]) + # Ensure at least one element globally to avoid degenerate zero-sized tensor + if lengths.sum().item() == 0: + lengths[0] = 1 + else: + lengths = torch.empty(world_size, dtype=torch.int64, device=setup["device"]) + + # Broadcast lengths and total from rank 0 + dist.broadcast(lengths, src=0) + + total = int(lengths.sum().item()) + + # Now we know global shape: rows x total + global_cols = total + + # Build a “reference” tensor on all ranks via broadcast from rank 0 + ref_global = make_global_arange((rows, global_cols), dtype=torch.float32, device=setup["device"]) + + # Local slice for this rank is a contiguous segment in dim 1 + start = int(lengths[:rank].sum().item()) + length = int(lengths[rank].item()) + end = start + length + + if length > 0: + local = ref_global[:, start:end].clone() + else: + # zero-sized local shard + local = torch.empty((rows, 0), dtype=ref_global.dtype, device=ref_global.device) + + # Construct DTensor with explicit shape/stride; offsets handled by your metadata helper + dtensor = DTensor.from_local( + local, + mesh, + (Shard(1),), + shape=(rows, global_cols), + stride=(global_cols, 1), + ) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + assert gathered.shape == ref_global.shape + assert torch.allclose(gathered, ref_global, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_wild_random_uneven_shards_multi_dim(distributed_setup): + """ + Wild uneven shards across a 2D mesh and 2 sharded dims, including zero-sized shards. + This stresses: + - multiple mesh dims + - multiple sharded dims + - varying per-rank shapes + """ + setup = distributed_setup + rank = setup["rank"] + world_size = setup["world_size"] + + if world_size == 4: + mesh_shape = (2, 2) + elif world_size == 8: + mesh_shape = (2, 4) + else: + pytest.skip(f"2D mesh test expects world_size 4 or 8, got {world_size}") + + mesh_ids = torch.arange(world_size).reshape(mesh_shape) + mesh = DeviceMesh(setup["device_type"], mesh_ids) + + # Logical global shape + base_rows = 3 + base_cols = 5 + + # Each mesh row gets its own random row-count; each mesh col gets its own random col-count + if rank == 0: + row_chunks = torch.randint(low=0, high=base_rows + 2, size=(mesh_shape[0],), device=setup["device"]) + col_chunks = torch.randint(low=0, high=base_cols + 3, size=(mesh_shape[1],), device=setup["device"]) + if row_chunks.sum().item() == 0: + row_chunks[0] = 1 + if col_chunks.sum().item() == 0: + col_chunks[0] = 1 + else: + row_chunks = torch.empty(mesh_shape[0], dtype=torch.int64, device=setup["device"]) + col_chunks = torch.empty(mesh_shape[1], dtype=torch.int64, device=setup["device"]) + + dist.broadcast(row_chunks, src=0) + dist.broadcast(col_chunks, src=0) + + total_rows = int(row_chunks.sum().item()) + total_cols = int(col_chunks.sum().item()) + + # Global reference tensor + ref_global = make_global_arange( + (total_rows, total_cols), dtype=torch.float32, device=setup["device"] + ) + + # Determine which row/col block this rank owns + mesh_row = rank // mesh_shape[1] + mesh_col = rank % mesh_shape[1] + + row_start = int(row_chunks[:mesh_row].sum().item()) + row_len = int(row_chunks[mesh_row].item()) + row_end = row_start + row_len + + col_start = int(col_chunks[:mesh_col].sum().item()) + col_len = int(col_chunks[mesh_col].item()) + col_end = col_start + col_len + + if row_len > 0 and col_len > 0: + local = ref_global[row_start:row_end, col_start:col_end].clone() + else: + local = torch.empty( + (row_len, col_len), + dtype=ref_global.dtype, + device=ref_global.device, + ) + + dtensor = DTensor.from_local( + local, + mesh, + (Shard(0), Shard(1)), + shape=(total_rows, total_cols), + stride=(total_cols, 1), + ) + + gathered = uneven_dtensor_to_full_tensor(dtensor) + + assert gathered.shape == ref_global.shape + assert torch.allclose(gathered, ref_global, rtol=1e-5, atol=1e-5) + +# --------------------------------------------------------------------------- +# split_dtensor tests +# --------------------------------------------------------------------------- + +@pytest.mark.distributed +def test_split_dtensor_even_shard_dim0(distributed_setup): + """Even split along sharded dim 0, verify each split matches torch.split of global tensor.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_arange((16, 4), dtype=torch.float32, device=setup["device"]) + dt = distribute_tensor(global_tensor, mesh, [Shard(0)]) + + # Split evenly into size 4 along dim 0 + splits = list(split_dtensor(dt, 4, dim=0, update_uneven_dtensor_chunk_meta=True)) + + assert len(splits) == 4 + + # Reference splits on full tensor + ref_splits = torch.split(global_tensor, 4, dim=0) + for i, (chunk_dt, ref) in enumerate(zip(splits, ref_splits)): + gathered = uneven_dtensor_to_full_tensor(chunk_dt) + ref = ref.to(gathered.device) + assert gathered.shape == ref.shape, f"split {i} shape mismatch" + assert torch.allclose(gathered, ref), f"split {i} content mismatch" + + +@pytest.mark.distributed +def test_split_dtensor_uneven_sections_dim1(distributed_setup): + """List-of-sections split along dim 1 on a sharded DTensor.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_randn((8, 13), device=setup["device"]) + dt = distribute_tensor(global_tensor, mesh, [Shard(0)]) + + sections = [3, 5, 5] # sum == 13 + splits = list(split_dtensor(dt, sections, dim=1, update_uneven_dtensor_chunk_meta=True)) + assert len(splits) == len(sections) + + ref_splits = torch.split(global_tensor, sections, dim=1) + for i, (chunk_dt, ref) in enumerate(zip(splits, ref_splits)): + gathered = uneven_dtensor_to_full_tensor(chunk_dt) + ref = ref.to(gathered.device) + assert gathered.shape == ref.shape + assert torch.allclose(gathered, ref, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_split_dtensor_replicate_placement(distributed_setup): + """Splitting a replicated DTensor should behave like splitting the global tensor, no redistribution.""" + setup = distributed_setup + mesh = DeviceMesh(setup["device_type"], list(range(setup["world_size"]))) + + global_tensor = make_global_randn((6, 10), device=setup["device"]) + dt = distribute_tensor(global_tensor, mesh, [Replicate()]) + + splits = list(split_dtensor(dt, 4, dim=1, update_uneven_dtensor_chunk_meta=False)) + ref_splits = torch.split(global_tensor, 4, dim=1) + + assert len(splits) == len(ref_splits) + for i, (chunk_dt, ref) in enumerate(zip(splits, ref_splits)): + # Replicated placement: local == global slice + local = chunk_dt.to_local() + ref = ref.to(local.device) + assert local.shape == ref.shape + assert torch.allclose(local, ref, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_split_dtensor_uneven_shard_with_metadata(distributed_setup): + """Split along dim 0 on an unevenly sharded DTensor and verify correctness.""" + setup = distributed_setup + world_size = setup["world_size"] + mesh = DeviceMesh(setup["device_type"], list(range(world_size))) + + rows = world_size * 3 + 1 # uneven vs world_size + global_tensor = make_global_arange((rows, 4), dtype=torch.float32, device=setup["device"]) + + shard = Shard(0) + local_list, _ = shard._split_tensor( + global_tensor, + world_size, + with_padding=False, + contiguous=True, + ) + local = local_list[setup["rank"]] + + dt = DTensor.from_local( + local, + mesh, + (Shard(0),), + shape=global_tensor.size(), + stride=global_tensor.stride(), + ) + + # Split into size 2 along dim 0 + splits = list(split_dtensor(dt, 2, dim=0, update_uneven_dtensor_chunk_meta=True)) + ref_splits = torch.split(global_tensor, 2, dim=0) + + assert len(splits) == len(ref_splits) + for i, (chunk_dt, ref) in enumerate(zip(splits, ref_splits)): + gathered = uneven_dtensor_to_full_tensor(chunk_dt) + ref = ref.to(gathered.device) + assert gathered.shape == ref.shape + assert torch.allclose(gathered, ref, rtol=1e-5, atol=1e-5) + + +@pytest.mark.distributed +def test_split_dtensor_zero_local_shard(distributed_setup): + """ + Split DTensor where some ranks have zero local data (after an uneven manual layout), + ensuring split_dtensor yields correct empty locals but correct global slices. + """ + setup = distributed_setup + rank = setup["rank"] + world_size = setup["world_size"] + mesh = DeviceMesh(setup["device_type"], list(range(world_size))) + + # Create a manual uneven sharding along dim 1 with possible zero-length local on some ranks + # Similar style to your "wild random uneven" gather test. + if rank == 0: + # random but deterministic lengths + lengths = torch.tensor([0] + [4] * (world_size - 1), dtype=torch.int64, device=setup["device"]) + if lengths.sum().item() == 0: + lengths[0] = 1 # fallback + else: + lengths = torch.empty(world_size, dtype=torch.int64, device=setup["device"]) + + dist.broadcast(lengths, src=0) + total = int(lengths.sum().item()) + + rows = 4 + cols = total + global_tensor = make_global_arange((rows, cols), dtype=torch.float32, device=setup["device"]) + + start = int(lengths[:rank].sum().item()) + length = int(lengths[rank].item()) + end = start + length + + if length > 0: + local = global_tensor[:, start:end].clone() + else: + local = torch.empty((rows, 0), dtype=global_tensor.dtype, device=global_tensor.device) + + dt = DTensor.from_local( + local, + mesh, + (Shard(1),), + shape=global_tensor.size(), + stride=global_tensor.stride(), + ) + + # Split dim 1 into chunks of size 5 + splits = list(split_dtensor(dt, 5, dim=1, update_uneven_dtensor_chunk_meta=True)) + ref_splits = torch.split(global_tensor, 5, dim=1) + + assert len(splits) == len(ref_splits) + for i, (chunk_dt, ref) in enumerate(zip(splits, ref_splits)): + gathered = uneven_dtensor_to_full_tensor(chunk_dt) + ref = ref.to(gathered.device) + assert gathered.shape == ref.shape + assert torch.allclose(gathered, ref, rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/checkpoint/checkpoint_inspector.py b/tools/checkpoint/checkpoint_inspector.py index 3d03f4db959..692a803e2af 100644 --- a/tools/checkpoint/checkpoint_inspector.py +++ b/tools/checkpoint/checkpoint_inspector.py @@ -31,7 +31,7 @@ from torch.distributed.checkpoint.state_dict_saver import _save_state_dict from torch.distributed.tensor import DeviceMesh, Replicate, Shard -from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import split_dtensor, gather_uneven_dtensor_to_full_tensor +from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import split_dtensor, redistribute_uneven_dtensor_to_replicated from megatron.core.dist_checkpointing.serialization import ( get_default_load_sharded_strategy, @@ -399,7 +399,7 @@ def split_layers( _free_up_some_gpu_memory() layers = {} for i, v in enumerate(split_dtensor(value, 1, dim=0)): - v = gather_uneven_dtensor_to_full_tensor(v).reshape( + v = redistribute_uneven_dtensor_to_replicated(v).reshape( orig_shape[1:] if orig_shape else value.shape[1:] ).redistribute(placements=[Shard(0)]) @@ -428,7 +428,7 @@ def split_expert_weights( else: raise ValueError(f"Unexpected expert layer key: {layer_key}") - expert_weight = gather_uneven_dtensor_to_full_tensor(expert_weight) + expert_weight = redistribute_uneven_dtensor_to_replicated(expert_weight) expert_shape = orig_shape[1:] if orig_shape else value.shape[1:] # Handle optimizer states for expert linear_fc2 when ETP is enabled if ( @@ -472,7 +472,7 @@ def split_swiglu_weight(key: str, value: torch.Tensor) -> dict[str, torch.Tensor """ Split SwiGLU weights into separate tensors. """ - value = gather_uneven_dtensor_to_full_tensor(value) + value = redistribute_uneven_dtensor_to_replicated(value) swiglu_w_and_v = {} w, v = torch.chunk(value, 2, dim=0) w = w.redistribute(placements=[Shard(0)]) @@ -531,7 +531,7 @@ def has_layer_index(key: str) -> bool: split_tensors = split_expert_weights(new_key, value, orig_shape) else: if orig_shape: - value = gather_uneven_dtensor_to_full_tensor(value) + value = redistribute_uneven_dtensor_to_replicated(value) # Handle optimizer states with partition_dim=1 when TP is enabled if ( new_key.startswith("optimizer.state.")