From 15c23693f6b4d6b4312c9d506cd653cb7b326c00 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Wed, 11 Mar 2026 17:54:40 +0800 Subject: [PATCH 1/2] Fix forward compatibility issue with MFSDP --grad-reduce-in-bf16 --- megatron/training/arguments.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 2b279f3fd53..c89c2fef48e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -872,6 +872,9 @@ def validate_args(args, defaults={}): args.megatron_fsdp_main_params_dtype = map_dtype(args.megatron_fsdp_main_params_dtype) args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype) args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype) + if args.grad_reduce_in_bf16: + assert args.megatron_fsdp_grad_comm_dtype == torch.bfloat16, \ + "When --grad-reduce-in-bf16 is set, --megatron-fsdp-grad-comm-dtype must be bfloat16" if args.fp8_param_gather: assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \ From a89014bdb3185e53f321e86fc1d620599954b368 Mon Sep 17 00:00:00 2001 From: Jianbin Chang Date: Wed, 11 Mar 2026 18:05:31 +0800 Subject: [PATCH 2/2] Update megatron/training/arguments.py Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> --- megatron/training/arguments.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c89c2fef48e..870e0b2719c 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -873,8 +873,7 @@ def validate_args(args, defaults={}): args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype) args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype) if args.grad_reduce_in_bf16: - assert args.megatron_fsdp_grad_comm_dtype == torch.bfloat16, \ - "When --grad-reduce-in-bf16 is set, --megatron-fsdp-grad-comm-dtype must be bfloat16" + args.megatron_fsdp_grad_comm_dtype = torch.bfloat16 if args.fp8_param_gather: assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \