Open
Description
It would be a good user experience improvement to expose FSDP2 MixedPrecisionPolicy
to be set through the config, at least for param_dtype
and reduce_dtype
. These are important parameters when training in low precision (e.g. bf16
) and right now are only changeable by hardcoding training.fully_shard
. See #2254 for why these are important parameters.
As a suggestion, we may want to hardcode reduce_dtype=torch.float32
by default. I don't think it reduces training speed at all (but we should check) and helps with convergence / stability.