Skip to content
Discussion options

You must be logged in to vote

至于为什么没有显式的 seq-len=1

贴一下GPT5.1回答吧,不打字了。


你提到的 “推理的时候,forward 的维度是 [batch_size, 1, hidden]” 这一点完全正确,且是 KV Cache 生效的核心前提

为什么代码里没看到 [batch_size, 1] 的截断逻辑?

这其实是 transformers 库在幕后帮你做的工作。

MiniMindForCausalLM 继承了 GenerationMixin。当调用 model.generate() 时,transformers 库会接管控制流:

  1. First Pass (Prefill):

    • 输入完整的 Prompt(例如长度为 10)。
    • forward 接收 input_ids 维度 [1, 10]
    • 计算出所有 10 个 token 的 KV,存入 past_key_values
  2. Decoding Phase (Generation):

    • transformers 库的 GenerationMixin 会检查 past_key_values 是否存在。
    • 关键点:它的默认行为(在较新版本中)会自动执行 input_ids = input_ids[:, -1:],即只取最后一个生成的 token。
    • 所以,第 2 次及以后的 forward,传入 input_ids 维度真的是 [1, 1]

代码维度的逐行验证

让我们带入维度 [batch, 1, num_head, dim] 来验证你的猜想:

model/model_minimind.pyAttention.…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jingyaogong
Comment options

Answer selected by ciaoyizhen
@ciaoyizhen
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants