Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions tests/modules/layers/test_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchmultimodal.modules.layers.multi_head_attention import (
MultiHeadAttentionWithCache,
MultiHeadSelfAttention,
MultiHeadSelfAttentionWithCache,
)


Expand Down Expand Up @@ -103,6 +104,13 @@ def multi_head_self_attn_use_cache(self, dim_q):
mha.eval()
return mha

@pytest.fixture
def multi_head_self_attn_module_with_cache(self, dim_q):
mha = MultiHeadSelfAttentionWithCache(dim_q, num_heads=2)
init_weights_with_constant(mha)
mha.eval()
return mha

@pytest.fixture
def multi_head_cross_attn(self, dim_q, dim_kv):
mha = MultiHeadAttentionWithCache(dim_q, dim_kv, num_heads=2)
Expand All @@ -117,16 +125,7 @@ def multi_head_cross_attn_without_bias(self, dim_q, dim_kv):
mha.eval()
return mha

def test_multi_head_self_attention_use_cache(
self,
multi_head_self_attn_use_cache,
current_key_value,
past_key_value,
q,
):
actual = multi_head_self_attn_use_cache(
q, q, q, past_key_value=(past_key_value, past_key_value), use_cache=True
)
def _assert_mha_self_attn_equal(self, actual, past_key_value, current_key_value):
expected = torch.tensor(
[
[
Expand All @@ -138,6 +137,7 @@ def test_multi_head_self_attention_use_cache(
)
assert_expected(actual.attn_output, expected, rtol=0, atol=1e-4)
# Check that the cache is properly updated
torch.cat([past_key_value, current_key_value], dim=2)
assert_expected(
actual.past_key_value[0],
torch.cat([past_key_value, current_key_value], dim=2),
Expand All @@ -147,6 +147,59 @@ def test_multi_head_self_attention_use_cache(
torch.cat([past_key_value, current_key_value], dim=2),
)

def test_multi_head_self_attention_use_cache(
self,
multi_head_self_attn_use_cache,
current_key_value,
past_key_value,
q,
):
actual = multi_head_self_attn_use_cache(
q, q, q, past_key_value=(past_key_value, past_key_value), use_cache=True
)
self._assert_mha_self_attn_equal(actual, past_key_value, current_key_value)

def test_multi_head_self_attn_module_with_cache(
self,
multi_head_self_attn_module_with_cache,
current_key_value,
past_key_value,
q,
):
actual = multi_head_self_attn_module_with_cache(
q, past_key_value=(past_key_value, past_key_value), use_cache=True
)
self._assert_mha_self_attn_equal(actual, past_key_value, current_key_value)

def test_multi_head_attention_with_cache_modules_equal(
self,
multi_head_self_attn_use_cache,
multi_head_self_attn_module_with_cache,
current_key_value,
past_key_value,
q,
):
mha_with_cache_cls_output = multi_head_self_attn_use_cache(
q, q, q, past_key_value=(past_key_value, past_key_value), use_cache=True
)
sa_with_cache_cls_output = multi_head_self_attn_module_with_cache(
q, past_key_value=(past_key_value, past_key_value), use_cache=True
)
assert_expected(
mha_with_cache_cls_output.attn_output,
sa_with_cache_cls_output.attn_output,
rtol=0,
atol=1e-4,
)
assert_expected(
mha_with_cache_cls_output.past_key_value[0],
sa_with_cache_cls_output.past_key_value[0],
)
assert_expected(
mha_with_cache_cls_output.past_key_value[1],
sa_with_cache_cls_output.past_key_value[1],
)

def test_multi_head_cross_attention(self, multi_head_cross_attn, q, kv):
actual = multi_head_cross_attn(q, kv, kv)
expected = torch.tensor(
Expand Down
86 changes: 86 additions & 0 deletions torchmultimodal/modules/layers/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,92 @@ def forward(
return attn_out


class MultiHeadSelfAttentionWithCache(nn.Module):
"""
MultiHeadAttention module for self-attention(SA). Similar to MultiHeadAttentionWithCache,
but only supports self attention and uses a fast path where the query, key, and value projections
are batched into a single matrix multiplication as opposed to three separate matmuls.
This class supports a cache mechanism for decoders to store previous states through
"past_key_value".

Args:
embed_dim (int): query, key, value embedding dimension
num_heads (int): number of attention heads
dropout (float): dropout rate
add_bias (bool): if ``True``, adds a learnable bias to query, key, value input projection matrix.
Defaults to True.
"""

def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
add_bias: bool = True,
) -> None:
super().__init__()
self.num_heads = num_heads
self.input_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=add_bias)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = dropout

def forward(
self,
query: Tensor,
attn_mask: Optional[Tensor] = None,
past_key_value: Optional[Tuple[Tensor, Tensor]] = None,
is_causal: bool = False,
use_cache: bool = False,
) -> Union[Tensor, MHAWithCacheOutput]:
"""
Args:
query (Tensor): input query of shape bsz x target_seq_len x embed_dim
attn_mask (optional Tensor): Attention mask of shape bsz x num_heads x target_seq_len x source_seq_len.
Note that the num_heads dimension can equal 1 and the mask will be broadcasted to all heads.
Two types of masks are supported. A boolean mask where a value of True
indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
past_key_value (optional tuple of tensors): cached key and value with the same shape of key, value inputs.
The size of tuple should be 2, where the first entry is for cached key and second entry is for cached value.
is_causal (bool): If true, does causal attention masking, attn_mask should be set to None if this is set to True
is_causal is a hint that the mask is a causal mask, providing incorrect hints can result in incorrect execution.
use_cache (bool): whether to use cache for key and value tensors

Returns:
if use_cache is off, return attn_output tensor of shape bsz x seq_len x embed_dim;
otherwise return namedtuple with attn_output, cached key and value.
"""
bsz = query.size(0)
embed_dim = query.size(-1)
head_dim = embed_dim // self.num_heads
projected_query = self.input_proj(query)
query, key, value = projected_query.chunk(3, dim=-1)

# bsz x seq_len x embed_dim => bsz x num_heads x seq_len x head_dim
query = query.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
if key.size(0) != bsz or value.size(0) != bsz:
raise ValueError("key and value should have the same bsz as query.")
key = key.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(bsz, -1, self.num_heads, head_dim).transpose(1, 2)

# concat key value with cached values
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)

# turn off causal attention inside scaled_dot_product_attention, we handle it separately with attn_mask.
attn = F.scaled_dot_product_attention(
query, key, value, attn_mask, self.dropout, is_causal
)
attn = attn.transpose(1, 2).reshape(bsz, -1, embed_dim)

# add dense layer after attention
attn_output = self.output_proj(attn)
if use_cache:
return MHAWithCacheOutput(attn_output, (key, value))
return attn_output


class MultiHeadAttentionWithCache(nn.Module):
"""
MultiHeadAttention module for both self-attention(SA) and cross-attention(CA).
Expand Down