Skip to content

Commit ec56682

Browse files
authored
Merge pull request #2 from shenwanxiang/patch-1
2 parents 89aa879 + e3eb3c5 commit ec56682

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

FlashMHA/FlashMHA.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,6 @@ def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, dropout=0.
4040

4141
def forward(self, query, key, value):
4242
qkv = self.Wqkv(query)
43-
q, k, v = rearrange(qkv, 'b s (three h d) -> three b s h d', three=3, h=self.num_heads, d=self.head_dim).unbind(dim=0)
43+
q, k, v = rearrange(qkv, 'b s (three h d) -> three b h s d', three=3, h=self.num_heads, d=self.head_dim).unbind(dim=0)
4444
context = self.inner_attn(q, k, v)
45-
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
45+
return self.out_proj(rearrange(context, 'b h s d -> b s (h d)'))

0 commit comments

Comments
 (0)