Skip to content

Very slow convergence with bf16 #2254

Open
@EugenHotaj

Description

We've noticed very bad convergence when training in bf16 vs fp32.

As a comparison, here are the loss curves between bf16:
Screenshot 2025-01-11 at 2 07 33 PM

and fp32:
Screenshot 2025-01-11 at 2 08 40 PM

This is a full finetune of 8B llama running on 8 nodes (64 GPUS) but the issue exists even on 1 node (8 GPUS). The runs are identical besides the dtype. Notice that even after 250 steps the bf16 run does not go below 0.7 loss. In theory, it should be possible to get similar convergence rates with either dtype (at least I think there are multiple existence proofs inside Meta 😛).

One thing I tried doing was setting FSDP's reduce_dtype=fp32 (had to hardcode because torchtune doesn't expose this option AFAICT) but it did not seem to help much. Any other options we should be looking into?

Need to confirm this but I think one thing that would greatly help is to keep optimizer states in fp32. It would use a lot more memory than end-to-end bf16 but at least it would not slow down training as much as doing everything in fp32. Is there an easy way to do this in torchtune/pytorch? Would doing something like below work?

model = create_model(dtype=fp32)
optimizer(model.parameters())
model.to(bf16)

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions