Skip to content

Commit b7de6dc

Browse files
model: add colqwen2 (#326)
Co-authored-by: thkim <[email protected]>
1 parent 96b2fc2 commit b7de6dc

File tree

9 files changed

+732
-1
lines changed

9 files changed

+732
-1
lines changed

src/optimum/rbln/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
"RBLNCLIPVisionModelWithProjectionConfig",
7575
"RBLNColPaliForRetrieval",
7676
"RBLNColPaliForRetrievalConfig",
77+
"RBLNColQwen2ForRetrieval",
78+
"RBLNColQwen2ForRetrievalConfig",
7779
"RBLNDecoderOnlyModelConfig",
7880
"RBLNDecoderOnlyModel",
7981
"RBLNDecoderOnlyModelForCausalLM",
@@ -366,6 +368,8 @@
366368
RBLNCLIPVisionModelWithProjectionConfig,
367369
RBLNColPaliForRetrieval,
368370
RBLNColPaliForRetrievalConfig,
371+
RBLNColQwen2ForRetrieval,
372+
RBLNColQwen2ForRetrievalConfig,
369373
RBLNDecoderOnlyModel,
370374
RBLNDecoderOnlyModelConfig,
371375
RBLNDecoderOnlyModelForCausalLM,

src/optimum/rbln/modeling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConf
113113
)
114114
return compiled_model
115115

