Skip to content

AdEMA NaN when loading from state_dict #1382

Open
@darius-lam

Description

@darius-lam

System Info

Running a standard training loop where I save the optimizer state_dict using opt.state_dict().
Upon loading using opt.load_state_dict() to resume, the model immediately NaNs after the first backprop step.

This only occurs using the AdEMA optimizer:

bnb.optim.AdEMAMix8bit(model.parameters(), lr=lr, t_alpha=T, t_beta3=T)

AdamW and others load state dict perfectly fine. Any ideas?

Reproduction

`
opt = bnb.optim.AdEMAMix8bit(model.parameters())
#run training loop
torch.save(opt.state_dict(), "dt.pt")

#try resuming opt from state_dict later
opt.load_state_dict("dt.pt")
#run training loop again
`

Expected behavior

Optimizer should resume training without NaNning

Metadata

Metadata

Labels

OptimizersIssues or feature requests relating to optimizersbugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions