Skip to content

Conversation

@adi776borate
Copy link

What does this PR do?

Fixes #2178

The root cause of the bug:

In litgpt/generate/sequentially.py line 65, the code computes rope_cache_length as:
model.cos.size(-1)
For models with rope_local_base_freq, the RoPE cache has shape (seq_len, n_elem, 2) instead of (seq_len, n_elem).
Using .size(-1) returns 2 instead of the correct n_elem (e.g., 128), causing the KV cache to be initialized with incorrect head dimensions.

Solution:

In litgpt/generate/sequentially.py:

# Before (buggy):
submodule.attn.kv_cache = submodule.attn.build_kv_cache(
    1, max_seq_length, model.cos.size(-1), target_device
)

# After (fixed):
if len(model.cos.shape) == 2:
    rope_cache_length = model.cos.size(-1)
elif len(model.cos.shape) == 3:
    rope_cache_length = model.cos.size(1)  # Get n_elem dimension
else:
    rope_cache_length = model.cos.size(-1)

submodule.attn.kv_cache = submodule.attn.build_kv_cache(
    1, max_seq_length, rope_cache_length, target_device
)

Testing:

from pathlib import Path
from litgpt.api import LLM

# Path to downloaded checkpoint
checkpoint_dir = Path("checkpoints/google/gemma-3-1b-it")

llm = LLM.load(str(checkpoint_dir), distribute=None)

llm.distribute(
    devices=1,
    accelerator="cuda",
    generate_strategy="sequential",
    fixed_kv_cache_size=2048,
)

output = llm.generate("What do llamas eat?", max_new_tokens=2000)
print(f"\nOutput: {output}")

Above script produces the expected output when the fix is applied.

Who can review?

Anyone in the community is free to review the PR once the tests have passed.

Copy link
Collaborator

@bhimrazy bhimrazy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @adi776borate, nice catch.

I also refactored the logic abit to a rope_cache_length method and also added a test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KV cache dimension error in sequential generation for models with rope_local_base_freq

3 participants