Skip to content

Commit dfc0288

Browse files
committed
remove redundant codes between Qwen2_5VLModel & Qwen2_5VLForConditionalGeneration
1 parent 059d646 commit dfc0288

File tree

2 files changed

+9
-122
lines changed

2 files changed

+9
-122
lines changed

src/optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
5151
new_model.vlm.model.language_model.load_state_dict(model.language_model.state_dict())
5252
model = new_model
5353

54-
# del model.vlm.model.lm_head
54+
# replace the lm_head with the custom text projection layer for optimization
5555
model.vlm.model.lm_head = model.embedding_proj_layer
5656

5757
return model.to(torch.float32)

src/optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 8 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,9 @@ def __post_init__(self, **kwargs):
361361

362362
super().__post_init__(**kwargs)
363363
self.visual = self.rbln_submodules[0]
364-
self.mrope_section = self.config.rope_scaling["mrope_section"]
365364
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config)
366-
self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
367-
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
365+
if not self.can_generate():
366+
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
368367

369368
def _create_embedding_layer(self):
370369
with no_init_weights():
@@ -398,7 +397,7 @@ def get_input_info(
398397

399398
def _get_position_embeddings(self, hidden_states, position_ids):
400399
cos, sin = self.rotary_emb(hidden_states, position_ids)
401-
mrope_section = self.mrope_section * 2
400+
mrope_section = self.config.rope_scaling["mrope_section"] * 2
402401
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
403402
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
404403
return torch.stack([cos, sin])
@@ -544,7 +543,8 @@ def forward(
544543
return RBLNDecoderOnlyOutput(logits=logits)
545544

546545

547-
class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
546+
# MRO: RBLNQwen2_5_VLForConditionalGeneration -> RBLNQwen2_5_VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
547+
class RBLNQwen2_5_VLForConditionalGeneration(RBLNQwen2_5_VLModel, RBLNDecoderOnlyModelForCausalLM):
548548
"""
549549
RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
550550
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -579,20 +579,16 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
579579
```
580580
"""
581581

582-
_supports_non_fp32 = False
583-
584582
auto_model_class = AutoModelForVision2Seq
583+
_decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
584+
_supports_non_fp32 = False
585+
_use_rotary_emb = False
585586
_rbln_submodules = [
586587
{"name": "visual"},
587588
]
588-
_decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
589-
_use_rotary_emb = False
590589

591590
def __post_init__(self, **kwargs):
592591
super().__post_init__(**kwargs)
593-
self.visual = self.rbln_submodules[0]
594-
self.mrope_section = self.config.rope_scaling["mrope_section"]
595-
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config)
596592
self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
597593

598594
def can_generate(self):
@@ -601,31 +597,8 @@ def can_generate(self):
601597
@classmethod
602598
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
603599
model.model.lm_head = model.lm_head
604-
model.lm_head = None
605-
del model.lm_head
606600
return model
607601

608-
@classmethod
609-
def get_input_info(
610-
cls,
611-
batch_size: int,
612-
query_length: int,
613-
rbln_config: RBLNQwen2_5_VLForConditionalGenerationConfig,
614-
model_config: PretrainedConfig,
615-
):
616-
input_info = super().get_input_info(batch_size, query_length, rbln_config, model_config)
617-
pos_idx = 3
618-
input_info.insert(
619-
pos_idx,
620-
(
621-
"position_emb",
622-
[2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
623-
"float32",
624-
),
625-
)
626-
627-
return input_info
628-
629602
def prepare_inputs_for_generation(
630603
self,
631604
input_ids: torch.LongTensor,
@@ -670,92 +643,6 @@ def prepare_inputs_for_generation(
670643

671644
return model_inputs
672645

673-
def _get_position_embeddings(self, hidden_states, position_ids):
674-
cos, sin = self.rotary_emb(hidden_states, position_ids)
675-
mrope_section = self.mrope_section * 2
676-
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
677-
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
678-
return torch.stack([cos, sin])
679-
680-
def _preprocess_prefill(
681-
self,
682-
input_ids: torch.LongTensor = None,
683-
attention_mask: torch.Tensor = None,
684-
pixel_values: torch.Tensor = None,
685-
pixel_values_videos: torch.FloatTensor = None,
686-
image_grid_thw: torch.LongTensor = None,
687-
video_grid_thw: torch.LongTensor = None,
688-
second_per_grid_ts: torch.Tensor = None,
689-
):
690-
batch_size = input_ids.shape[0]
691-
inputs_embeds = self.embed_tokens(input_ids)
692-
693-
if pixel_values is not None:
694-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
695-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
696-
n_image_features = image_embeds.shape[0]
697-
if n_image_tokens != n_image_features:
698-
raise ValueError(
699-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
700-
)
701-
702-
mask = input_ids == self.config.image_token_id
703-
mask_unsqueezed = mask.unsqueeze(-1)
704-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
705-
706-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
707-
inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, image_embeds)
708-
709-
if pixel_values_videos is not None:
710-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
711-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
712-
n_video_features = video_embeds.shape[0]
713-
if n_video_tokens != n_video_features:
714-
raise ValueError(
715-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
716-
)
717-
718-
mask = input_ids == self.config.video_token_id
719-
mask_unsqueezed = mask.unsqueeze(-1)
720-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
721-
inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
722-
723-
max_inputs_len = input_ids.shape[1]
724-
725-
head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
726-
all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
727-
all_rope_deltas = []
728-
729-
image_token_id = self.config.image_token_id
730-
video_token_id = self.config.video_token_id
731-
vision_start_token_id = self.config.vision_start_token_id
732-
image_idx, video_idx = 0, 0
733-
734-
for b_idx in range(batch_size):
735-
input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
736-
vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
737-
vision_tokens = input_id[0][vision_start_indices + 1]
738-
image_nums = (vision_tokens == image_token_id).sum()
739-
video_nums = (vision_tokens == video_token_id).sum()
740-
position_ids, rope_deltas = Qwen2_5_VLModel.get_rope_index(
741-
self,
742-
input_id,
743-
image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
744-
video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
745-
second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None,
746-
)
747-
image_idx += image_nums
748-
video_idx += video_nums
749-
750-
position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
751-
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
752-
all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
753-
all_rope_deltas.append(rope_deltas)
754-
755-
rope_deltas = torch.stack(all_rope_deltas)
756-
757-
return inputs_embeds, all_position_embeds, rope_deltas
758-
759646
def _preprocess_decoder(
760647
self,
761648
input_ids: torch.LongTensor = None,

0 commit comments

Comments
 (0)