-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Labels
bugSomething isn't workingSomething isn't workingcommunity-requestmodule: megatron-fsdpneeds-follow-upIssue needs follow-upIssue needs follow-up
Description
Describe the bug
When training Qwen3-30B-A3B on 32 H800s, I found that using Megatron-FSDP + Expert Parallelism(EP) consumes more GPU memory than using EP alone.
| EP8 | EP8 + FSDP | |
|---|---|---|
| throughput(TFLOPS) | 137.8 | 111.0 |
| Max Reserved(MB) | 45552.0 | 55638.0 |
Steps/Code to reproduce bug
Megatron-LM version: 268fda0
TransformerEngine version: 2.10 release
EP8 + FSDP args:
#!/bin/bash
set -ex
unset CUDA_DEVICE_MAX_CONNECTIONS
GPUS_PER_NODE=8
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-"29345"}
NUM_NODES=${NUM_NODES:-"1"}
NODE_RANK=${NODE_RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
LOG_DIR="/mnt/"
mkdir -p ${LOG_DIR}
if [ "${NODE_RANK}" -eq "0" ]; then
cp ${0} ${LOG_DIR}
fi
CHECKPOINT_PATH="${LOG_DIR}/checkpoint"
DATA_PATH="/mnt/alpaca_text_document"
TOKENIZER_MODEL="/mnt/tokenziers/Qwen3"
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=INIT
export NCCL_NVLS_ENABLE="0"
export NCCL_CUMEM_ENABLE="0"
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
export NVTE_DEBUG=1
export NVTE_DEBUG_LEVEL=2
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--node-rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--use-mcore-models
--seq-length 4096
--num-layers 48
--hidden-size 2048
--ffn-hidden-size 6144
--num-attention-heads 32
--kv-channels 128
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--qk-layernorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--no-masked-softmax-fusion
--no-position-embedding
--disable-bias-linear
--norm-epsilon 1e-06
--seed 1234
--rotary-base 1000000
--max-position-embeddings 32768
--group-query-attention
--num-query-groups 4
--attention-backend flash
--bf16
)
MOE_ARGS=(
--num-experts 128
--moe-router-topk 8
--moe-ffn-hidden-size 768
--moe-router-load-balancing-type aux_loss
--moe-router-dtype fp32
--moe-aux-loss-coeff 0.001
--moe-router-score-function softmax
--moe-grouped-gemm
--moe-token-dispatcher-type flex
--moe-enable-deepep
--moe-deepep-num-sms 48
--moe-permute-fusion
--moe-router-fusion
--use-fused-weighted-squared-relu
--moe-router-force-load-balancing
)
DATA_ARGS=(
--tokenizer-type HuggingFaceTokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--vocab-size 151936
--padded-vocab-size 151936
--make-vocab-size-divisible-by 8
--data-path $DATA_PATH
--split 99990,8,2
--num-workers 2
)
TRAINING_ARGS=(
--micro-batch-size 2
--global-batch-size 64
--lr 1e-05
--train-iters 10000
--lr-decay-style cosine
--min-lr 5e-06
--weight-decay 0.01
--clip-grad 1.0
--lr-warmup-iters 100
--lr-decay-iters 10000
--adam-beta1 0.9
--adam-beta2 0.95
--adam-eps 1e-08
--accumulate-allreduce-grads-in-fp32
--cross-entropy-loss-fusion
--cross-entropy-fusion-impl te
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--sequence-parallel
--expert-model-parallel-size 8
--expert-tensor-parallel-size 1
--use-distributed-optimizer
--overlap-grad-reduce
--overlap-param-gather
)
FSDP_ARGS=(
--use-megatron-fsdp
--init-model-with-meta-device
--data-parallel-sharding-strategy optim_grads_params
--no-gradient-accumulation-fusion
--calculate-per-token-loss
--ckpt-format fsdp_dtensor
)
LOGGING_ARGS=(
--log-interval 1
--log-num-zeros-in-grad
--log-throughput
--log-progress
--save-interval 10000
--eval-interval 5000
--eval-iters 10
--save $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
--tensorboard-dir $CHECKPOINT_PATH
--no-load-optim
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"megatron"}
--wandb-exp-name ${WANDB_NAME:-"experiment"}
)
fi
nohup \
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]} \
${FSDP_ARGS[@]} \
> ${LOG_DIR}/log_${NODE_RANK}.txt 2>&1 &
echo "log file: ${LOG_DIR}/log_${NODE_RANK}.txt"EP8 args:
#!/bin/bash
set -ex
unset CUDA_DEVICE_MAX_CONNECTIONS
GPUS_PER_NODE=8
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-"29345"}
NUM_NODES=${NUM_NODES:-"1"}
NODE_RANK=${NODE_RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
LOG_DIR="/mnt/"
mkdir -p ${LOG_DIR}
if [ "${NODE_RANK}" -eq "0" ]; then
cp ${0} ${LOG_DIR}
fi
CHECKPOINT_PATH="${LOG_DIR}/checkpoint"
DATA_PATH="/mnt/alpaca_text_document"
TOKENIZER_MODEL="/mnt/tokenziers/Qwen3"
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=INIT
export NCCL_NVLS_ENABLE="0"
export NCCL_CUMEM_ENABLE="0"
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
export NVTE_DEBUG=1
export NVTE_DEBUG_LEVEL=2
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--node-rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--use-mcore-models
--seq-length 4096
--num-layers 48
--hidden-size 2048
--ffn-hidden-size 6144
--num-attention-heads 32
--kv-channels 128
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--qk-layernorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--no-masked-softmax-fusion
--no-position-embedding
--disable-bias-linear
--norm-epsilon 1e-06
--seed 1234
--rotary-base 1000000
--max-position-embeddings 32768
--group-query-attention
--num-query-groups 4
--attention-backend flash
--bf16
)
MOE_ARGS=(
--num-experts 128
--moe-router-topk 8
--moe-ffn-hidden-size 768
--moe-router-load-balancing-type aux_loss
--moe-router-dtype fp32
--moe-aux-loss-coeff 0.001
--moe-router-score-function softmax
--moe-grouped-gemm
--moe-token-dispatcher-type flex
--moe-enable-deepep
--moe-deepep-num-sms 48
--moe-permute-fusion
--moe-router-fusion
--use-fused-weighted-squared-relu
--moe-router-force-load-balancing
)
DATA_ARGS=(
--tokenizer-type HuggingFaceTokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--vocab-size 151936
--padded-vocab-size 151936
--make-vocab-size-divisible-by 8
--data-path $DATA_PATH
--split 99990,8,2
--num-workers 2
)
TRAINING_ARGS=(
--micro-batch-size 2
--global-batch-size 64
--lr 1e-05
--train-iters 10000
--lr-decay-style cosine
--min-lr 5e-06
--weight-decay 0.01
--clip-grad 1.0
--lr-warmup-iters 100
--lr-decay-iters 10000
--adam-beta1 0.9
--adam-beta2 0.95
--adam-eps 1e-08
--accumulate-allreduce-grads-in-fp32
--cross-entropy-loss-fusion
--cross-entropy-fusion-impl te
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--sequence-parallel
--expert-model-parallel-size 8
--expert-tensor-parallel-size 1
--use-distributed-optimizer
--overlap-grad-reduce
--overlap-param-gather
)
LOGGING_ARGS=(
--log-interval 1
--log-num-zeros-in-grad
--log-throughput
--log-progress
--save-interval 10000
--eval-interval 5000
--eval-iters 10
--save $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
--tensorboard-dir $CHECKPOINT_PATH
--no-load-optim
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"megatron"}
--wandb-exp-name ${WANDB_NAME:-"experiment"}
)
fi
nohup \
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]} \
> ${LOG_DIR}/log_${NODE_RANK}.txt 2>&1 &Expected behavior
FSDP + EP should consume less gpu memory than EP.
Additional context
collected env:
Collecting environment information...
PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.2.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-153-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H800
GPU 1: NVIDIA H800
GPU 2: NVIDIA H800
GPU 3: NVIDIA H800
GPU 4: NVIDIA H800
GPU 5: NVIDIA H800
GPU 6: NVIDIA H800
GPU 7: NVIDIA H800
Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-79
Off-line CPU(s) list: 80-223
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480+
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
Stepping: 8
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear pconfig flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.3 MiB (112 instances)
L1i cache: 3.5 MiB (112 instances)
L2 cache: 224 MiB (112 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flash-attn-3==3.0.0b1+20251201.cu128torch280cxx11abitrue.672381
[pip3] mypy_extensions==1.1.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] onnx==1.20.0
[pip3] onnx-ir==0.1.13
[pip3] onnxruntime==1.23.2
[pip3] onnxscript==0.5.7
[pip3] stepccl==0.0.5.post5+torch2.8.0cu128
[pip3] steprpc==0.2.5.post4+torch2.8.0cu128
[pip3] torch==2.8.0+cu128
[pip3] torchaudio==2.8.0+cu128
[pip3] torchvision==0.23.0+cu128
[pip3] triton==3.4.0
[conda] Could not collect
zhujian19891203
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcommunity-requestmodule: megatron-fsdpneeds-follow-upIssue needs follow-upIssue needs follow-up