Skip to content

Commit 2c08ee5

Browse files
authored
Add CambriansVSR/VSC/VSCStreaming model integrations (#1268)
* feat: add cambrians_vsr model * fix: qwen kv cache management for cambrians varients * feat: add cambrians_vsc $ cambrians_vsc_streaming model * chore: format cambrians model files
1 parent f8279ad commit 2c08ee5

5 files changed

Lines changed: 1583 additions & 0 deletions

File tree

lmms_eval/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
"claude": "Claude",
2929
"cogvlm2": "CogVLM2",
3030
"cambrians": "CambrianS",
31+
"cambrians_vsc": "CambriansVSC",
32+
"cambrians_vsc_streaming": "CambriansVSCStreaming",
33+
"cambrians_vsr": "CambriansVSR",
3134
"dummy": "Dummy",
3235
"egogpt": "EgoGPT",
3336
"from_log": "FromLog",
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
from transformers.cache_utils import Cache
5+
from transformers.modeling_attn_mask_utils import (
6+
_prepare_4d_causal_attention_mask_for_sdpa,
7+
)
8+
from transformers.modeling_outputs import BaseModelOutputWithPast
9+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, repeat_kv
10+
11+
flash_attn_func = None
12+
try:
13+
from flash_attn import flash_attn_func
14+
except Exception:
15+
flash_attn_func = None
16+
17+
18+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
19+
x1 = x[..., : x.shape[-1] // 2]
20+
x2 = x[..., x.shape[-1] // 2 :]
21+
return torch.cat((-x2, x1), dim=-1)
22+
23+
24+
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1) -> torch.Tensor:
25+
cos = cos.unsqueeze(0).unsqueeze(unsqueeze_dim)
26+
sin = sin.unsqueeze(0).unsqueeze(unsqueeze_dim)
27+
return (x * cos) + (rotate_half(x) * sin)
28+
29+
30+
class Qwen2SdpaAttention(Qwen2Attention):
31+
def forward(
32+
self,
33+
hidden_states: torch.Tensor,
34+
attention_mask: Optional[torch.Tensor] = None,
35+
position_ids: Optional[torch.LongTensor] = None,
36+
past_key_value: Optional[Cache] = None,
37+
output_attentions: bool = False,
38+
use_cache: bool = False,
39+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40+
if output_attentions:
41+
return super().forward(
42+
hidden_states=hidden_states,
43+
attention_mask=attention_mask,
44+
position_ids=position_ids,
45+
past_key_value=past_key_value,
46+
output_attentions=output_attentions,
47+
use_cache=use_cache,
48+
)
49+
50+
bsz, q_len, _ = hidden_states.size()
51+
52+
query_states = self.q_proj(hidden_states)
53+
key_states = self.k_proj(hidden_states)
54+
value_states = self.v_proj(hidden_states)
55+
56+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
57+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
58+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
59+
60+
if past_key_value is not None:
61+
past_key_value = (
62+
torch.cat([past_key_value[0], key_states], dim=2),
63+
torch.cat([past_key_value[1], value_states], dim=2),
64+
)
65+
else:
66+
past_key_value = (key_states, value_states)
67+
68+
key_states, value_states = past_key_value
69+
kv_seq_len = value_states.size(2)
70+
71+
if hasattr(self, "use_retrieval") and self.retrieval_topk > 0:
72+
retrieval_topk = self.retrieval_topk
73+
del self.use_retrieval
74+
del self.retrieval_topk
75+
76+
img_mask = torch.tensor([1 if item == "I" else 0 for item in self.cache_modalities], dtype=torch.long, device=query_states.device)
77+
cache_lengths = list(self.cache_lengths) + [query_states.size(2)]
78+
split_key_states = torch.split(repeat_kv(key_states, self.num_key_value_groups), cache_lengths, dim=2)[:-1]
79+
sub_key_reprs = torch.cat([state.mean(dim=2).flatten(1, 2) for state in split_key_states], dim=0)
80+
query_repr = query_states.mean(dim=2).flatten(1, 2)
81+
82+
query_subkey_sims = torch.cosine_similarity(query_repr, sub_key_reprs, dim=-1) * img_mask
83+
topk_indices = torch.topk(query_subkey_sims, min(retrieval_topk, query_subkey_sims.size(0)), dim=-1).indices.tolist()
84+
85+
split_key_states = torch.split(key_states, cache_lengths, dim=2)
86+
split_value_states = torch.split(value_states, cache_lengths, dim=2)
87+
88+
retrieved_key_states = []
89+
retrieved_value_states = []
90+
for block_idx in range(len(self.cache_modalities)):
91+
if block_idx in topk_indices or self.cache_modalities[block_idx] == "T":
92+
retrieved_key_states.append(split_key_states[block_idx])
93+
retrieved_value_states.append(split_value_states[block_idx])
94+
retrieved_key_states.append(split_key_states[-1])
95+
retrieved_value_states.append(split_value_states[-1])
96+
97+
key_states = torch.cat(retrieved_key_states, dim=2)
98+
value_states = torch.cat(retrieved_value_states, dim=2)
99+
kv_seq_len = key_states.size(2)
100+
past_key_value = (key_states, value_states)
101+
if attention_mask is not None:
102+
attention_mask = attention_mask[..., -kv_seq_len:].contiguous()
103+
104+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
105+
query_states = apply_rotary_pos_emb(query_states, cos[-q_len:], sin[-q_len:])
106+
key_states = apply_rotary_pos_emb(key_states, cos, sin)
107+
108+
key_states = repeat_kv(key_states, self.num_key_value_groups)
109+
value_states = repeat_kv(value_states, self.num_key_value_groups)
110+
111+
if attention_mask is not None and attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
112+
raise ValueError(f"Attention mask should be {(bsz, 1, q_len, kv_seq_len)}, got {attention_mask.size()}")
113+
114+
if query_states.device.type == "cuda" and attention_mask is not None:
115+
query_states = query_states.contiguous()
116+
key_states = key_states.contiguous()
117+
value_states = value_states.contiguous()
118+
119+
if flash_attn_func is None:
120+
attn_output = torch.nn.functional.scaled_dot_product_attention(
121+
query_states,
122+
key_states,
123+
value_states,
124+
attn_mask=attention_mask,
125+
dropout_p=self.attention_dropout if self.training else 0.0,
126+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
127+
)
128+
else:
129+
attn_output = flash_attn_func(
130+
query_states.transpose(1, 2),
131+
key_states.transpose(1, 2),
132+
value_states.transpose(1, 2),
133+
dropout_p=self.attention_dropout if self.training else 0.0,
134+
causal=True,
135+
).transpose(1, 2)
136+
137+
attn_output = attn_output.transpose(1, 2).contiguous()
138+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
139+
attn_output = self.o_proj(attn_output)
140+
return attn_output, None, past_key_value
141+
142+
143+
def cambrian_qwen2_forward(
144+
self,
145+
input_ids,
146+
attention_mask,
147+
position_ids,
148+
past_key_values,
149+
inputs_embeds,
150+
use_cache,
151+
output_attentions,
152+
output_hidden_states,
153+
return_dict,
154+
):
155+
assert use_cache is True
156+
assert output_attentions is False
157+
assert output_hidden_states is True
158+
assert return_dict is True
159+
160+
if inputs_embeds is None:
161+
inputs_embeds = self.embed_tokens(input_ids)
162+
163+
kv_cache = tuple()
164+
batch_size, seq_length, _ = inputs_embeds.size()
165+
past_key_values_length = 0 if past_key_values is None else past_key_values[0][0].size(2)
166+
167+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
168+
attention_mask,
169+
(batch_size, seq_length),
170+
inputs_embeds,
171+
past_key_values_length,
172+
sliding_window=self.config.sliding_window,
173+
)
174+
175+
hidden_states = inputs_embeds
176+
for idx, decoder_layer in enumerate(self.layers):
177+
layer_outputs = decoder_layer(
178+
hidden_states,
179+
attention_mask=attention_mask,
180+
position_ids=position_ids,
181+
past_key_value=past_key_values[idx] if past_key_values is not None else None,
182+
output_attentions=output_attentions,
183+
use_cache=use_cache,
184+
)
185+
kv_cache += (layer_outputs[1],)
186+
hidden_states = layer_outputs[0]
187+
188+
hidden_states = self.norm(hidden_states)
189+
return BaseModelOutputWithPast(
190+
last_hidden_state=hidden_states,
191+
past_key_values=kv_cache,
192+
hidden_states=hidden_states,
193+
)

0 commit comments

Comments
 (0)