|
21 | 21 | logger = logging.get_logger(__name__) |
22 | 22 |
|
23 | 23 |
|
| 24 | +@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) |
| 25 | +@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| 26 | +def lce_forward_deprecated( |
| 27 | + self, |
| 28 | + input_ids: torch.LongTensor = None, |
| 29 | + pixel_values: torch.FloatTensor = None, |
| 30 | + attention_mask: Optional[torch.Tensor] = None, |
| 31 | + position_ids: Optional[torch.LongTensor] = None, |
| 32 | + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, |
| 33 | + token_type_ids: Optional[torch.LongTensor] = None, |
| 34 | + cache_position: Optional[torch.LongTensor] = None, |
| 35 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 36 | + labels: Optional[torch.LongTensor] = None, |
| 37 | + use_cache: Optional[bool] = None, |
| 38 | + output_attentions: Optional[bool] = None, |
| 39 | + output_hidden_states: Optional[bool] = None, |
| 40 | + return_dict: Optional[bool] = None, |
| 41 | +) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: |
| 42 | + r""" |
| 43 | + Args: |
| 44 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 45 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 46 | + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 47 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| 48 | +
|
| 49 | + Returns: |
| 50 | +
|
| 51 | + Example: |
| 52 | +
|
| 53 | + ```python |
| 54 | + >>> from PIL import Image |
| 55 | + >>> import requests |
| 56 | + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
| 57 | +
|
| 58 | + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf") |
| 59 | + >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf") |
| 60 | +
|
| 61 | + >>> prompt = "answer en Where is the cow standing?" |
| 62 | + >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png" |
| 63 | + >>> image = Image.open(requests.get(url, stream=True).raw) |
| 64 | +
|
| 65 | + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") |
| 66 | +
|
| 67 | + >>> # Generate |
| 68 | + >>> generate_ids = model.generate(**inputs, max_length=30) |
| 69 | + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| 70 | + "answer en Where is the cow standing?\nbeach" |
| 71 | + ```""" |
| 72 | + |
| 73 | + if (input_ids is None) ^ (inputs_embeds is not None): |
| 74 | + raise ValueError( |
| 75 | + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| 76 | + ) |
| 77 | + |
| 78 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 79 | + output_hidden_states = ( |
| 80 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 81 | + ) |
| 82 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 83 | + |
| 84 | + # the attention mask is turned 4d after, we keep track of the original one |
| 85 | + input_attention_mask = attention_mask |
| 86 | + |
| 87 | + if inputs_embeds is None: |
| 88 | + # 1. Extra the input embeddings |
| 89 | + inputs_embeds = self.get_input_embeddings()(input_ids) |
| 90 | + |
| 91 | + # 2. Merge text and images |
| 92 | + if pixel_values is not None and input_ids.shape[1] != 1: |
| 93 | + image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype)) |
| 94 | + selected_image_feature = image_outputs.last_hidden_state |
| 95 | + image_features = self.multi_modal_projector(selected_image_feature) |
| 96 | + |
| 97 | + if cache_position is None: |
| 98 | + cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) |
| 99 | + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( |
| 100 | + image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position |
| 101 | + ) |
| 102 | + |
| 103 | + else: |
| 104 | + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of |
| 105 | + # generation with cache |
| 106 | + if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: |
| 107 | + # Retrieve the first layer to inspect the logits and mask out the hidden states |
| 108 | + # that are set to 0 |
| 109 | + # TODO @molbap this will only work for dynamic cache. |
| 110 | + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
| 111 | + |
| 112 | + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 |
| 113 | + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) |
| 114 | + |
| 115 | + # Get the target length |
| 116 | + target_seqlen = cache_position[-1] + 1 |
| 117 | + extended_attention_mask = torch.ones( |
| 118 | + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1), |
| 119 | + dtype=attention_mask.dtype, |
| 120 | + device=attention_mask.device, |
| 121 | + ) |
| 122 | + # Filter out only the tokens that can be un-attended, this can happen |
| 123 | + # if one uses PaliGemma+ Fused modules where the cache on the |
| 124 | + # first iteration is already big enough, or if one passes custom cache |
| 125 | + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) |
| 126 | + new_batch_index = batch_index[valid_indices] |
| 127 | + new_non_attended_tokens = non_attended_tokens[valid_indices] |
| 128 | + |
| 129 | + # Zero-out the places where we don't need to attend |
| 130 | + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 |
| 131 | + |
| 132 | + attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) |
| 133 | + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
| 134 | + |
| 135 | + attention_mask = attention_mask.to(inputs_embeds.dtype) |
| 136 | + outputs = self.language_model.model( |
| 137 | + attention_mask=attention_mask, |
| 138 | + position_ids=position_ids, |
| 139 | + past_key_values=past_key_values, |
| 140 | + inputs_embeds=inputs_embeds, |
| 141 | + use_cache=use_cache, |
| 142 | + output_attentions=output_attentions, |
| 143 | + output_hidden_states=output_hidden_states, |
| 144 | + return_dict=return_dict, |
| 145 | + cache_position=cache_position, |
| 146 | + ) |
| 147 | + |
| 148 | + hidden_states = outputs[0] |
| 149 | + |
| 150 | + loss = None |
| 151 | + logits = None |
| 152 | + |
| 153 | + if self.training and (labels is not None): |
| 154 | + shift_hidden_states = hidden_states[..., :-1, :] |
| 155 | + shift_labels = labels[..., 1:] |
| 156 | + |
| 157 | + hidden_device = shift_hidden_states.device |
| 158 | + |
| 159 | + if attention_mask is not None: |
| 160 | + # we use the input attention mask to shift the hidden_states and labels, because it is 2D. |
| 161 | + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft |
| 162 | + shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device) |
| 163 | + shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous() |
| 164 | + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
| 165 | + else: |
| 166 | + shift_hidden_states = shift_hidden_states.contiguous() |
| 167 | + shift_labels = shift_labels.contiguous() |
| 168 | + |
| 169 | + # Flatten hidden state |
| 170 | + shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size) |
| 171 | + shift_labels = shift_labels.view(-1).to(hidden_device) |
| 172 | + |
| 173 | + lce = LigerFusedLinearCrossEntropyLoss() |
| 174 | + loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) |
| 175 | + |
| 176 | + else: |
| 177 | + logits = self.language_model.lm_head(hidden_states) |
| 178 | + if labels is not None: |
| 179 | + shift_logits = logits[..., :-1, :] |
| 180 | + shift_labels = labels[..., 1:] |
| 181 | + if input_attention_mask is not None: |
| 182 | + # we use the input attention mask to shift the logits and labels, because it is 2D. |
| 183 | + shift_attention_mask = input_attention_mask[..., 1:] |
| 184 | + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() |
| 185 | + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
| 186 | + else: |
| 187 | + shift_logits = shift_logits.contiguous() |
| 188 | + shift_labels = shift_labels.contiguous() |
| 189 | + # Flatten the tokens |
| 190 | + loss_fct = CrossEntropyLoss() |
| 191 | + |
| 192 | + flat_logits = shift_logits.view(-1, self.config.vocab_size) |
| 193 | + flat_labels = shift_labels.view(-1).to(shift_logits.device) |
| 194 | + loss = loss_fct(flat_logits, flat_labels) |
| 195 | + if not return_dict: |
| 196 | + output = (logits,) + outputs[1:] |
| 197 | + return (loss,) + output if loss is not None else output |
| 198 | + |
| 199 | + return PaliGemmaCausalLMOutputWithPast( |
| 200 | + loss=loss, |
| 201 | + logits=logits, |
| 202 | + past_key_values=outputs.past_key_values, |
| 203 | + hidden_states=outputs.hidden_states, |
| 204 | + attentions=outputs.attentions, |
| 205 | + ) |
| 206 | + |
| 207 | + |
24 | 208 | @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
25 | 209 | @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) |
26 | 210 | @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
0 commit comments