116+
@classmethod
117+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
118+
return model
119+
116120
@classmethod
117121
def from_model(
118122
cls,
@@ -146,6 +150,8 @@ def from_model(
146150
Returns:
147151
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
148152
"""
153+
154+
model = cls._reconstruct_model_if_needed(model)
149155
preprocessors = kwargs.pop("preprocessors", [])
150156
rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
151157

src/optimum/rbln/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
"RBLNBlip2VisionModelConfig",
5555
"RBLNColPaliForRetrieval",
5656
"RBLNColPaliForRetrievalConfig",
57+
"RBLNColQwen2ForRetrieval",
58+
"RBLNColQwen2ForRetrievalConfig",
5759
"RBLNCLIPTextModel",
5860
"RBLNCLIPTextModelConfig",
5961
"RBLNCLIPTextModelWithProjection",
@@ -218,6 +220,8 @@
218220
RBLNCLIPVisionModelWithProjectionConfig,
219221
RBLNColPaliForRetrieval,
220222
RBLNColPaliForRetrievalConfig,
223+
RBLNColQwen2ForRetrieval,
224+
RBLNColQwen2ForRetrievalConfig,
221225
RBLNDecoderOnlyModel,
222226
RBLNDecoderOnlyModelConfig,
223227
RBLNDecoderOnlyModelForCausalLM,

src/optimum/rbln/transformers/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@
7575
"RBLNColPaliForRetrieval",
7676
"RBLNColPaliForRetrievalConfig",
7777
],
78+
"colqwen2": [
79+
"RBLNColQwen2ForRetrieval",
80+
"RBLNColQwen2ForRetrievalConfig",
81+
],
7882
"distilbert": [
7983
"RBLNDistilBertForQuestionAnswering",
8084
"RBLNDistilBertForQuestionAnsweringConfig",
@@ -236,6 +240,7 @@
236240
RBLNCLIPVisionModelWithProjectionConfig,
237241
)
238242
from .colpali import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
243+
from .colqwen2 import RBLNColQwen2ForRetrieval, RBLNColQwen2ForRetrievalConfig
239244
from .decoderonly import (
240245
RBLNDecoderOnlyModel,
241246
RBLNDecoderOnlyModelConfig,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .configuration_colqwen2 import RBLNColQwen2ForRetrievalConfig
2+
from .modeling_colqwen2 import RBLNColQwen2ForRetrieval
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional, Tuple, Union
16+
17+
import torch
18+
import torch.nn as nn
19+
from transformers import PreTrainedModel
20+
21+
from optimum.rbln.transformers.models.decoderonly.decoderonly_architecture import (
22+
DecoderOnlyLayer,
23+
DecoderOnlyModel,
24+
DecoderOnlyWrapper,
25+
)
26+
27+
from .configuration_colqwen2 import (
28+
RBLNColQwen2ForRetrievalConfig,
29+
)
30+
31+
32+
def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
33+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
34+
cos = cos[position_ids[0]][None, None, None, :, :]
35+
sin = sin[position_ids[0]][None, None, None, :, :]
36+
37+
return cos, sin
38+
39+
40+
class ColQwen2LanguageModelWrapper(DecoderOnlyWrapper):
41+
def __init__(
42+
self, model: PreTrainedModel, rbln_config: "RBLNColQwen2ForRetrievalConfig", use_rotary_emb: bool = True
43+
):
44+
model.config = (
45+
model.config.vlm_config.text_config if hasattr(model.config, "vlm_config") else model.config.text_config
46+
)
47+
super().__init__(model, rbln_config, use_rotary_emb)
48+
49+
def get_decoder_layers(self, model: PreTrainedModel):
50+
return model.language_model.layers
51+
52+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
53+
new_layers = []
54+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
55+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
56+
new_self_attn = self.get_rbln_attn_class()(
57+
self.get_attn_layer(layer),
58+
self.rbln_config,
59+
is_sliding=is_sliding,
60+
)
61+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
62+
new_layers.append(new_layer)
63+
64+
new_model = self.get_rbln_model_class()(
65+
model.language_model,
66+
new_layers,
67+
self.rbln_config,
68+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
69+
)
70+
71+
# text_projection layer from model
72+
self.embedding_proj_layer = (
73+
model.embedding_proj_layer if hasattr(model, "embedding_proj_layer") else model.custom_text_proj
74+
)
75+
return new_model
76+
77+
def get_rbln_model_class(self):
78+
return RBLNColQwen2LanguageModel
79+
80+
def prepare_forward_args(self, *args):
81+
args = list(args)
82+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
83+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
84+
cache_position = args.pop(0)
85+
global_block_tables = args.pop(0)
86+
local_block_tables = None
87+
position_embeds = args.pop(0)
88+
position_ids = None
89+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
90+
past_key_values = args
91+
92+
if len(past_key_values) != 2 * self.num_hidden_layers:
93+
raise ValueError(
94+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
95+
)
96+
97+
_past_key_values = []
98+
for i in range(self.config.num_hidden_layers):
99+
key_states = past_key_values[i * 2]
100+
value_states = past_key_values[i * 2 + 1]
101+
past_key_value = [key_states, value_states]
102+
_past_key_values.append(past_key_value)
103+
past_key_values = _past_key_values
104+
105+
return (
106+
input_ids,
107+
inputs_embeds,
108+
cache_position,
109+
global_block_tables,
110+
local_block_tables,
111+
attention_mask,
112+
position_ids,
113+
past_key_values,
114+
position_embeds,
115+
)
116+
117+
def forward(self, *args):
118+
(
119+
input_ids,
120+
inputs_embeds,
121+
cache_position,
122+
global_block_tables,
123+
local_block_tables,
124+
attention_mask,
125+
position_ids,
126+
past_key_values,
127+
rotary_emb,
128+
) = self.prepare_forward_args(*args)
129+
130+
last_hidden_states = self.model(
131+
input_ids=input_ids,
132+
inputs_embeds=inputs_embeds,
133+
attention_mask=attention_mask,
134+
cache_position=cache_position,
135+
position_ids=position_ids,
136+
past_key_values=past_key_values,
137+
rotary_emb=rotary_emb,
138+
global_block_tables=global_block_tables,
139+
local_block_tables=local_block_tables,
140+
)
141+
142+
proj = self.embedding_proj_layer(last_hidden_states[0])
143+
all_hidden_states = last_hidden_states[1] if self.rbln_config.output_hidden_states else None
144+
145+
if self.rbln_config.output_hidden_states:
146+
return proj, all_hidden_states
147+
else:
148+
return proj
149+
150+
151+
class RBLNColQwen2LanguageModel(DecoderOnlyModel):
152+
def __init__(
153+
self,
154+
model,
155+
layers: List["DecoderOnlyLayer"],
156+
rbln_config: "RBLNColQwen2ForRetrievalConfig",
157+
use_learned_pos_emb=None,
158+
):
159+
super().__init__(model, layers, rbln_config, use_learned_pos_emb)
160+
161+
self.output_hidden_states = rbln_config.output_hidden_states
162+
163+
def forward(
164+
self,
165+
input_ids: torch.Tensor = None,
166+
inputs_embeds: Optional[torch.Tensor] = None,
167+
attention_mask: torch.Tensor = None,
168+
cache_position: torch.Tensor = None,
169+
position_ids: torch.Tensor = None,
170+
query_position: torch.Tensor = None,
171+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
172+
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
173+
global_block_tables: Optional[torch.Tensor] = None,
174+
local_block_tables: Optional[torch.Tensor] = None,
175+
lora_int_id: Optional[torch.Tensor] = None,
176+
):
177+
# retrieve input_ids and inputs_embeds
178+
if (input_ids is None) ^ (inputs_embeds is not None):
179+
raise ValueError(
180+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
181+
)
182+
183+
# embed positions
184+
if inputs_embeds is None:
185+
inputs_embeds = self.get_embedding()(input_ids)
186+
187+
hidden_states = inputs_embeds * self.hidden_multiplier
188+
189+
# get cos,sin vector if needed
190+
position_ids = position_ids if position_ids is not None else cache_position
191+
if rotary_emb is not None:
192+
if isinstance(rotary_emb, torch.Tensor):
193+
cos = rotary_emb[0]
194+
sin = rotary_emb[1]
195+
else:
196+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
197+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
198+
199+
# Get sequence positions for flash attention
200+
if self.attn_impl == "flash_attn":
201+
seq_positions = cache_position[:, 0]
202+
seq_positions = self.convert_sequence_positions_for_flash_attn(
203+
seq_positions=seq_positions, max_seq_len=self.max_seq_len
204+
)
205+
else:
206+
seq_positions = cache_position[:, :1]
207+
208+
# Get local cache positions for sliding window layers
209+
if len(self.sliding_window_layers) > 0:
210+
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
211+
212+
all_hidden_states = () if self.output_hidden_states else None
213+
for layer_idx, layer in enumerate(self.layers):
214+
if self.output_hidden_states:
215+
all_hidden_states += (hidden_states,)
216+
217+
is_sliding = True if layer_idx in self.sliding_window_layers else False
218+
hidden_states = layer(
219+
hidden_states=hidden_states,
220+
attention_mask=attention_mask,
221+
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
222+
past_key_values=past_key_values,
223+
cos=cos,
224+
sin=sin,
225+
block_tables=local_block_tables if is_sliding else global_block_tables,
226+
lora_int_id=lora_int_id,
227+
)
228+
229+
hidden_states = self.get_last_layernorm()(hidden_states)
230+
if self.output_hidden_states:
231+
all_hidden_states += (hidden_states,)
232+
233+
return hidden_states, all_hidden_states
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
from optimum.rbln.configuration_utils import RBLNModelConfig
18+
19+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig
20+
21+
22+
class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
23+
submodules = ["visual"]
24+
25+
def __init__(
26+
self,
27+
visual: Optional[RBLNModelConfig] = None,
28+
batch_size: Optional[int] = None,
29+
use_inputs_embeds: bool = True,
30+
output_hidden_states: Optional[bool] = False,
31+
**kwargs,
32+
):
33+
super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
34+
if not self.use_inputs_embeds:
35+
raise ValueError(
36+
"RBLNColQwen2ForRetrievalConfig does not allow `use_inputs_embeds` to be set to False, "
37+
"as RBLNColQwen2ForRetrieval accepts only `inputs_embeds` as input."
38+
)
39+
if batch_size is not None and batch_size != 1:
40+
raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")
41+
42+
self.visual = visual
43+
self.output_hidden_states = output_hidden_states

0 commit comments

Comments
 (0)