Open
Description
System Info
Running a standard training loop where I save the optimizer state_dict using opt.state_dict().
Upon loading using opt.load_state_dict() to resume, the model immediately NaNs after the first backprop step.
This only occurs using the AdEMA optimizer:
bnb.optim.AdEMAMix8bit(model.parameters(), lr=lr, t_alpha=T, t_beta3=T)
AdamW and others load state dict perfectly fine. Any ideas?
Reproduction
`
opt = bnb.optim.AdEMAMix8bit(model.parameters())
#run training loop
torch.save(opt.state_dict(), "dt.pt")
#try resuming opt from state_dict later
opt.load_state_dict("dt.pt")
#run training loop again
`
Expected behavior
Optimizer should resume training without NaNning