|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
| 16 | +from typing import Optional |
16 | 17 |
|
17 | 18 | import numpy as np |
18 | 19 | import paddle |
19 | 20 |
|
20 | 21 | from .tools import get_env_device |
21 | 22 |
|
22 | 23 |
|
23 | | -def _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, dtype): |
| 24 | +def _gen_from_sparse_attn_mask_indices( |
| 25 | + attn_mask_startend_row_indices: paddle.Tensor, |
| 26 | + dtype: Optional[paddle.dtype] = paddle.bfloat16, |
| 27 | + is_causal: Optional[bool] = None, |
| 28 | +): |
24 | 29 | """ |
25 | | - Recover 4-D attention_mask from attn_mask_start_row_indices. |
| 30 | + Recover 4-D attention_mask from attn_mask_startend_row_indices. |
26 | 31 |
|
27 | 32 | Args: |
28 | | - attn_mask_start_row_indices (paddle.Tensor): The start row indices for the attention mask. |
29 | | - dtype (str): The data type of the tensor. |
| 33 | + attn_mask_startend_row_indices (paddle.Tensor): |
| 34 | + A column-wise sparse attention mask row indices tensor. |
| 35 | + A 4-D tensor with shape [batch_size, k_num_heads, k_seq_len, {1, 2, 4}]. |
| 36 | + The dtype must be int32. k_num_heads can be 1 or the same as key's num_heads. When num_heads is 1, it will be broadcast to match key's num_heads. |
| 37 | + Depending on the value of the causal parameter, startend_row_indices can take different shapes and meanings. |
| 38 | +
|
| 39 | + - When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 1], |
| 40 | + indicating unidirectional attention. The value represents the starting row index of the left |
| 41 | + lower triangular mask in the dense mask. The value startend_row_indices[..., 0] indicates that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked. |
| 42 | + - When `causal=True` and the shape is [batch_size, k_num_heads, k_seq_len, 2], |
| 43 | + indicating unidirectional attention. The values represent the starting and ending row indices of |
| 44 | + the left lower triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1]-th row (exclusive) will be masked. |
| 45 | + - When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 2], |
| 46 | + indicating bidirectional attention. The values represent the starting row index of the left |
| 47 | + lower triangular mask and the ending row index of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:2] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 1]-th row upwards (exclusive) will be masked. |
| 48 | + - When `causal=False` and the shape is [batch_size, k_num_heads, k_seq_len, 4] , |
| 49 | + indicating bidirectional attention. The values represent the start and end row indices of the |
| 50 | + left lower triangular mask and the start and end row indices of the right upper triangular mask in the dense mask. The values startend_row_indices[..., 0:4] in startend_row_indices indicate that elements in the lower left triangle of the attention score matrix starting from the startend_row_indices[..., 0]-th row downwards (inclusive) but above the startend_row_indices[..., 1] row (exclusive) will be masked, and elements in the upper right triangle starting from the startend_row_indices[..., 2]-th row downwards (inclusive) but above the startend_row_indices[..., 3] row (exclusive) will be masked. |
| 51 | + dtype (paddle.dtype): The data type of the tensor. |
| 52 | + causal (bool): Whether to enable causal mode. |
30 | 53 |
|
31 | 54 | Returns: |
32 | | - paddle.Tensor: The dense attention mask recovered from attn_mask_start_row_indices. |
| 55 | + paddle.Tensor: The dense attention mask recovered from attn_mask_startend_row_indices. |
33 | 56 | """ |
34 | | - batch_size, _, max_seq_len, _ = attn_mask_start_row_indices.shape |
35 | | - base = paddle.arange(max_seq_len, dtype="int32").unsqueeze(1).expand([batch_size, -1, max_seq_len]).unsqueeze(1) |
36 | | - mask_indices = attn_mask_start_row_indices |
37 | 57 |
|
38 | | - tril = paddle.tril( |
39 | | - paddle.ones([max_seq_len, max_seq_len], dtype="bool").expand([batch_size, 1, max_seq_len, max_seq_len]) |
40 | | - ) |
41 | | - attention_mask = paddle.logical_and(base < mask_indices, tril) |
| 58 | + if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.ndim == 3: |
| 59 | + attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(-1) |
| 60 | + if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 1: |
| 61 | + is_causal = True |
| 62 | + if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.shape[-1] == 4: |
| 63 | + is_causal = False |
| 64 | + |
| 65 | + if is_causal is None: |
| 66 | + raise ValueError( |
| 67 | + "The `is_causal` argument must be specified when recovering the dense attention mask from the column-wise sparse attention mask row indices." |
| 68 | + ) |
| 69 | + |
| 70 | + batch_size, num_head, seq_len, bound_num = attn_mask_startend_row_indices.shape |
| 71 | + has_end = (is_causal and bound_num == 2) or ((not is_causal) and bound_num == 4) |
| 72 | + |
| 73 | + attention_mask = paddle.ones([seq_len, seq_len], dtype="bool").expand([batch_size, num_head, seq_len, seq_len]) |
| 74 | + if is_causal: |
| 75 | + attention_mask = paddle.tril(attention_mask) |
| 76 | + |
| 77 | + base = paddle.arange(seq_len, dtype="int32").unsqueeze(1).expand([batch_size, num_head, -1, seq_len]) |
| 78 | + |
| 79 | + # [batch_size, k_num_heads, k_seq_len, {1, 2, 4}] -> [batch_size, k_num_heads, {1, 2, 4}, k_seq_len] |
| 80 | + mask_indices = attn_mask_startend_row_indices.transpose([0, 1, 3, 2]) |
| 81 | + |
| 82 | + downstart_mask_indices = mask_indices[:, :, 0, :] |
| 83 | + downstart_mask_indices = downstart_mask_indices.expand([batch_size, num_head, seq_len, -1]) |
| 84 | + lower_tri = base < downstart_mask_indices |
| 85 | + if has_end: |
| 86 | + downend_mask_indices = mask_indices[:, :, 1, :] |
| 87 | + downend_mask_indices = downend_mask_indices.expand([batch_size, num_head, seq_len, -1]) |
| 88 | + lower_tri = paddle.logical_or(lower_tri, base >= downend_mask_indices) |
| 89 | + |
| 90 | + attention_mask = paddle.logical_and(attention_mask, lower_tri) |
| 91 | + |
| 92 | + if not is_causal: |
| 93 | + if has_end: |
| 94 | + upstart_mask_indices = mask_indices[:, :, 2, :] |
| 95 | + upstart_mask_indices = upstart_mask_indices.expand([batch_size, num_head, seq_len, -1]) |
| 96 | + upend_mask_indices = mask_indices[:, :, 3, :] |
| 97 | + upend_mask_indices = upend_mask_indices.expand([batch_size, num_head, seq_len, -1]) |
| 98 | + upper_tri = base >= upend_mask_indices |
| 99 | + upper_tri = paddle.logical_or(upper_tri, base < upstart_mask_indices) |
| 100 | + else: |
| 101 | + upend_mask_indices = mask_indices[:, :, 1, :] |
| 102 | + upend_mask_indices = upend_mask_indices.expand([batch_size, num_head, seq_len, -1]) |
| 103 | + upper_tri = base >= upend_mask_indices |
| 104 | + |
| 105 | + attention_mask = paddle.logical_and(attention_mask, upper_tri) |
| 106 | + |
42 | 107 | attention_mask = paddle.scale( |
43 | 108 | x=attention_mask.astype(dtype), |
44 | 109 | scale=1000000.0, |
|
0 commit comments