Skip to content

Commit 264caa1

Browse files
committed
support generate dense attn mask from sparse one
1 parent 2ce669c commit 264caa1

File tree

4 files changed

+101
-22
lines changed

4 files changed

+101
-22
lines changed

paddleformers/nn/attention/eager_attention.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import paddle
1818
import paddle.nn as nn
1919

20+
from ...utils.masking_utils import _gen_from_sparse_attn_mask_indices
2021
from .utils import repeat_kv
2122

2223

@@ -37,15 +38,26 @@ def eager_attention_forward(
3738
key = repeat_kv(key, num_key_value_groups)
3839
value = repeat_kv(value, num_key_value_groups)
3940

41+
if attention_mask is None and kwargs.get("attn_mask_startend_row_indices", None) is not None:
42+
attn_mask_startend_row_indices = kwargs["attn_mask_startend_row_indices"]
43+
if attn_mask_startend_row_indices.ndim == 3:
44+
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1)
45+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 1:
46+
is_causal = True
47+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 4:
48+
is_causal = False
49+
50+
attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_startend_row_indices, query.dtype, is_causal)
51+
4052
perm = [0, 2, 1, 3] # b l h d -> b h l d
4153
query = paddle.transpose(x=query, perm=perm)
4254
key = paddle.transpose(x=key, perm=perm)
4355
value = paddle.transpose(x=value, perm=perm)
4456

45-
attn_weights = paddle.matmul(query, key.transpose([0, 1, 3, 2])) * scaling
57+
attn_weights = paddle.matmul(x=query * scaling, y=key, transpose_y=True)
4658
if attention_mask is not None:
47-
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
48-
attn_weights = attn_weights + causal_mask
59+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
60+
attn_weights = attn_weights + attention_mask
4961

5062
if sink is not None:
5163
sink = sink.reshape([1, -1, 1, 1]).expand([query.shape[0], -1, query.shape[-2], -1])
@@ -54,7 +66,7 @@ def eager_attention_forward(
5466
scores = probs[..., :-1] # we drop the sink here
5567
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
5668
else:
57-
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query.dtype)
69+
attn_weights = nn.functional.softmax(attn_weights, axis=-1, dtype=query.dtype)
5870
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
5971

6072
attn_output = paddle.matmul(attn_weights, value) # b h l l @ b h l d -> b h l d

