Skip to content

Expose FSDP2 MixedPrecisionPolicy params #2267

Open
@EugenHotaj

Description

It would be a good user experience improvement to expose FSDP2 MixedPrecisionPolicy to be set through the config, at least for param_dtype and reduce_dtype. These are important parameters when training in low precision (e.g. bf16) and right now are only changeable by hardcoding training.fully_shard. See #2254 for why these are important parameters.

As a suggestion, we may want to hardcode reduce_dtype=torch.float32 by default. I don't think it reduces training speed at all (but we should check) and helps with convergence / stability.

Metadata

Assignees

Labels

enhancementNew feature or requesttriagedThis issue has been assigned an owner and appropriate label

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions