Description
We've noticed very bad convergence when training in bf16
vs fp32
.
As a comparison, here are the loss curves between bf16
:
This is a full finetune of 8B llama running on 8 nodes (64 GPUS) but the issue exists even on 1 node (8 GPUS). The runs are identical besides the dtype
. Notice that even after 250 steps the bf16
run does not go below 0.7
loss. In theory, it should be possible to get similar convergence rates with either dtype
(at least I think there are multiple existence proofs inside Meta 😛).
One thing I tried doing was setting FSDP's reduce_dtype=fp32
(had to hardcode because torchtune doesn't expose this option AFAICT) but it did not seem to help much. Any other options we should be looking into?
Need to confirm this but I think one thing that would greatly help is to keep optimizer states in fp32
. It would use a lot more memory than end-to-end bf16
but at least it would not slow down training as much as doing everything in fp32
. Is there an easy way to do this in torchtune/pytorch? Would doing something like below work?
model = create_model(dtype=fp32)
optimizer(model.parameters())
model.to(bf16)