Skip to content

Commit

Permalink
fix(model): Add support for bias wqkv tensor in Attention
Browse files Browse the repository at this point in the history
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Oct 4, 2024
1 parent 12b7d16 commit bbea338
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args):
# wv = state_dict.pop(prefix + "wv.weight")
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
q_size = self.n_heads * self.head_dim
kv_size = self.n_local_heads * self.head_dim
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
state_dict[prefix + "wq.weight"] = wq
state_dict[prefix + "wk.weight"] = wk
state_dict[prefix + "wv.weight"] = wv
for tensor_suffix in ["weight", "bias"]:
wqkv_key = f"{prefix}wqkv.{tensor_suffix}"
if wqkv_key in state_dict:
wqkv = state_dict.pop(wqkv_key)
q_size = self.n_heads * self.head_dim
kv_size = self.n_local_heads * self.head_dim
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
state_dict[f"{prefix}wq.{tensor_suffix}"] = wq
state_dict[f"{prefix}wk.{tensor_suffix}"] = wk
state_dict[f"{prefix}wv.{tensor_suffix}"] = wv

return

Expand Down

0 comments on commit bbea338

Please sign in to comment.