Skip to content

Commit 9e75434

Browse files
authored
Merge branch 'main' into enhance-TrainerArgs-push_to_hub-functionality
2 parents dc03b66 + 2638d54 commit 9e75434

File tree

5 files changed

+275
-33
lines changed

5 files changed

+275
-33
lines changed

.github/workflows/self-comment-ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
runs-on: ubuntu-22.04
3030
name: Get PR number
3131
# For security: only allow team members to run
32-
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
32+
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
3333
outputs:
3434
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
3535
steps:

src/transformers/generation/utils.py

+42-18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import torch
2525
import torch.distributed as dist
26+
from packaging import version
2627
from torch import nn
2728
from torch.nn import functional as F
2829

@@ -1552,7 +1553,7 @@ def _prepare_generated_length(
15521553
return generation_config
15531554

15541555
def _prepare_generation_config(
1555-
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
1556+
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
15561557
) -> Tuple[GenerationConfig, Dict]:
15571558
"""
15581559
Prepares the base generation config, then applies any generation configuration options from kwargs. This
@@ -1591,23 +1592,38 @@ def _prepare_generation_config(
15911592

15921593
generation_config = copy.deepcopy(generation_config)
15931594

1594-
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
15951595
if not using_model_generation_config:
1596-
modified_values = {}
1597-
default_generation_config = GenerationConfig()
1598-
for key, default_value in default_generation_config.__dict__.items():
1599-
if key.startswith("_"): # metadata
1600-
continue
1601-
custom_gen_config_value = getattr(generation_config, key)
1602-
model_gen_config_value = getattr(self.generation_config, key)
1603-
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
1604-
modified_values[key] = model_gen_config_value
1605-
setattr(generation_config, key, model_gen_config_value)
1606-
if len(modified_values) > 0:
1607-
logger.warning_once(
1608-
f"`generation_config` default values have been modified to match model-specific defaults: "
1609-
f"{modified_values}. If this is not desired, please set these values explicitly."
1610-
)
1596+
# If `generation_config` is provided:
1597+
# - `use_model_defaults`: let's fallback ALL default values to the model's generation config
1598+
# - otherwise: legacy behavior, let's just make sure we have the tokens defined
1599+
model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
1600+
if use_model_defaults is True or (
1601+
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
1602+
):
1603+
modified_values = {}
1604+
default_generation_config = GenerationConfig()
1605+
for key, default_value in default_generation_config.__dict__.items():
1606+
if key.startswith("_") or key == "transformers_version": # metadata
1607+
continue
1608+
custom_gen_config_value = getattr(generation_config, key)
1609+
model_gen_config_value = getattr(self.generation_config, key)
1610+
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
1611+
modified_values[key] = model_gen_config_value
1612+
setattr(generation_config, key, model_gen_config_value)
1613+
if len(modified_values) > 0:
1614+
logger.warning_once(
1615+
f"`generation_config` default values have been modified to match model-specific defaults: "
1616+
f"{modified_values}. If this is not desired, please set these values explicitly."
1617+
)
1618+
else:
1619+
if generation_config.bos_token_id is None:
1620+
generation_config.bos_token_id = self.generation_config.bos_token_id
1621+
if generation_config.eos_token_id is None:
1622+
generation_config.eos_token_id = self.generation_config.eos_token_id
1623+
if generation_config.pad_token_id is None:
1624+
generation_config.pad_token_id = self.generation_config.pad_token_id
1625+
if generation_config.decoder_start_token_id is None:
1626+
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
16111627

16121628
# Finally, apply any passed kwargs
16131629
model_kwargs = generation_config.update(**kwargs)
@@ -1967,6 +1983,7 @@ def generate(
19671983
streamer: Optional["BaseStreamer"] = None,
19681984
negative_prompt_ids: Optional[torch.Tensor] = None,
19691985
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
1986+
use_model_defaults: Optional[bool] = None,
19701987
**kwargs,
19711988
) -> Union[GenerateOutput, torch.LongTensor]:
19721989
r"""
@@ -2031,6 +2048,11 @@ def generate(
20312048
size. This is an experimental feature, subject to breaking API changes in future versions.
20322049
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
20332050
Attention_mask for `negative_prompt_ids`.
2051+
use_model_defaults (`bool`, *optional*):
2052+
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
2053+
generation configuration (`model.generation_config`), as opposed to the global defaults
2054+
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
2055+
`True`.
20342056
kwargs (`Dict[str, Any]`, *optional*):
20352057
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
20362058
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -2058,7 +2080,9 @@ def generate(
20582080
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
20592081
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
20602082

2061-
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
2083+
generation_config, model_kwargs = self._prepare_generation_config(
2084+
generation_config, use_model_defaults, **kwargs
2085+
)
20622086
self._validate_model_kwargs(model_kwargs.copy())
20632087
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
20642088

src/transformers/models/gemma3/modeling_gemma3.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -1313,9 +1313,6 @@ def forward(
13131313
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
13141314
)
13151315

1316-
if position_ids is None:
1317-
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
1318-
13191316
# Merge text and images
13201317
if pixel_values is not None:
13211318
image_features = self.get_image_features(pixel_values)
@@ -1363,7 +1360,7 @@ def forward(
13631360
**lm_kwargs,
13641361
)
13651362

1366-
logits = outputs[0]
1363+
logits = outputs.logits
13671364
loss = None
13681365
if labels is not None:
13691366
# Upcast to float if we need to compute the loss to avoid potential precision issues
@@ -1427,9 +1424,6 @@ def prepare_inputs_for_generation(
14271424
**kwargs,
14281425
)
14291426

1430-
# position_ids in Gemma3 are 1-indexed
1431-
if model_inputs.get("position_ids") is not None:
1432-
model_inputs["position_ids"] += 1
14331427
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
14341428
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
14351429
if cache_position[0] == 0:

src/transformers/models/gemma3/modular_gemma3.py

+215
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@
3333
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
3434
from ...processing_utils import Unpack
3535
from ...utils import (
36+
add_start_docstrings_to_model_forward,
37+
is_torchdynamo_compiling,
3638
logging,
39+
replace_return_docstrings,
3740
)
41+
from ...utils.deprecation import deprecate_kwarg
3842
from ..bart.modeling_bart import BartScaledWordEmbedding
3943
from ..gemma2.configuration_gemma2 import Gemma2Config
4044
from ..gemma2.modeling_gemma2 import (
@@ -837,6 +841,217 @@ def _update_causal_mask(
837841

838842
return causal_mask
839843

844+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
845+
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
846+
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
847+
def forward(
848+
self,
849+
input_ids: torch.LongTensor = None,
850+
pixel_values: torch.FloatTensor = None,
851+
attention_mask: Optional[torch.Tensor] = None,
852+
position_ids: Optional[torch.LongTensor] = None,
853+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
854+
token_type_ids: Optional[torch.LongTensor] = None,
855+
cache_position: Optional[torch.LongTensor] = None,
856+
inputs_embeds: Optional[torch.FloatTensor] = None,
857+
labels: Optional[torch.LongTensor] = None,
858+
use_cache: Optional[bool] = None,
859+
output_attentions: Optional[bool] = None,
860+
output_hidden_states: Optional[bool] = None,
861+
return_dict: Optional[bool] = None,
862+
logits_to_keep: Union[int, torch.Tensor] = 0,
863+
**lm_kwargs,
864+
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
865+
r"""
866+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
867+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
868+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
869+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
870+
871+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
872+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
873+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
874+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
875+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
876+
This is useful when using packed tensor format (single dimension for batch and sequence length).
877+
878+
Returns:
879+
880+
Example:
881+
882+
```python
883+
>>> from PIL import Image
884+
>>> import requests
885+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
886+
887+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
888+
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
889+
890+
>>> prompt = "answer en Where is the cow standing?"
891+
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
892+
>>> image = Image.open(requests.get(url, stream=True).raw)
893+
894+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
895+
896+
>>> # Generate
897+
>>> generate_ids = model.generate(**inputs, max_length=30)
898+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
899+
"answer en Where is the cow standing?\nbeach"
900+
```"""
901+
902+
if (input_ids is None) ^ (inputs_embeds is not None):
903+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
904+
905+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
906+
output_hidden_states = (
907+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
908+
)
909+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
910+
911+
is_training = token_type_ids is not None and labels is not None
912+
913+
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
914+
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
915+
special_image_mask = input_ids == self.config.image_token_index
916+
llm_input_ids = input_ids.clone()
917+
llm_input_ids[special_image_mask] = 0
918+
else:
919+
llm_input_ids = input_ids
920+
921+
if inputs_embeds is None:
922+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
923+
924+
if cache_position is None:
925+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
926+
cache_position = torch.arange(
927+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
928+
)
929+
930+
# Merge text and images
931+
if pixel_values is not None:
932+
image_features = self.get_image_features(pixel_values)
933+
934+
if input_ids is None:
935+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
936+
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
937+
)
938+
else:
939+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
940+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
941+
942+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
943+
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
944+
raise ValueError(
945+
f"Number of images does not match number of special image tokens in the input text. "
946+
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
947+
"tokens from image embeddings."
948+
)
949+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
950+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
951+
952+
# mask out pad-token-ids in labels for BC
953+
if labels is not None and self.pad_token_id in labels:
954+
logger.warning_once(
955+
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
956+
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
957+
)
958+
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
959+
960+
causal_mask = self._update_causal_mask(
961+
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
962+
)
963+
outputs = self.language_model(
964+
attention_mask=causal_mask,
965+
position_ids=position_ids,
966+
past_key_values=past_key_values,
967+
inputs_embeds=inputs_embeds,
968+
use_cache=use_cache,
969+
output_attentions=output_attentions,
970+
output_hidden_states=output_hidden_states,
971+
return_dict=return_dict,
972+
cache_position=cache_position,
973+
logits_to_keep=logits_to_keep,
974+
**lm_kwargs,
975+
)
976+
977+
logits = outputs.logits
978+
loss = None
979+
if labels is not None:
980+
# Upcast to float if we need to compute the loss to avoid potential precision issues
981+
logits = logits.float()
982+
shift_logits = logits[..., :-1, :]
983+
shift_labels = labels[..., 1:]
984+
if attention_mask is not None:
985+
# we use the input attention mask to shift the logits and labels, because it is 2D.
986+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
987+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
988+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
989+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
990+
else:
991+
shift_logits = shift_logits.contiguous()
992+
shift_labels = shift_labels.contiguous()
993+
# Flatten the tokens
994+
loss_fct = nn.CrossEntropyLoss()
995+
996+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
997+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
998+
loss = loss_fct(flat_logits, flat_labels)
999+
if not return_dict:
1000+
output = (logits,) + outputs[1:]
1001+
return (loss,) + output if loss is not None else output
1002+
1003+
return Gemma3CausalLMOutputWithPast(
1004+
loss=loss,
1005+
logits=logits,
1006+
past_key_values=outputs.past_key_values,
1007+
hidden_states=outputs.hidden_states,
1008+
attentions=outputs.attentions,
1009+
image_hidden_states=image_features if pixel_values is not None else None,
1010+
)
1011+
1012+
def prepare_inputs_for_generation(
1013+
self,
1014+
input_ids,
1015+
past_key_values=None,
1016+
inputs_embeds=None,
1017+
cache_position=None,
1018+
position_ids=None,
1019+
pixel_values=None,
1020+
attention_mask=None,
1021+
token_type_ids=None,
1022+
use_cache=True,
1023+
logits_to_keep=None,
1024+
labels=None,
1025+
**kwargs,
1026+
):
1027+
# Overwritten -- custom `position_ids` and `pixel_values` handling
1028+
model_inputs = self.language_model.prepare_inputs_for_generation(
1029+
input_ids,
1030+
past_key_values=past_key_values,
1031+
inputs_embeds=inputs_embeds,
1032+
attention_mask=attention_mask,
1033+
position_ids=position_ids,
1034+
cache_position=cache_position,
1035+
use_cache=use_cache,
1036+
logits_to_keep=logits_to_keep,
1037+
token_type_ids=token_type_ids,
1038+
**kwargs,
1039+
)
1040+
1041+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1042+
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
1043+
if cache_position[0] == 0:
1044+
model_inputs["pixel_values"] = pixel_values
1045+
is_training = token_type_ids is not None and labels is not None
1046+
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
1047+
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
1048+
causal_mask = self._update_causal_mask(
1049+
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
1050+
)
1051+
model_inputs["attention_mask"] = causal_mask
1052+
1053+
return model_inputs
1054+
8401055

8411056
__all__ = [
8421057
"Gemma3Config",

0 commit comments

Comments
 (0)