MTP的输入不应该是input_ids后移head_index+1位嘛?假如main model输入的是t1 t2 t3 t4,第一个mtp头应该是 t2 t3 t4 t5,代码里没看到有这样的操作,直接输入的还是t1 t2 t3 t4