Skip to content

Commit a2e0a70

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
Batch matmul fast path in MHAWithCache (facebookresearch#449)
Summary: Pull Request resolved: facebookresearch#449 When doing self attention, an optimization is to combine the Q, K, V input projection matrices and do a single matmul, instead of 3. Adding this optimization in MHAWithCache. Differential Revision: D48418780 fbshipit-source-id: 0501341832910bf90a7ea1cc902b98f0760548ab
1 parent 951a452 commit a2e0a70

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

torchmultimodal/modules/layers/multi_head_attention.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,24 @@
77
from typing import NamedTuple, Optional, Tuple, Union
88

99
import torch
10-
1110
import torch.nn.functional as F
1211
from torch import nn, Tensor
12+
from torch.nn import Module
1313

1414

1515
class MHAWithCacheOutput(NamedTuple):
1616
attn_output: Tensor
1717
past_key_value: Tuple[Tensor, Tensor]
1818

1919

20+
# def _batched_input_proj(
21+
# query: Tensor, input_proj: Module
22+
# ) -> Tuple[Tensor, Tensor, Tensor]:
23+
# projected_query = input_proj(query)
24+
# query, key, value = projected_query.chunk(3, dim=-1)
25+
# return query, key, value
26+
27+
2028
class MultiHeadSelfAttention(nn.Module):
2129
"""
2230
Multihead self attention.
@@ -93,6 +101,7 @@ class MultiHeadAttentionWithCache(nn.Module):
93101
dropout (float): dropout rate
94102
add_bias (bool): if true, adds a learnable bias to query, key, value.
95103
Defaults to True.
104+
is_self_attention
96105
"""
97106

98107
def __init__(
@@ -102,12 +111,17 @@ def __init__(
102111
num_heads: int,
103112
dropout: float = 0.0,
104113
add_bias: bool = True,
114+
is_self_attention: bool = False,
105115
) -> None:
106116
super().__init__()
107117
self.num_heads = num_heads
118+
self.is_self_attention = is_self_attention
119+
# Note: defining qkv and input_proj regardless of is_self_attention
120+
# due to TorchScript compatibility.
108121
self.q_proj = nn.Linear(dim_q, dim_q, bias=add_bias)
109122
self.k_proj = nn.Linear(dim_kv, dim_q, bias=add_bias)
110123
self.v_proj = nn.Linear(dim_kv, dim_q, bias=add_bias)
124+
self.input_proj_self_attn = nn.Linear(dim_q, 3 * dim_q, bias=add_bias)
111125
self.output_proj = nn.Linear(dim_q, dim_q)
112126
self.dropout = dropout
113127

@@ -144,9 +158,13 @@ def forward(
144158
bsz = query.size(0)
145159
embed_dim = query.size(-1)
146160
head_dim = embed_dim // self.num_heads
147-
query = self.q_proj(query)
148-
key = self.k_proj(key)
149-
value = self.v_proj(value)
161+
if self.is_self_attention:
162+
projected_query = self.input_proj_self_attn(query)
163+
query, key, value = projected_query.chunk(3, dim=-1)
164+
else:
165+
query = self.q_proj(query)
166+
key = self.k_proj(key)
167+
value = self.v_proj(value)
150168

151169
# bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
152170
query = query.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)

0 commit comments

Comments
 (0)