We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
--grad-reduce-in-bf16
1 parent 39472d8 commit d997820Copy full SHA for d997820
megatron/training/arguments.py
@@ -872,6 +872,8 @@ def validate_args(args, defaults={}):
872
args.megatron_fsdp_main_params_dtype = map_dtype(args.megatron_fsdp_main_params_dtype)
873
args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype)
874
args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype)
875
+ if args.grad_reduce_in_bf16:
876
+ args.megatron_fsdp_grad_comm_dtype = torch.bfloat16
877
878
if args.fp8_param_gather:
879
assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \
0 commit comments