Skip to content

Commit d997820

Browse files
shjwudpclaude[bot]
andauthored
Fix backward compatibility issue with MFSDP --grad-reduce-in-bf16 (#3799)
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
1 parent 39472d8 commit d997820

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

megatron/training/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@ def validate_args(args, defaults={}):
872872
args.megatron_fsdp_main_params_dtype = map_dtype(args.megatron_fsdp_main_params_dtype)
873873
args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype)
874874
args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype)
875+
if args.grad_reduce_in_bf16:
876+
args.megatron_fsdp_grad_comm_dtype = torch.bfloat16
875877

876878
if args.fp8_param_gather:
877879
assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \

0 commit comments

Comments
 (0)