Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/gpt_cp_recipe/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pt
22 changes: 22 additions & 0 deletions recipes/gpt_cp_recipe/3rdparty_nemo_evo2_tmp.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
index 67500615e0..147e78cfa1 100644
--- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
+++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
@@ -647,7 +647,7 @@ class ExplicitSingleDecayFilter(nn.Module):
"""
return self.filter(L, *args, **kwargs)

- @torch.compile(mode="max-autotune")
+ #@torch.compile(mode="max-autotune")
def filter(self, L, *args, **kwargs):
"""Compute the filter as a function of h and decay for the requested sequence length."""
h = self.h[:, :L]
@@ -834,7 +834,7 @@ class ParallelHyenaOperator(nn.Module):
self.conv_bias.data = conv_init_method(self.conv_bias.data)
self.conv_bias.model_parallel = True
self.conv_bias.partition_dim = 0
- self.conv_bias.stride = 1
+ #self.conv_bias.stride = 1

def forward_long(self, *, x1, x2, v, h, bias, inference_context):
"""Forward pass long."""
4 changes: 4 additions & 0 deletions recipes/gpt_cp_recipe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Steps to run
1. apply `.patch` to nemo so that nvfsdp can run
2. run `run.sh` so that the .pt files are made
3. Run `python compare_grads.py` to collapse distributed grads and compare to single process grads
25 changes: 25 additions & 0 deletions recipes/gpt_cp_recipe/compare_grads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python3

import torch
import torch.distributed.tensor
def combine(layer_name, cp_0, cp_1):
cp_0_l = cp_0[layer_name]
cp_1_l = cp_1[layer_name]
if cp_0_l is not None and cp_1_l is not None:
return torch.cat([cp_0_l.to_local().cpu(), cp_1_l.to_local().cpu()])
if cp_0_l is not None:
return cp_0_l.to_local().cpu()
return cp_1_l.to_local().cpu()


cp_0 = torch.load("cp_2_grads_rnk_0_0.pt", weights_only=False)
cp_1 = torch.load("cp_2_grads_rnk_1_0.pt", weights_only=False)
combined_cp = {key: combine(key, cp_0, cp_1) for key in cp_0.keys()}

base = {k: v.to_local().cpu() for k, v in torch.load("cp_1_grads_rnk_0_0.pt", weights_only=False).items()}

assert set(combined_cp.keys()) == set(base.keys())

for k in combined_cp.keys():
print(f"Testing {k}")
torch.testing.assert_close(combined_cp[k], base[k])
22 changes: 22 additions & 0 deletions recipes/gpt_cp_recipe/megatron_fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .distributed_data_parallel_config import DistributedDataParallelConfig
from .megatron_fsdp import MegatronFSDP
from .utils import FSDPDistributedIndex

try:
from .fully_shard import fully_shard
except ImportError as e:
print(f"Failed to import fully_shard: {e}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from dataclasses import dataclass
from typing import Optional


@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""

grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""

overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""

overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute."""

align_param_gather: bool = False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""

use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""

num_distributed_optimizer_instances: int = 1
"""Sets the factor by which the DP domain is sharded to have the partial DistOpt
enabled. Defaults to 1, which means DistOpt is across entire DP domain.
"""

check_for_nan_in_grad: bool = False
"""If true, check for NaNs and Infs in gradients _before_ communication collective."""

check_for_large_grads: bool = False
"""If true, check for unexpectedly large gradients _before_ communication collective."""

bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""

pad_buckets_for_high_nccl_busbw: bool = False
"""If true, make sure the bucket size is divisible by a large power of 2 (2^16) to
ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL
message size (which for ring algorithms is bucket_size / dp_size) apparently needs
to be divisible by a power of 2 for high busbw."""

average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""

fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""

reuse_grad_buf_for_mxfp8_param_ag: bool = False
"""If true, reuse the grad buffer for param AG when using mxfp8 recipe. Should be
set to True only when fp8_recipe is mxfp8 and fp8_param_gather is True."""

use_megatron_fsdp: bool = False
"""If true, use the FSDP code path for DDP."""

use_custom_fsdp: bool = False
"""
NOTE: The flag `use_custom_fsdp` is deprecated and will be removed in future versions.
Please use `use_megatron_fsdp` instead, as all functionality will be migrated there.
Future updates will drop support for `use_custom_fsdp` to avoid confusion.
"""

data_parallel_sharding_strategy: str = 'no_shard'
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
'optim_grads', 'optim_grads_params'."""

gradient_reduce_div_fusion: bool = True
"""If true, perform gradient reduce and division fusion."""

suggested_communication_unit_size: int = None
"""Specifies the number of elements to communicate at once during
FSDP (Fully Sharded Data Parallel) operations.
This flag also affects FSDP all-gather prefetch behavior. Setting a larger
value increases the communication buffer size, while a smaller value
disables prefetching and may degrade performance. Adjust this value
based on your system's memory and performance requirements."""

preserve_fp32_weights: bool = True
"""If true, preserve fp32 weights in the Megatron FSDP ParamAndGradBuffer."""

keep_fp8_transpose_cache: bool = False
"""If true, keep the fp8 transpose cache when using Megatron FSDP."""

nccl_ub: bool = False
"""If true, allocate and register NCCL userbuffer for param and grad buffer.
This flag enables SM efficient nccl algorithm that could improve the performance
of FSDP and DP with comm_overlap. This flag will be much more effective when used
together with sharp.
The follwoing will be the expected number of SM usage for various cases.
(Note that this is just a reference number and the number of SM usage could vary
on message size, communication domain size and nccl version.)
----------------------------------------------------------
| Communication domain | use_sharp | SM usage of "AG/RS" |
|----------------------|-----------|---------------------|
| NVL | N/A | 4 / 5 |
| NVL+IB | False | 16 / 16 |
| NVL+IB | True | 6 / 6 |
| IB | False | 1 / 4 |
| IB | True | 1 / 1 |
----------------------------------------------------------
"""

fsdp_double_buffer: bool = False
"""If true, use persistently allocated double buffers for the
temporary memory needed in the Megatron FSDP communications.
This option will cause additional memory overhead, however, it is necessary for
to register user buffer (nccl_ub=True) for the Megatron FSDP.
This option will be automatically set to True when nccl_ub=True.
"""

outer_dp_sharding_strategy: str = 'no_shard'
"""
Sharding strategy for outer data parallel group in Hybrid Sharded Data Parallel (HSDP) mode.
Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'.
This option is only effective when Hybrid FSDP is enabled.
"""

def __post_init__(self):
import os

"""Check the validity of the config."""
if self.reuse_grad_buf_for_mxfp8_param_ag:
assert self.fp8_param_gather, "Reuse grad buffer only when keeping params in MXFP8."

if self.nccl_ub:
if 'expandable_segments:True' in os.getenv('PYTORCH_CUDA_ALLOC_CONF', '').split(','):
raise ValueError(
"PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True is currently not supported "
"with nccl_ub due to compatibility issue with torch.cuda.MemPool API."
)
Loading