|
33 | 33 | from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
34 | 34 | from ...processing_utils import Unpack
|
35 | 35 | from ...utils import (
|
| 36 | + add_start_docstrings_to_model_forward, |
| 37 | + is_torchdynamo_compiling, |
36 | 38 | logging,
|
| 39 | + replace_return_docstrings, |
37 | 40 | )
|
| 41 | +from ...utils.deprecation import deprecate_kwarg |
38 | 42 | from ..bart.modeling_bart import BartScaledWordEmbedding
|
39 | 43 | from ..gemma2.configuration_gemma2 import Gemma2Config
|
40 | 44 | from ..gemma2.modeling_gemma2 import (
|
@@ -837,6 +841,217 @@ def _update_causal_mask(
|
837 | 841 |
|
838 | 842 | return causal_mask
|
839 | 843 |
|
| 844 | + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
| 845 | + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) |
| 846 | + @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| 847 | + def forward( |
| 848 | + self, |
| 849 | + input_ids: torch.LongTensor = None, |
| 850 | + pixel_values: torch.FloatTensor = None, |
| 851 | + attention_mask: Optional[torch.Tensor] = None, |
| 852 | + position_ids: Optional[torch.LongTensor] = None, |
| 853 | + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, |
| 854 | + token_type_ids: Optional[torch.LongTensor] = None, |
| 855 | + cache_position: Optional[torch.LongTensor] = None, |
| 856 | + inputs_embeds: Optional[torch.FloatTensor] = None, |
| 857 | + labels: Optional[torch.LongTensor] = None, |
| 858 | + use_cache: Optional[bool] = None, |
| 859 | + output_attentions: Optional[bool] = None, |
| 860 | + output_hidden_states: Optional[bool] = None, |
| 861 | + return_dict: Optional[bool] = None, |
| 862 | + logits_to_keep: Union[int, torch.Tensor] = 0, |
| 863 | + **lm_kwargs, |
| 864 | + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: |
| 865 | + r""" |
| 866 | + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| 867 | + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| 868 | + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| 869 | + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. |
| 870 | +
|
| 871 | + logits_to_keep (`int` or `torch.Tensor`, *optional*): |
| 872 | + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
| 873 | + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| 874 | + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| 875 | + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
| 876 | + This is useful when using packed tensor format (single dimension for batch and sequence length). |
| 877 | +
|
| 878 | + Returns: |
| 879 | +
|
| 880 | + Example: |
| 881 | +
|
| 882 | + ```python |
| 883 | + >>> from PIL import Image |
| 884 | + >>> import requests |
| 885 | + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration |
| 886 | +
|
| 887 | + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") |
| 888 | + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") |
| 889 | +
|
| 890 | + >>> prompt = "answer en Where is the cow standing?" |
| 891 | + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" |
| 892 | + >>> image = Image.open(requests.get(url, stream=True).raw) |
| 893 | +
|
| 894 | + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
| 895 | +
|
| 896 | + >>> # Generate |
| 897 | + >>> generate_ids = model.generate(**inputs, max_length=30) |
| 898 | + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| 899 | + "answer en Where is the cow standing?\nbeach" |
| 900 | + ```""" |
| 901 | + |
| 902 | + if (input_ids is None) ^ (inputs_embeds is not None): |
| 903 | + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| 904 | + |
| 905 | + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| 906 | + output_hidden_states = ( |
| 907 | + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| 908 | + ) |
| 909 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 910 | + |
| 911 | + is_training = token_type_ids is not None and labels is not None |
| 912 | + |
| 913 | + # Replace image id woth PAD if the image token if OOV, to avoid index-errors |
| 914 | + if input_ids is not None and self.config.image_token_index >= self.vocab_size: |
| 915 | + special_image_mask = input_ids == self.config.image_token_index |
| 916 | + llm_input_ids = input_ids.clone() |
| 917 | + llm_input_ids[special_image_mask] = 0 |
| 918 | + else: |
| 919 | + llm_input_ids = input_ids |
| 920 | + |
| 921 | + if inputs_embeds is None: |
| 922 | + inputs_embeds = self.get_input_embeddings()(llm_input_ids) |
| 923 | + |
| 924 | + if cache_position is None: |
| 925 | + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| 926 | + cache_position = torch.arange( |
| 927 | + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| 928 | + ) |
| 929 | + |
| 930 | + # Merge text and images |
| 931 | + if pixel_values is not None: |
| 932 | + image_features = self.get_image_features(pixel_values) |
| 933 | + |
| 934 | + if input_ids is None: |
| 935 | + special_image_mask = inputs_embeds == self.get_input_embeddings()( |
| 936 | + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) |
| 937 | + ) |
| 938 | + else: |
| 939 | + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| 940 | + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| 941 | + |
| 942 | + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| 943 | + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] |
| 944 | + raise ValueError( |
| 945 | + f"Number of images does not match number of special image tokens in the input text. " |
| 946 | + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " |
| 947 | + "tokens from image embeddings." |
| 948 | + ) |
| 949 | + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| 950 | + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| 951 | + |
| 952 | + # mask out pad-token-ids in labels for BC |
| 953 | + if labels is not None and self.pad_token_id in labels: |
| 954 | + logger.warning_once( |
| 955 | + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " |
| 956 | + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", |
| 957 | + ) |
| 958 | + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) |
| 959 | + |
| 960 | + causal_mask = self._update_causal_mask( |
| 961 | + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training |
| 962 | + ) |
| 963 | + outputs = self.language_model( |
| 964 | + attention_mask=causal_mask, |
| 965 | + position_ids=position_ids, |
| 966 | + past_key_values=past_key_values, |
| 967 | + inputs_embeds=inputs_embeds, |
| 968 | + use_cache=use_cache, |
| 969 | + output_attentions=output_attentions, |
| 970 | + output_hidden_states=output_hidden_states, |
| 971 | + return_dict=return_dict, |
| 972 | + cache_position=cache_position, |
| 973 | + logits_to_keep=logits_to_keep, |
| 974 | + **lm_kwargs, |
| 975 | + ) |
| 976 | + |
| 977 | + logits = outputs.logits |
| 978 | + loss = None |
| 979 | + if labels is not None: |
| 980 | + # Upcast to float if we need to compute the loss to avoid potential precision issues |
| 981 | + logits = logits.float() |
| 982 | + shift_logits = logits[..., :-1, :] |
| 983 | + shift_labels = labels[..., 1:] |
| 984 | + if attention_mask is not None: |
| 985 | + # we use the input attention mask to shift the logits and labels, because it is 2D. |
| 986 | + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft |
| 987 | + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) |
| 988 | + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() |
| 989 | + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
| 990 | + else: |
| 991 | + shift_logits = shift_logits.contiguous() |
| 992 | + shift_labels = shift_labels.contiguous() |
| 993 | + # Flatten the tokens |
| 994 | + loss_fct = nn.CrossEntropyLoss() |
| 995 | + |
| 996 | + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) |
| 997 | + flat_labels = shift_labels.view(-1).to(shift_logits.device) |
| 998 | + loss = loss_fct(flat_logits, flat_labels) |
| 999 | + if not return_dict: |
| 1000 | + output = (logits,) + outputs[1:] |
| 1001 | + return (loss,) + output if loss is not None else output |
| 1002 | + |
| 1003 | + return Gemma3CausalLMOutputWithPast( |
| 1004 | + loss=loss, |
| 1005 | + logits=logits, |
| 1006 | + past_key_values=outputs.past_key_values, |
| 1007 | + hidden_states=outputs.hidden_states, |
| 1008 | + attentions=outputs.attentions, |
| 1009 | + image_hidden_states=image_features if pixel_values is not None else None, |
| 1010 | + ) |
| 1011 | + |
| 1012 | + def prepare_inputs_for_generation( |
| 1013 | + self, |
| 1014 | + input_ids, |
| 1015 | + past_key_values=None, |
| 1016 | + inputs_embeds=None, |
| 1017 | + cache_position=None, |
| 1018 | + position_ids=None, |
| 1019 | + pixel_values=None, |
| 1020 | + attention_mask=None, |
| 1021 | + token_type_ids=None, |
| 1022 | + use_cache=True, |
| 1023 | + logits_to_keep=None, |
| 1024 | + labels=None, |
| 1025 | + **kwargs, |
| 1026 | + ): |
| 1027 | + # Overwritten -- custom `position_ids` and `pixel_values` handling |
| 1028 | + model_inputs = self.language_model.prepare_inputs_for_generation( |
| 1029 | + input_ids, |
| 1030 | + past_key_values=past_key_values, |
| 1031 | + inputs_embeds=inputs_embeds, |
| 1032 | + attention_mask=attention_mask, |
| 1033 | + position_ids=position_ids, |
| 1034 | + cache_position=cache_position, |
| 1035 | + use_cache=use_cache, |
| 1036 | + logits_to_keep=logits_to_keep, |
| 1037 | + token_type_ids=token_type_ids, |
| 1038 | + **kwargs, |
| 1039 | + ) |
| 1040 | + |
| 1041 | + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore |
| 1042 | + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always |
| 1043 | + if cache_position[0] == 0: |
| 1044 | + model_inputs["pixel_values"] = pixel_values |
| 1045 | + is_training = token_type_ids is not None and labels is not None |
| 1046 | + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): |
| 1047 | + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids |
| 1048 | + causal_mask = self._update_causal_mask( |
| 1049 | + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training |
| 1050 | + ) |
| 1051 | + model_inputs["attention_mask"] = causal_mask |
| 1052 | + |
| 1053 | + return model_inputs |
| 1054 | + |
840 | 1055 |
|
841 | 1056 | __all__ = [
|
842 | 1057 | "Gemma3Config",
|
|
0 commit comments