-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
When using sequential generation strategy with models that have rope_local_base_freq configured (e.g., Gemma-3 models), the KV cache is initialized with incorrect dimensions, causing a runtime error.
Reproduction script:
from pathlib import Path
from litgpt.api import LLM
# Path to downloaded checkpoint
checkpoint_dir = Path("checkpoints/google/gemma-3-1b-it")
llm = LLM.load(str(checkpoint_dir), distribute=None)
llm.distribute(
devices=1,
accelerator="cuda",
generate_strategy="sequential",
fixed_kv_cache_size=2048,
)
output = llm.generate("What do llamas eat?", max_new_tokens=50)
print(f"\nOutput: {output}")Log Message:
⚡ ~ python main.py
Using 1 device(s)
Precision set
Fabric launched
Moving '_forward_module.transformer.h.25' to cuda:0: 100%|███████████████████████████████████████████| 26/26 [00:00<00:00, 91.66it/s]
Traceback (most recent call last):
File "/teamspace/studios/this_studio/main.py", line 16, in <module>
output = llm.generate("What do llamas eat?", max_new_tokens=50)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/api.py", line 552, in generate
outputs = generate_fn(
^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/generate/base.py", line 413, in generate
token_list = list(
^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 38, in generator_context
response = gen.send(None)
^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/generate/base.py", line 186, in generate_fn
token = next_token(
^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/generate/base.py", line 83, in next_token
logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/lightning/fabric/wrappers.py", line 136, in forward
output = self._forward_module(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/model.py", line 161, in forward
x = block(
^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/model.py", line 330, in forward
attention_output = self.attn(x_normed, cos, sin, mask, input_pos, input_pos_maxp1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/model.py", line 459, in forward
k, v = self.kv_cache(input_pos, k, v)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/model.py", line 1081, in forward
k = batched_index_copy_(self.k[:bs, ...], -2, cache_positions, k)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/teamspace/studios/this_studio/litgpt/litgpt/model.py", line 980, in batched_index_copy_
return t.index_copy_(dim, idx, val)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: index_copy_(): Source/destination tensor must have same slice shapes. Destination slice shape: 1 1 2 at dimension 2 and source slice shape: 1 1 256 at dimension 0.
Reproduced in studio
https://lightning.ai/23110065/personal/studios/litgpt-bug/code?source=copylink
What operating system are you using?
Linux
LitGPT Version
Version: 0.5.12
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working