@@ -858,7 +858,7 @@ def validate_args(args, defaults={}):
858858
859859 # Map string data-type to torch.dtype.
860860 dtype_map = {
861- 'fp32' : torch .float32 , 'bf16' : torch .bfloat16 , 'fp16' : torch .float16 , 'fp8' : torch .uint8 ,
861+ 'fp32' : torch .float32 , 'bf16' : torch .bfloat16 , 'fp16' : torch .float16 , 'fp8' : torch .uint8 , 'auto' : None ,
862862 }
863863 map_dtype = lambda d : d if isinstance (d , torch .dtype ) else dtype_map [d ]
864864
@@ -3132,15 +3132,19 @@ def _add_experimental_args(parser):
31323132 'the precision in the kernel computation.' )
31333133
31343134 # Megatron-FSDP Arguments
3135- group .add_argument ('--megatron-fsdp-main-params-dtype' , default = 'fp32' , choices = ['fp32' , 'bf16' , 'fp16' ],
3135+ group .add_argument ('--megatron-fsdp-main-params-dtype' , default = 'fp32' , choices = ['fp32' , 'bf16' , 'fp16' , 'auto' ],
31363136 help = "Data type for the main weight buffer utilized for distributed optimization "
3137- "and quantization with Megatron-FSDP." )
3138- group .add_argument ('--megatron-fsdp-main-grads-dtype' , default = 'fp32' , choices = ['fp32' , 'bf16' , 'fp16' ],
3137+ "and quantization with Megatron-FSDP. If 'auto', then the native model parameter "
3138+ "data-type will be used for the main weight data-type." )
3139+ group .add_argument ('--megatron-fsdp-main-grads-dtype' , default = 'fp32' , choices = ['fp32' , 'bf16' , 'fp16' , 'auto' ],
31393140 help = "Data type for the main gradient buffer utilized for distributed optimization "
3140- "with Megatron-FSDP." )
3141- group .add_argument ("--megatron-fsdp-grad-comm-dtype" , default = 'fp32' , choices = ['fp32' , 'fp16' , 'bf16' ],
3141+ "with Megatron-FSDP. If 'auto', then the native model gradient data-type will "
3142+ "be used for the main gradient / accumulation data-type." )
3143+ group .add_argument ("--megatron-fsdp-grad-comm-dtype" , default = 'fp32' , choices = ['fp32' , 'fp16' , 'bf16' , 'auto' ],
31423144 help = "When using Megatron-FSDP, this controls the data-type used when communicating "
3143- "model gradients during FSDP." )
3145+ "model gradients during FSDP. If 'auto', then the main gradient data-type will "
3146+ "be used for the gradient communication / reduction data-type. When using NCCL "
3147+ "v2.27+, reduction is always computed in FP32 if using NCCL Symmetric kernels." )
31443148
31453149 return parser
31463150
0 commit comments