paddleformers/nn/attention/flashmask_attention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@ def flashmask_attention_forward(
4040
is_causal = True
4141
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 4:
4242
is_causal = False
43-
43+
4444
if is_causal is None:
45-
raise ValueError(
46-
f"The `is_causal` argument must be specified when using the Flash Mask Attention kernel."
47-
)
45+
raise ValueError("The `is_causal` argument must be specified when using the Flash Mask Attention.")
4846

4947
if sink is None:
5048
out = flashmask_attention(

paddleformers/nn/attention/sdpa_attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ def sdpa_attention_forward(
3838
if is_causal is None and attn_mask_startend_row_indices is None:
3939
is_causal = query.shape[1] > 1 and attention_mask is None and getattr(module, "is_causal", True)
4040
elif attn_mask_startend_row_indices is not None:
41-
is_causal = False
4241
if attn_mask_startend_row_indices.ndim == 3:
4342
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1)
44-
attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_startend_row_indices, query.dtype)
43+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 1:
44+
is_causal = True
45+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 4:
46+
is_causal = False
47+
48+
attention_mask = _gen_from_sparse_attn_mask_indices(attn_mask_startend_row_indices, query.dtype, is_causal)
4549

4650
if sink is None:
4751
attn_output = nn.functional.scaled_dot_product_attention(

paddleformers/utils/masking_utils.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,97 @@
1313
# limitations under the License.
1414

1515
import os
16+
from typing import Optional
1617

1718
import numpy as np
1819
import paddle
1920

2021
from .tools import get_env_device
2122

2223

23-
def _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, dtype):
24+
def _gen_from_sparse_attn_mask_indices(
25+
attn_mask_startend_row_indices: paddle.Tensor,
26+
dtype: Optional[paddle.dtype] = paddle.bfloat16,
27+
is_causal: Optional[bool] = None,
28+
):
2429
"""
25-
Recover 4-D attention_mask from attn_mask_start_row_indices.
30+
Recover 4-D attention_mask from attn_mask_startend_row_indices.
2631
2732
Args:
28-
attn_mask_start_row_indices (paddle.Tensor): The start row indices for the attention mask.
29-
dtype (str): The data type of the tensor.
33+
attn_mask_startend_row_indices (paddle.Tensor):
34+
A column-wise sparse attention mask row indices tensor.
35+
A 4-D tensor with shape [batch_size, k_num_heads, k_seq_len, {1, 2, 4}].
36+
The dtype must be int32. k_num_heads can be 1 or the same as key's num_heads. When num_heads is 1, it will be broadcast to match key's num_heads.
37+
Depending on the value of the causal parameter, startend_row_indices can take different shapes and meanings.
38+
39+
- When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 1],
40+
indicating unidirectional attention. The value represents the starting row index of the left
41+
lower triangular mask in the dense mask. The value startend_row_indices[..., 0] indicates that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked.
42+
- When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 2],
43+
indicating unidirectional attention. The values represent the starting and ending row indices of
44+
the left lower triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1]-th row (exclusive) will be masked.
45+
- When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 2],
46+
indicating bidirectional attention. The values represent the starting row index of the left
47+
lower triangular mask and the ending row index of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 1]-th row upwards (exclusive) will be masked.
48+
- When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 4] ,
49+
indicating bidirectional attention. The values represent the start and end row indices of the
50+
left lower triangular mask and the start and end row indices of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:4] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1] row (exclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 2]-th row downwards (inclusive) but above the startend_row_indices[..., 3] row (exclusive) will be masked.
51+
dtype (paddle.dtype): The data type of the tensor.
52+
causal (bool): Whether to enable causal mode.
3053
3154
Returns:
32-
paddle.Tensor: The dense attention mask recovered from attn_mask_start_row_indices.
55+
paddle.Tensor: The dense attention mask recovered from attn_mask_startend_row_indices.
3356
"""
34-
batch_size, _, max_seq_len, _ = attn_mask_start_row_indices.shape
35-
base = paddle.arange(max_seq_len, dtype="int32").unsqueeze(1).expand([batch_size, -1, max_seq_len]).unsqueeze(1)
36-
mask_indices = attn_mask_start_row_indices
3757

38-
tril = paddle.tril(
39-
paddle.ones([max_seq_len, max_seq_len], dtype="bool").expand([batch_size, 1, max_seq_len, max_seq_len])
40-
)
41-
attention_mask = paddle.logical_and(base < mask_indices, tril)
58+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.ndim == 3:
59+
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1)
60+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 1:
61+
is_causal = True
62+
if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 4:
63+
is_causal = False
64+
65+
if is_causal is None:
66+
raise ValueError(
67+
"The `is_causal` argument must be specified when recovering the dense attention mask from the column-wise sparse attention mask row indices."
68+
)
69+
70+
batch_size, num_head, seq_len, bound_num = attn_mask_startend_row_indices.shape
71+
has_end = (is_causal and bound_num == 2) or ((not is_causal) and bound_num == 4)
72+
73+
attention_mask = paddle.ones([seq_len, seq_len], dtype="bool").expand([batch_size, num_head, seq_len, seq_len])
74+
if is_causal:
75+
attention_mask = paddle.tril(attention_mask)
76+
77+
base = paddle.arange(seq_len, dtype="int32").unsqueeze(1).expand([batch_size, num_head, -1, seq_len])
78+
79+
# [batch_size, k_num_heads, k_seq_len, {1, 2, 4}] -> [batch_size, k_num_heads, {1, 2, 4}, k_seq_len]
80+
mask_indices = attn_mask_startend_row_indices.transpose([0, 1, 3, 2])
81+
82+
downstart_mask_indices = mask_indices[:, :, 0, :]
83+
downstart_mask_indices = downstart_mask_indices.expand([batch_size, num_head, seq_len, -1])
84+
lower_tri = base < downstart_mask_indices
85+
if has_end:
86+
downend_mask_indices = mask_indices[:, :, 1, :]
87+
downend_mask_indices = downend_mask_indices.expand([batch_size, num_head, seq_len, -1])
88+
lower_tri = paddle.logical_or(lower_tri, base >= downend_mask_indices)
89+
90+
attention_mask = paddle.logical_and(attention_mask, lower_tri)
91+
92+
if not is_causal:
93+
if has_end:
94+
upstart_mask_indices = mask_indices[:, :, 2, :]
95+
upstart_mask_indices = upstart_mask_indices.expand([batch_size, num_head, seq_len, -1])
96+
upend_mask_indices = mask_indices[:, :, 3, :]
97+
upend_mask_indices = upend_mask_indices.expand([batch_size, num_head, seq_len, -1])
98+
upper_tri = base >= upend_mask_indices
99+
upper_tri = paddle.logical_or(upper_tri, base < upstart_mask_indices)
100+
else:
101+
upend_mask_indices = mask_indices[:, :, 1, :]
102+
upend_mask_indices = upend_mask_indices.expand([batch_size, num_head, seq_len, -1])
103+
upper_tri = base >= upend_mask_indices
104+
105+
attention_mask = paddle.logical_and(attention_mask, upper_tri)
106+
42107
attention_mask = paddle.scale(
43108
x=attention_mask.astype(dtype),
44109
scale=1000000.0,

0 commit comments

Comments
 (0)