@@ -71,35 +71,18 @@ def _forward(
7171 input_ids : torch .LongTensor = None ,
7272 pixel_values : torch .FloatTensor = None ,
7373 image_sizes : Optional [torch .LongTensor ] = None ,
74- inputs_embeds : Optional [torch .FloatTensor ] = None ,
75- vision_feature_layer : Optional [int ] = None ,
76- vision_feature_select_strategy : Optional [str ] = None ,
7774 cache_position : Union [List [torch .Tensor ],
7875 torch .Tensor ] = None , # vllm keyword argument
7976 ** kwargs ,
8077 ):
81- if inputs_embeds is not None :
82- raise NotImplementedError (
83- "Specifying inputs_embeds is not supported." )
84-
8578 if is_prefill :
86- # Get text_embeds
87- inputs_embeds = self .model .text_embedding (input_ids )
88-
89- # If any images in the prompt, get image_embeds and merge with text
90- if pixel_values is not None and input_ids .shape [
91- 1 ] != 1 and pixel_values .size (0 ) > 0 :
92- image_features , _ = self .model .image_embedding (
93- image_sizes , pixel_values , vision_feature_layer ,
94- vision_feature_select_strategy )
95-
96- inputs_embeds = self .merge_multimodal_embeddings (
97- input_ids , inputs_embeds , image_features ,
98- self .model .config .image_token_index )
99- else :
100- inputs_embeds = self .model .text_embedding (input_ids = input_ids )
79+ # NOTE inputs_embeds will be generated inside _preprocess_prefill
80+ inputs_embeds = self .model ._preprocess_prefill (
81+ input_ids = input_ids ,
82+ pixel_values = pixel_values ,
83+ image_sizes = image_sizes ,
84+ )
10185
102- if is_prefill :
10386 if self .model .language_model .prefill_decoder is None :
10487 raise version_error
10588
@@ -113,7 +96,7 @@ def _forward(
11396 raise version_error
11497
11598 logits = self .model .language_model .decoder (
116- inputs_embeds = inputs_embeds ,
99+ input_ids = input_ids ,
117100 cache_position = cache_position ,
118101 block_tables = block_tables ,
119102 ).logits
@@ -160,9 +143,9 @@ def forward(self, model_input: ModelInputForRBLN,
160143 is_prefill = is_prompt ,
161144 block_tables = block_tables ,
162145 input_ids = input_ids ,
163- cache_position = cache_position ,
164146 pixel_values = pixel_values ,
165147 image_sizes = image_sizes ,
148+ cache_position = cache_position ,
166149 )
167150
168151 if not is_prompt :
0 commit comments