Skip to content

Commit

Permalink
Merge pull request #2 from shenwanxiang/patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez authored Feb 5, 2024
2 parents 89aa879 + e3eb3c5 commit ec56682
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions FlashMHA/FlashMHA.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, dropout=0.

def forward(self, query, key, value):
qkv = self.Wqkv(query)
q, k, v = rearrange(qkv, 'b s (three h d) -> three b s h d', three=3, h=self.num_heads, d=self.head_dim).unbind(dim=0)
q, k, v = rearrange(qkv, 'b s (three h d) -> three b h s d', three=3, h=self.num_heads, d=self.head_dim).unbind(dim=0)
context = self.inner_attn(q, k, v)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
return self.out_proj(rearrange(context, 'b h s d -> b s (h d)'))

0 comments on commit ec56682

Please sign in to comment.