|
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | import torch.distributed as dist |
| 23 | +import torch.nn.functional as F |
23 | 24 | from transformers.modeling_flash_attention_utils import _flash_attention_forward, fa_peft_integration_check |
24 | 25 | from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 |
25 | 26 |
|
|
43 | 44 | def prepare_fa2_from_position_ids( |
44 | 45 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor |
45 | 46 | ): |
46 | | - query = query.view(-1, query.size(-2), query.size(-1)) |
| 47 | + assert position_ids.ndim == 2 # (batch_size, seq_length) |
| 48 | + query = query.contiguous().view(-1, query.size(-2), query.size(-1)) |
47 | 49 | key = key.contiguous().view(-1, key.size(-2), key.size(-1)) |
48 | 50 | value = value.contiguous().view(-1, value.size(-2), value.size(-1)) |
49 | | - position_ids = position_ids.flatten() |
50 | | - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) |
51 | | - cu_seqlens = torch.cat( |
52 | | - ( |
53 | | - indices_q[position_ids == 0], |
54 | | - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), |
55 | | - ) |
56 | | - ) |
| 51 | + position_ids = position_ids.view(-1) |
| 52 | + cu_seqlens = F.pad((position_ids == 0).nonzero().view(-1), (0, 1), value=position_ids.size()) |
57 | 53 | max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope |
58 | | - return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) |
| 54 | + return (query, key, value, (cu_seqlens, cu_seqlens), (max_length, max_length)) |
59 | 55 |
|
60 | 56 |
|
61 | 57 | def _custom_flash_attention_forward( |
@@ -102,7 +98,7 @@ def _custom_flash_attention_forward( |
102 | 98 |
|
103 | 99 | if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): |
104 | 100 | batch_size = query_states.size(0) |
105 | | - query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( |
| 101 | + query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( |
106 | 102 | query_states, key_states, value_states, position_ids |
107 | 103 | ) |
108 | 104 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
@@ -162,16 +158,18 @@ def flash_attention_forward( |
162 | 158 | key = key.transpose(1, 2) |
163 | 159 | value = value.transpose(1, 2) |
164 | 160 |
|
165 | | - # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice |
166 | | - kwargs.pop("is_causal", None) |
| 161 | + # FA2 uses the kwargs value if explicitly passed, otherwise it uses the module attribute |
| 162 | + is_causal = kwargs.pop("is_causal", None) |
| 163 | + if is_causal is None: |
| 164 | + is_causal = getattr(module, "is_causal", True) |
167 | 165 |
|
168 | 166 | attn_output = _custom_flash_attention_forward( |
169 | 167 | query, |
170 | 168 | key, |
171 | 169 | value, |
172 | 170 | attention_mask, |
173 | 171 | query_length=q_len, |
174 | | - is_causal=module.is_causal, |
| 172 | + is_causal=is_causal, |
175 | 173 | dropout=dropout, |
176 | 174 | softmax_scale=scaling, |
177 | 175 | sliding_window=sliding_window, |
|
0 commit comments