diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 2b279f3fd53..870e0b2719c 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -872,6 +872,8 @@ def validate_args(args, defaults={}): args.megatron_fsdp_main_params_dtype = map_dtype(args.megatron_fsdp_main_params_dtype) args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype) args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype) + if args.grad_reduce_in_bf16: + args.megatron_fsdp_grad_comm_dtype = torch.bfloat16 if args.fp8_param_gather: assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \