Skip to content

Commit f82a451

Browse files
committed
Support 'auto' argument which defaults to pre-MixedPrecisionPolicy behavior for supporting per-parameter grad dtypes.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent e20b89a commit f82a451

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

megatron/training/arguments.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def validate_args(args, defaults={}):
858858

859859
# Map string data-type to torch.dtype.
860860
dtype_map = {
861-
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
861+
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8, 'auto': None,
862862
}
863863
map_dtype = lambda d: d if isinstance(d, torch.dtype) else dtype_map[d]
864864

@@ -3132,15 +3132,19 @@ def _add_experimental_args(parser):
31323132
'the precision in the kernel computation.')
31333133

31343134
# Megatron-FSDP Arguments
3135-
group.add_argument('--megatron-fsdp-main-params-dtype', default='fp32', choices=['fp32', 'bf16', 'fp16'],
3135+
group.add_argument('--megatron-fsdp-main-params-dtype', default='fp32', choices=['fp32', 'bf16', 'fp16', 'auto'],
31363136
help="Data type for the main weight buffer utilized for distributed optimization "
3137-
"and quantization with Megatron-FSDP.")
3138-
group.add_argument('--megatron-fsdp-main-grads-dtype', default='fp32', choices=['fp32', 'bf16', 'fp16'],
3137+
"and quantization with Megatron-FSDP. If 'auto', then the native model parameter "
3138+
"data-type will be used for the main weight data-type.")
3139+
group.add_argument('--megatron-fsdp-main-grads-dtype', default='fp32', choices=['fp32', 'bf16', 'fp16', 'auto'],
31393140
help="Data type for the main gradient buffer utilized for distributed optimization "
3140-
"with Megatron-FSDP.")
3141-
group.add_argument("--megatron-fsdp-grad-comm-dtype", default='fp32', choices=['fp32', 'fp16', 'bf16'],
3141+
"with Megatron-FSDP. If 'auto', then the native model gradient data-type will "
3142+
"be used for the main gradient / accumulation data-type.")
3143+
group.add_argument("--megatron-fsdp-grad-comm-dtype", default='fp32', choices=['fp32', 'fp16', 'bf16', 'auto'],
31423144
help="When using Megatron-FSDP, this controls the data-type used when communicating "
3143-
"model gradients during FSDP.")
3145+
"model gradients during FSDP. If 'auto', then the main gradient data-type will "
3146+
"be used for the gradient communication / reduction data-type. When using NCCL "
3147+
"v2.27+, reduction is always computed in FP32 if using NCCL Symmetric kernels.")
31443148

31453149
return parser
31463150

0 commit comments

Comments
 (0)