Skip to content

Commit d80f53f

Browse files
author
Jonas
authored
[generate] Fix vocab_size access for multimodal models (#37937)
Implements last migrations for generation from `config.vocab_size` to `config.get_text_config().vocab.size` In doing so, we enable multimodal models to fully leverage all existing generation features.
1 parent 7819911 commit d80f53f

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/transformers/generation/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def _get_candidate_generator(
968968
atm_translator = AssistantVocabTranslatorCache.get_translator(
969969
target_tokenizer,
970970
assistant_tokenizer,
971-
self.config.vocab_size,
971+
self.config.get_text_config().vocab_size,
972972
assistant_model=assistant_model,
973973
assistant_prune_lm_head=True, # prune LM head of assistant model
974974
)
@@ -1234,7 +1234,9 @@ def _get_logits_processor(
12341234
# Watermarking should be after all logits processing is finished (see #34630)
12351235
if generation_config.watermarking_config is not None:
12361236
processors.append(
1237-
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
1237+
generation_config.watermarking_config.construct_processor(
1238+
self.config.get_text_config().vocab_size, device
1239+
)
12381240
)
12391241

12401242
# `LogitNormalization` should always be the last logit processor, when present
@@ -1412,7 +1414,7 @@ def compute_transition_scores(
14121414

14131415
# 3. Optionally normalize the logits (across the vocab dimension)
14141416
if normalize_logits:
1415-
scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1])
1417+
scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1])
14161418
scores = torch.nn.functional.log_softmax(scores, dim=1)
14171419
scores = scores.reshape(-1, scores.shape[-1])
14181420

@@ -1426,7 +1428,7 @@ def compute_transition_scores(
14261428
beam_indices[beam_indices_mask] = 0
14271429

14281430
# 6. multiply beam_indices with vocab size to gather correctly from scores
1429-
beam_sequence_indices = beam_indices * self.config.vocab_size
1431+
beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size
14301432

14311433
# 7. Define which indices contributed to scores
14321434
cut_idx = sequences.shape[-1] - max_beam_length

0 commit comments

Comments
 (0)