Skip to content

Bug when using optimizer LAMB 32bits #1350

Open
@FrsECM

Description

@FrsECM

System Info

wsl Ubuntu22.04, Python3.10, bnb 0.43.1

Reproduction

In order to reproduce the issue you can do this :

import bitsandbytes as bnb
import torch
import torch.nn as nn

model = nn.Linear(10,2).cuda()
model.train()
# We create an optimizer
optimizer = bnb.optim.LAMB(model.parameters())
# We create dummy input / output
input = torch.rand(size=(10,10)).cuda()
target = torch.zeros(10).cuda()

# We compute prediction / loss
optimizer.zero_grad()
prediction = model(input)
loss =nn.CrossEntropyLoss()(prediction, target.long())

loss.backward()
optimizer.step()

It will result in something like that :

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[12], [line 19](vscode-notebook-cell:?execution_count=12&line=19)
     [16](vscode-notebook-cell:?execution_count=12&line=16) loss =nn.CrossEntropyLoss()(prediction, target.long())
     [18](vscode-notebook-cell:?execution_count=12&line=18) loss.backward()
---> [19](vscode-notebook-cell:?execution_count=12&line=19) optimizer.step()

File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    [386](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:386)         else:
    [387](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:387)             raise RuntimeError(
    [388](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:388)                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    [389](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:389)             )
--> [391](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:391) out = func(*args, **kwargs)
    [392](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:392) self._optimizer_step_code()
    [394](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/optim/optimizer.py:394) # call optimizer step post hooks

File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File /home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:287, in Optimizer8bit.step(self, closure)
    [284](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:284)             self.init_state(group, p, gindex, pindex)
    [286](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/optim/optimizer.py:286)         self.prefetch_state(p)
...
-> [1584](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1584)     optim_func = str2optimizer32bit[optimizer_name][0]
   [1585](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1585) elif g.dtype == torch.float16:
   [1586](https://vscode-remote+wsl-002bwsl4datascience.vscode-resource.vscode-cdn.net/home/default/miniconda/envs/domf_iris2/lib/python3.10/site-packages/bitsandbytes/functional.py:1586)     optim_func = str2optimizer32bit[optimizer_name][1]

KeyError: 'lamb'

Expected behavior

The optimizer should be working.

Metadata

Metadata

Assignees

No one assigned

    Labels

    OptimizersIssues or feature requests relating to optimizersbugSomething isn't workingcontributions-welcomeWe welcome contributions to fix this issue!low priority(will be worked on after all priority issues)

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions