Hi~
I am currently following the hf version for exploration.
But find that when update KV cache in Llama (NousResearch/Yarn-Llama-2-7b-128k).
The updated empty caches' length is always 256 (line 528):
past_kv = torch.cat([past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1)
I think it should be
past_kv = torch.cat([past_kv, torch.empty(bsz, kv.size(1), 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1)
Is that right? Or I misunderstand this procedure?