Skip to content

Commit b60bb61

Browse files
committed
Refactor lce_forward_deprecated to streamline input processing and remove deprecated parameters
1 parent 56bced1 commit b60bb61

1 file changed

Lines changed: 20 additions & 33 deletions

File tree

  • src/liger_kernel/transformers/model

src/liger_kernel/transformers/model/llava.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def lce_forward_deprecated(
3434
output_attentions: Optional[bool] = None,
3535
output_hidden_states: Optional[bool] = None,
3636
return_dict: Optional[bool] = None,
37-
cache_position: Optional[torch.LongTensor] = None,
38-
num_logits_to_keep: int = 0,
3937
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
4038
r"""
4139
Args:
@@ -96,39 +94,32 @@ def lce_forward_deprecated(
9694
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
9795
)
9896

99-
legacy_processing = False
10097
if inputs_embeds is None:
98+
# 1. Extra the input embeddings
10199
inputs_embeds = self.get_input_embeddings()(input_ids)
102100

103-
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
104-
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
105-
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
106-
legacy_processing = (
107-
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
108-
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
109-
110-
image_features = None
111-
if pixel_values is not None:
112-
image_features = self.get_image_features(
113-
pixel_values=pixel_values,
114-
vision_feature_layer=vision_feature_layer,
115-
vision_feature_select_strategy=vision_feature_select_strategy,
116-
)
117-
118-
if legacy_processing and image_features is not None:
119-
logger.warning_once(
120-
"Expanding inputs for image tokens in LLaVa should be done in processing. "
121-
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
122-
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
123-
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
124-
)
125-
# prefill stage vs decoding stage (legacy behavior copied)
126-
if input_ids.shape[1] != 1:
101+
# 2. Merge text and images
102+
if pixel_values is not None and input_ids.shape[1] != 1:
103+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
104+
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
105+
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
106+
107+
if vision_feature_select_strategy == "default":
108+
selected_image_feature = selected_image_feature[:, 1:]
109+
elif vision_feature_select_strategy == "full":
110+
selected_image_feature = selected_image_feature
111+
else:
112+
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
113+
114+
image_features = self.multi_modal_projector(selected_image_feature)
115+
inputs_embeds = inputs_embeds.to(image_features.dtype)
127116
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
128117
image_features, inputs_embeds, input_ids, attention_mask, labels
129118
)
130-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
131-
else:
119+
120+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
121+
# generation with cache
122+
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
132123
# Retrieve the first layer to inspect the logits and mask out the hidden states
133124
# that are set to 0
134125
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
@@ -158,7 +149,6 @@ def lce_forward_deprecated(
158149

159150
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
160151
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
161-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
162152

163153
# TODO: @raushan retain only the new behavior after v4.47
164154
elif image_features is not None:
@@ -184,8 +174,6 @@ def lce_forward_deprecated(
184174
output_attentions=output_attentions,
185175
output_hidden_states=output_hidden_states,
186176
return_dict=return_dict,
187-
cache_position=cache_position,
188-
num_logits_to_keep=num_logits_to_keep,
189177
)
190178
hidden_states = outputs[0]
191179

@@ -220,7 +208,6 @@ def lce_forward_deprecated(
220208
past_key_values=outputs.past_key_values,
221209
hidden_states=outputs.hidden_states,
222210
attentions=outputs.attentions,
223-
image_hidden_states=image_features if pixel_values is not None else None,
224211
)
225212

226213

0 commit comments

Comments
 (0)