Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,9 @@ def validate_args(args, defaults={}):
args.megatron_fsdp_main_params_dtype = map_dtype(args.megatron_fsdp_main_params_dtype)
args.megatron_fsdp_main_grads_dtype = map_dtype(args.megatron_fsdp_main_grads_dtype)
args.megatron_fsdp_grad_comm_dtype = map_dtype(args.megatron_fsdp_grad_comm_dtype)
if args.grad_reduce_in_bf16:
assert args.megatron_fsdp_grad_comm_dtype == torch.bfloat16, \
"When --grad-reduce-in-bf16 is set, --megatron-fsdp-grad-comm-dtype must be bfloat16"

if args.fp8_param_gather:
assert args.use_distributed_optimizer or args.use_torch_fsdp2 or args.use_megatron_fsdp or not torch.is_grad_enabled(), \
Expand Down
Loading