KV_cache的源码解释 #547
Answered
by
jingyaogong
ciaoyizhen
asked this question in
Q&A
-
|
kv_cache的原理我是很清楚的,但是我很难将代码和这个联系起来。 我就在想,在推理的时候,是不是forward的时候 他的维度不是[batch_size, seq_len, hidden] 而是[batch_size, 1, hidden] 1表示新进来的 xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
...
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1) # time 维度拼接
xv = torch.cat([past_key_value[1], xv], dim=1)不然我很难理解 xk的维度[batch_size, seq_len, num_head, dim] 怎么做的合并 不知道我的理解对不对 |
Beta Was this translation helpful? Give feedback.
Answered by
jingyaogong
Nov 24, 2025
Replies: 1 comment 2 replies
-
|
推理时候是 [batch, 1, num_head, dim] |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
至于为什么没有显式的 seq-len=1
贴一下GPT5.1回答吧,不打字了。
你提到的 “推理的时候,forward 的维度是 [batch_size, 1, hidden]” 这一点完全正确,且是 KV Cache 生效的核心前提。
为什么代码里没看到
[batch_size, 1]的截断逻辑?这其实是
transformers库在幕后帮你做的工作。MiniMindForCausalLM继承了GenerationMixin。当调用model.generate()时,transformers库会接管控制流:First Pass (Prefill):
forward接收input_ids维度[1, 10]。past_key_values。Decoding Phase (Generation):
transformers库的GenerationMixin会检查past_key_values是否存在。input_ids = input_ids[:, -1:],即只取最后一个生成的 token。forward,传入input_ids维度真的是[1, 1]。代码维度的逐行验证
让我们带入维度
[batch, 1, num_head, dim]来验证你的猜想:在
model/model_minimind.py的Attention.…