@@ -361,10 +361,9 @@ def __post_init__(self, **kwargs):
361361
362362 super ().__post_init__ (** kwargs )
363363 self .visual = self .rbln_submodules [0 ]
364- self .mrope_section = self .config .rope_scaling ["mrope_section" ]
365364 self .rotary_emb = Qwen2_5_VLRotaryEmbedding (self .config )
366- self . rope_deltas = torch . zeros ( self .rbln_config . batch_size )
367- self .block_tables = torch .arange (self .rbln_config .kvcache_num_blocks , dtype = torch .int16 )
365+ if not self .can_generate ():
366+ self .block_tables = torch .arange (self .rbln_config .kvcache_num_blocks , dtype = torch .int16 )
368367
369368 def _create_embedding_layer (self ):
370369 with no_init_weights ():
@@ -398,7 +397,7 @@ def get_input_info(
398397
399398 def _get_position_embeddings (self , hidden_states , position_ids ):
400399 cos , sin = self .rotary_emb (hidden_states , position_ids )
401- mrope_section = self .mrope_section * 2
400+ mrope_section = self .config . rope_scaling [ " mrope_section" ] * 2
402401 cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (1 )
403402 sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (1 )
404403 return torch .stack ([cos , sin ])
@@ -544,7 +543,8 @@ def forward(
544543 return RBLNDecoderOnlyOutput (logits = logits )
545544
546545
547- class RBLNQwen2_5_VLForConditionalGeneration (RBLNDecoderOnlyModelForCausalLM ):
546+ # MRO: RBLNQwen2_5_VLForConditionalGeneration -> RBLNQwen2_5_VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
547+ class RBLNQwen2_5_VLForConditionalGeneration (RBLNQwen2_5_VLModel , RBLNDecoderOnlyModelForCausalLM ):
548548 """
549549 RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
550550 optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -579,20 +579,16 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
579579 ```
580580 """
581581
582- _supports_non_fp32 = False
583-
584582 auto_model_class = AutoModelForVision2Seq
583+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
584+ _supports_non_fp32 = False
585+ _use_rotary_emb = False
585586 _rbln_submodules = [
586587 {"name" : "visual" },
587588 ]
588- _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
589- _use_rotary_emb = False
590589
591590 def __post_init__ (self , ** kwargs ):
592591 super ().__post_init__ (** kwargs )
593- self .visual = self .rbln_submodules [0 ]
594- self .mrope_section = self .config .rope_scaling ["mrope_section" ]
595- self .rotary_emb = Qwen2_5_VLRotaryEmbedding (self .config )
596592 self .rope_deltas = torch .zeros (self .rbln_config .batch_size )
597593
598594 def can_generate (self ):
@@ -601,31 +597,8 @@ def can_generate(self):
601597 @classmethod
602598 def _reconstruct_model_if_needed (cls , model : "PreTrainedModel" ):
603599 model .model .lm_head = model .lm_head
604- model .lm_head = None
605- del model .lm_head
606600 return model
607601
608- @classmethod
609- def get_input_info (
610- cls ,
611- batch_size : int ,
612- query_length : int ,
613- rbln_config : RBLNQwen2_5_VLForConditionalGenerationConfig ,
614- model_config : PretrainedConfig ,
615- ):
616- input_info = super ().get_input_info (batch_size , query_length , rbln_config , model_config )
617- pos_idx = 3
618- input_info .insert (
619- pos_idx ,
620- (
621- "position_emb" ,
622- [2 , batch_size , 1 , query_length , model_config .hidden_size // model_config .num_attention_heads ],
623- "float32" ,
624- ),
625- )
626-
627- return input_info
628-
629602 def prepare_inputs_for_generation (
630603 self ,
631604 input_ids : torch .LongTensor ,
@@ -670,92 +643,6 @@ def prepare_inputs_for_generation(
670643
671644 return model_inputs
672645
673- def _get_position_embeddings (self , hidden_states , position_ids ):
674- cos , sin = self .rotary_emb (hidden_states , position_ids )
675- mrope_section = self .mrope_section * 2
676- cos = torch .cat ([m [i % 3 ] for i , m in enumerate (cos .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (1 )
677- sin = torch .cat ([m [i % 3 ] for i , m in enumerate (sin .split (mrope_section , dim = - 1 ))], dim = - 1 ).unsqueeze (1 )
678- return torch .stack ([cos , sin ])
679-
680- def _preprocess_prefill (
681- self ,
682- input_ids : torch .LongTensor = None ,
683- attention_mask : torch .Tensor = None ,
684- pixel_values : torch .Tensor = None ,
685- pixel_values_videos : torch .FloatTensor = None ,
686- image_grid_thw : torch .LongTensor = None ,
687- video_grid_thw : torch .LongTensor = None ,
688- second_per_grid_ts : torch .Tensor = None ,
689- ):
690- batch_size = input_ids .shape [0 ]
691- inputs_embeds = self .embed_tokens (input_ids )
692-
693- if pixel_values is not None :
694- image_embeds = self .visual (pixel_values , grid_thw = image_grid_thw )
695- n_image_tokens = (input_ids == self .config .image_token_id ).sum ().item ()
696- n_image_features = image_embeds .shape [0 ]
697- if n_image_tokens != n_image_features :
698- raise ValueError (
699- f"Image features and image tokens do not match: tokens: { n_image_tokens } , features { n_image_features } "
700- )
701-
702- mask = input_ids == self .config .image_token_id
703- mask_unsqueezed = mask .unsqueeze (- 1 )
704- mask_expanded = mask_unsqueezed .expand_as (inputs_embeds )
705-
706- image_embeds = image_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
707- inputs_embeds = inputs_embeds .masked_scatter (mask_expanded , image_embeds )
708-
709- if pixel_values_videos is not None :
710- video_embeds = self .visual (pixel_values_videos , grid_thw = video_grid_thw )
711- n_video_tokens = (input_ids == self .config .video_token_id ).sum ().item ()
712- n_video_features = video_embeds .shape [0 ]
713- if n_video_tokens != n_video_features :
714- raise ValueError (
715- f"Video features and video tokens do not match: tokens: { n_video_tokens } , features { n_video_features } "
716- )
717-
718- mask = input_ids == self .config .video_token_id
719- mask_unsqueezed = mask .unsqueeze (- 1 )
720- mask_expanded = mask_unsqueezed .expand_as (inputs_embeds )
721- inputs_embeds = inputs_embeds .masked_scatter (mask_expanded , video_embeds )
722-
723- max_inputs_len = input_ids .shape [1 ]
724-
725- head_dim = getattr (self .config , "head_dim" , None ) or self .config .hidden_size // self .config .num_attention_heads
726- all_position_embeds = torch .zeros (2 , batch_size , 1 , max_inputs_len , head_dim )
727- all_rope_deltas = []
728-
729- image_token_id = self .config .image_token_id
730- video_token_id = self .config .video_token_id
731- vision_start_token_id = self .config .vision_start_token_id
732- image_idx , video_idx = 0 , 0
733-
734- for b_idx in range (batch_size ):
735- input_id = input_ids [b_idx : b_idx + 1 ][:, attention_mask [b_idx ].bool ()]
736- vision_start_indices = torch .argwhere (input_id == vision_start_token_id ).squeeze (1 )
737- vision_tokens = input_id [0 ][vision_start_indices + 1 ]
738- image_nums = (vision_tokens == image_token_id ).sum ()
739- video_nums = (vision_tokens == video_token_id ).sum ()
740- position_ids , rope_deltas = Qwen2_5_VLModel .get_rope_index (
741- self ,
742- input_id ,
743- image_grid_thw [image_idx : image_idx + image_nums ] if image_grid_thw is not None else None ,
744- video_grid_thw [video_idx : video_idx + video_nums ] if video_grid_thw is not None else None ,
745- second_per_grid_ts [video_idx : video_idx + video_nums ] if second_per_grid_ts is not None else None ,
746- )
747- image_idx += image_nums
748- video_idx += video_nums
749-
750- position_embed = self ._get_position_embeddings (inputs_embeds , position_ids )
751- mask_indices = torch .nonzero (attention_mask [b_idx ], as_tuple = True )[0 ]
752- all_position_embeds [:, b_idx : b_idx + 1 ].index_copy_ (dim = - 2 , index = mask_indices , source = position_embed )
753- all_rope_deltas .append (rope_deltas )
754-
755- rope_deltas = torch .stack (all_rope_deltas )
756-
757- return inputs_embeds , all_position_embeds , rope_deltas
758-
759646 def _preprocess_decoder (
760647 self ,
761648 input_ids : torch .LongTensor = None ,
0 commit comments