@@ -34,8 +34,6 @@ def lce_forward_deprecated(
3434 output_attentions : Optional [bool ] = None ,
3535 output_hidden_states : Optional [bool ] = None ,
3636 return_dict : Optional [bool ] = None ,
37- cache_position : Optional [torch .LongTensor ] = None ,
38- num_logits_to_keep : int = 0 ,
3937) -> Union [Tuple , LlavaCausalLMOutputWithPast ]:
4038 r"""
4139 Args:
@@ -96,39 +94,32 @@ def lce_forward_deprecated(
9694 "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
9795 )
9896
99- legacy_processing = False
10097 if inputs_embeds is None :
98+ # 1. Extra the input embeddings
10199 inputs_embeds = self .get_input_embeddings ()(input_ids )
102100
103- # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
104- # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
105- # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
106- legacy_processing = (
107- (input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length
108- ) or (input_ids .shape [- 1 ] == 1 and pixel_values is not None )
109-
110- image_features = None
111- if pixel_values is not None :
112- image_features = self .get_image_features (
113- pixel_values = pixel_values ,
114- vision_feature_layer = vision_feature_layer ,
115- vision_feature_select_strategy = vision_feature_select_strategy ,
116- )
117-
118- if legacy_processing and image_features is not None :
119- logger .warning_once (
120- "Expanding inputs for image tokens in LLaVa should be done in processing. "
121- "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
122- "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
123- "Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
124- )
125- # prefill stage vs decoding stage (legacy behavior copied)
126- if input_ids .shape [1 ] != 1 :
101+ # 2. Merge text and images
102+ if pixel_values is not None and input_ids .shape [1 ] != 1 :
103+ image_outputs = self .vision_tower (pixel_values , output_hidden_states = True )
104+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
105+ selected_image_feature = image_outputs .hidden_states [vision_feature_layer ]
106+
107+ if vision_feature_select_strategy == "default" :
108+ selected_image_feature = selected_image_feature [:, 1 :]
109+ elif vision_feature_select_strategy == "full" :
110+ selected_image_feature = selected_image_feature
111+ else :
112+ raise ValueError (f"Unexpected select feature strategy: { self .config .vision_feature_select_strategy } " )
113+
114+ image_features = self .multi_modal_projector (selected_image_feature )
115+ inputs_embeds = inputs_embeds .to (image_features .dtype )
127116 inputs_embeds , attention_mask , labels , position_ids = self ._merge_input_ids_with_image_features (
128117 image_features , inputs_embeds , input_ids , attention_mask , labels
129118 )
130- cache_position = torch .arange (attention_mask .shape [1 ], device = attention_mask .device )
131- else :
119+
120+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
121+ # generation with cache
122+ elif past_key_values is not None and pixel_values is not None and input_ids .shape [1 ] == 1 :
132123 # Retrieve the first layer to inspect the logits and mask out the hidden states
133124 # that are set to 0
134125 first_layer_past_key_value = past_key_values [0 ][0 ][:, :, :, 0 ]
@@ -158,7 +149,6 @@ def lce_forward_deprecated(
158149
159150 attention_mask = torch .cat ((extended_attention_mask , attention_mask [:, - target_length :]), dim = 1 )
160151 position_ids = torch .sum (attention_mask , dim = 1 ).unsqueeze (- 1 ) - 1
161- cache_position = torch .arange (attention_mask .shape [1 ], device = attention_mask .device )[- target_length :]
162152
163153 # TODO: @raushan retain only the new behavior after v4.47
164154 elif image_features is not None :
@@ -184,8 +174,6 @@ def lce_forward_deprecated(
184174 output_attentions = output_attentions ,
185175 output_hidden_states = output_hidden_states ,
186176 return_dict = return_dict ,
187- cache_position = cache_position ,
188- num_logits_to_keep = num_logits_to_keep ,
189177 )
190178 hidden_states = outputs [0 ]
191179
@@ -220,7 +208,6 @@ def lce_forward_deprecated(
220208 past_key_values = outputs .past_key_values ,
221209 hidden_states = outputs .hidden_states ,
222210 attentions = outputs .attentions ,
223- image_hidden_states = image_features if pixel_values is not None else None ,
224211 )
225212
226213
0 commit comments