|
| 1 | +from typing import Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +from transformers.cache_utils import Cache |
| 5 | +from transformers.modeling_attn_mask_utils import ( |
| 6 | + _prepare_4d_causal_attention_mask_for_sdpa, |
| 7 | +) |
| 8 | +from transformers.modeling_outputs import BaseModelOutputWithPast |
| 9 | +from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, repeat_kv |
| 10 | + |
| 11 | +flash_attn_func = None |
| 12 | +try: |
| 13 | + from flash_attn import flash_attn_func |
| 14 | +except Exception: |
| 15 | + flash_attn_func = None |
| 16 | + |
| 17 | + |
| 18 | +def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| 19 | + x1 = x[..., : x.shape[-1] // 2] |
| 20 | + x2 = x[..., x.shape[-1] // 2 :] |
| 21 | + return torch.cat((-x2, x1), dim=-1) |
| 22 | + |
| 23 | + |
| 24 | +def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1) -> torch.Tensor: |
| 25 | + cos = cos.unsqueeze(0).unsqueeze(unsqueeze_dim) |
| 26 | + sin = sin.unsqueeze(0).unsqueeze(unsqueeze_dim) |
| 27 | + return (x * cos) + (rotate_half(x) * sin) |
| 28 | + |
| 29 | + |
| 30 | +class Qwen2SdpaAttention(Qwen2Attention): |
| 31 | + def forward( |
| 32 | + self, |
| 33 | + hidden_states: torch.Tensor, |
| 34 | + attention_mask: Optional[torch.Tensor] = None, |
| 35 | + position_ids: Optional[torch.LongTensor] = None, |
| 36 | + past_key_value: Optional[Cache] = None, |
| 37 | + output_attentions: bool = False, |
| 38 | + use_cache: bool = False, |
| 39 | + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 40 | + if output_attentions: |
| 41 | + return super().forward( |
| 42 | + hidden_states=hidden_states, |
| 43 | + attention_mask=attention_mask, |
| 44 | + position_ids=position_ids, |
| 45 | + past_key_value=past_key_value, |
| 46 | + output_attentions=output_attentions, |
| 47 | + use_cache=use_cache, |
| 48 | + ) |
| 49 | + |
| 50 | + bsz, q_len, _ = hidden_states.size() |
| 51 | + |
| 52 | + query_states = self.q_proj(hidden_states) |
| 53 | + key_states = self.k_proj(hidden_states) |
| 54 | + value_states = self.v_proj(hidden_states) |
| 55 | + |
| 56 | + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 57 | + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 58 | + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 59 | + |
| 60 | + if past_key_value is not None: |
| 61 | + past_key_value = ( |
| 62 | + torch.cat([past_key_value[0], key_states], dim=2), |
| 63 | + torch.cat([past_key_value[1], value_states], dim=2), |
| 64 | + ) |
| 65 | + else: |
| 66 | + past_key_value = (key_states, value_states) |
| 67 | + |
| 68 | + key_states, value_states = past_key_value |
| 69 | + kv_seq_len = value_states.size(2) |
| 70 | + |
| 71 | + if hasattr(self, "use_retrieval") and self.retrieval_topk > 0: |
| 72 | + retrieval_topk = self.retrieval_topk |
| 73 | + del self.use_retrieval |
| 74 | + del self.retrieval_topk |
| 75 | + |
| 76 | + img_mask = torch.tensor([1 if item == "I" else 0 for item in self.cache_modalities], dtype=torch.long, device=query_states.device) |
| 77 | + cache_lengths = list(self.cache_lengths) + [query_states.size(2)] |
| 78 | + split_key_states = torch.split(repeat_kv(key_states, self.num_key_value_groups), cache_lengths, dim=2)[:-1] |
| 79 | + sub_key_reprs = torch.cat([state.mean(dim=2).flatten(1, 2) for state in split_key_states], dim=0) |
| 80 | + query_repr = query_states.mean(dim=2).flatten(1, 2) |
| 81 | + |
| 82 | + query_subkey_sims = torch.cosine_similarity(query_repr, sub_key_reprs, dim=-1) * img_mask |
| 83 | + topk_indices = torch.topk(query_subkey_sims, min(retrieval_topk, query_subkey_sims.size(0)), dim=-1).indices.tolist() |
| 84 | + |
| 85 | + split_key_states = torch.split(key_states, cache_lengths, dim=2) |
| 86 | + split_value_states = torch.split(value_states, cache_lengths, dim=2) |
| 87 | + |
| 88 | + retrieved_key_states = [] |
| 89 | + retrieved_value_states = [] |
| 90 | + for block_idx in range(len(self.cache_modalities)): |
| 91 | + if block_idx in topk_indices or self.cache_modalities[block_idx] == "T": |
| 92 | + retrieved_key_states.append(split_key_states[block_idx]) |
| 93 | + retrieved_value_states.append(split_value_states[block_idx]) |
| 94 | + retrieved_key_states.append(split_key_states[-1]) |
| 95 | + retrieved_value_states.append(split_value_states[-1]) |
| 96 | + |
| 97 | + key_states = torch.cat(retrieved_key_states, dim=2) |
| 98 | + value_states = torch.cat(retrieved_value_states, dim=2) |
| 99 | + kv_seq_len = key_states.size(2) |
| 100 | + past_key_value = (key_states, value_states) |
| 101 | + if attention_mask is not None: |
| 102 | + attention_mask = attention_mask[..., -kv_seq_len:].contiguous() |
| 103 | + |
| 104 | + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 105 | + query_states = apply_rotary_pos_emb(query_states, cos[-q_len:], sin[-q_len:]) |
| 106 | + key_states = apply_rotary_pos_emb(key_states, cos, sin) |
| 107 | + |
| 108 | + key_states = repeat_kv(key_states, self.num_key_value_groups) |
| 109 | + value_states = repeat_kv(value_states, self.num_key_value_groups) |
| 110 | + |
| 111 | + if attention_mask is not None and attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| 112 | + raise ValueError(f"Attention mask should be {(bsz, 1, q_len, kv_seq_len)}, got {attention_mask.size()}") |
| 113 | + |
| 114 | + if query_states.device.type == "cuda" and attention_mask is not None: |
| 115 | + query_states = query_states.contiguous() |
| 116 | + key_states = key_states.contiguous() |
| 117 | + value_states = value_states.contiguous() |
| 118 | + |
| 119 | + if flash_attn_func is None: |
| 120 | + attn_output = torch.nn.functional.scaled_dot_product_attention( |
| 121 | + query_states, |
| 122 | + key_states, |
| 123 | + value_states, |
| 124 | + attn_mask=attention_mask, |
| 125 | + dropout_p=self.attention_dropout if self.training else 0.0, |
| 126 | + is_causal=self.is_causal and attention_mask is None and q_len > 1, |
| 127 | + ) |
| 128 | + else: |
| 129 | + attn_output = flash_attn_func( |
| 130 | + query_states.transpose(1, 2), |
| 131 | + key_states.transpose(1, 2), |
| 132 | + value_states.transpose(1, 2), |
| 133 | + dropout_p=self.attention_dropout if self.training else 0.0, |
| 134 | + causal=True, |
| 135 | + ).transpose(1, 2) |
| 136 | + |
| 137 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 138 | + attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
| 139 | + attn_output = self.o_proj(attn_output) |
| 140 | + return attn_output, None, past_key_value |
| 141 | + |
| 142 | + |
| 143 | +def cambrian_qwen2_forward( |
| 144 | + self, |
| 145 | + input_ids, |
| 146 | + attention_mask, |
| 147 | + position_ids, |
| 148 | + past_key_values, |
| 149 | + inputs_embeds, |
| 150 | + use_cache, |
| 151 | + output_attentions, |
| 152 | + output_hidden_states, |
| 153 | + return_dict, |
| 154 | +): |
| 155 | + assert use_cache is True |
| 156 | + assert output_attentions is False |
| 157 | + assert output_hidden_states is True |
| 158 | + assert return_dict is True |
| 159 | + |
| 160 | + if inputs_embeds is None: |
| 161 | + inputs_embeds = self.embed_tokens(input_ids) |
| 162 | + |
| 163 | + kv_cache = tuple() |
| 164 | + batch_size, seq_length, _ = inputs_embeds.size() |
| 165 | + past_key_values_length = 0 if past_key_values is None else past_key_values[0][0].size(2) |
| 166 | + |
| 167 | + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| 168 | + attention_mask, |
| 169 | + (batch_size, seq_length), |
| 170 | + inputs_embeds, |
| 171 | + past_key_values_length, |
| 172 | + sliding_window=self.config.sliding_window, |
| 173 | + ) |
| 174 | + |
| 175 | + hidden_states = inputs_embeds |
| 176 | + for idx, decoder_layer in enumerate(self.layers): |
| 177 | + layer_outputs = decoder_layer( |
| 178 | + hidden_states, |
| 179 | + attention_mask=attention_mask, |
| 180 | + position_ids=position_ids, |
| 181 | + past_key_value=past_key_values[idx] if past_key_values is not None else None, |
| 182 | + output_attentions=output_attentions, |
| 183 | + use_cache=use_cache, |
| 184 | + ) |
| 185 | + kv_cache += (layer_outputs[1],) |
| 186 | + hidden_states = layer_outputs[0] |
| 187 | + |
| 188 | + hidden_states = self.norm(hidden_states) |
| 189 | + return BaseModelOutputWithPast( |
| 190 | + last_hidden_state=hidden_states, |
| 191 | + past_key_values=kv_cache, |
| 192 | + hidden_states=hidden_states, |
| 193 | + ) |
0 commit comments