Skip to content

Commit afc4ece

Browse files
[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
1 parent 397f819 commit afc4ece

20 files changed

+394
-260
lines changed

docs/source/model_doc/encoderdecoder.rst

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
Encoder Decoder Models
22
------------------------
33

4-
This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model.
4+
The :class:`~transformers.EncoderDecoderModel` can be used to initialize a sequence-to-sequence model with any pre-trained autoencoding model as the encoder and any pre-trained autoregressive model as the decoder.
55

6-
The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input.
7-
The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method.
6+
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
87

9-
An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
8+
After such an :class:`~transformers.EncoderDecoderModel` has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
9+
10+
An application of this architecture could be to leverage two pre-trained :obj:`transformers.BertModel` models as the encoder and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
1011

1112

1213
``EncoderDecoderConfig``

src/transformers/generation_utils.py

+34-36
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch import Tensor
2121
from torch.nn import functional as F
2222

23+
from .file_utils import ModelOutput
2324
from .utils import logging
2425

2526

@@ -46,14 +47,6 @@ def adjust_logits_during_generation(self, logits, **kwargs):
4647
"""
4748
return logits
4849

49-
def _use_cache(self, outputs, use_cache):
50-
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
51-
if len(outputs) <= 1 or use_cache is False:
52-
return False
53-
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
54-
return False
55-
return True
56-
5750
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
5851
"""
5952
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
@@ -137,7 +130,7 @@ def generate(
137130
attention_mask: Optional[torch.LongTensor] = None,
138131
decoder_start_token_id: Optional[int] = None,
139132
use_cache: Optional[bool] = None,
140-
**model_specific_kwargs
133+
**model_kwargs
141134
) -> torch.LongTensor:
142135
r"""
143136
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
@@ -208,7 +201,7 @@ def generate(
208201
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
209202
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
210203
speed up decoding.
211-
model_specific_kwargs:
204+
model_kwargs:
212205
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
213206
214207
Return:
@@ -400,7 +393,7 @@ def generate(
400393

401394
# get encoder and store encoder outputs
402395
encoder = self.get_encoder()
403-
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
396+
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
404397

405398
# Expand input ids if num_beams > 1 or num_return_sequences > 1
406399
if num_return_sequences > 1 or num_beams > 1:
@@ -428,8 +421,8 @@ def generate(
428421
cur_len = 1
429422

430423
assert (
431-
batch_size == encoder_outputs[0].shape[0]
432-
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
424+
batch_size == encoder_outputs.last_hidden_state.shape[0]
425+
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
433426

434427
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
435428
expanded_batch_idxs = (
@@ -439,11 +432,16 @@ def generate(
439432
.view(-1)
440433
.to(input_ids.device)
441434
)
435+
442436
# expand encoder_outputs
443-
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
437+
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
438+
0, expanded_batch_idxs
439+
)
440+
441+
# save encoder_outputs in `model_kwargs`
442+
model_kwargs["encoder_outputs"] = encoder_outputs
444443

445444
else:
446-
encoder_outputs = None
447445
cur_len = input_ids.shape[-1]
448446

449447
assert (
@@ -471,10 +469,9 @@ def generate(
471469
length_penalty=length_penalty,
472470
num_beams=num_beams,
473471
vocab_size=vocab_size,
474-
encoder_outputs=encoder_outputs,
475472
attention_mask=attention_mask,
476473
use_cache=use_cache,
477-
model_specific_kwargs=model_specific_kwargs,
474+
model_kwargs=model_kwargs,
478475
)
479476
else:
480477
output = self._generate_no_beam_search(
@@ -492,10 +489,9 @@ def generate(
492489
pad_token_id=pad_token_id,
493490
eos_token_id=eos_token_id,
494491
batch_size=effective_batch_size,
495-
encoder_outputs=encoder_outputs,
496492
attention_mask=attention_mask,
497493
use_cache=use_cache,
498-
model_specific_kwargs=model_specific_kwargs,
494+
model_kwargs=model_kwargs,
499495
)
500496

501497
return output
@@ -516,10 +512,9 @@ def _generate_no_beam_search(
516512
pad_token_id,
517513
eos_token_id,
518514
batch_size,
519-
encoder_outputs,
520515
attention_mask,
521516
use_cache,
522-
model_specific_kwargs,
517+
model_kwargs,
523518
):
524519
"""Generate sequences for each example without beam search (num_beams == 1).
525520
All returned sequence are generated independantly.
@@ -528,15 +523,14 @@ def _generate_no_beam_search(
528523
unfinished_sents = input_ids.new(batch_size).fill_(1)
529524
sent_lengths = input_ids.new(batch_size).fill_(max_length)
530525

531-
past = (encoder_outputs, None) if encoder_outputs is not None else None
532-
526+
past = None
533527
while cur_len < max_length:
534528
model_inputs = self.prepare_inputs_for_generation(
535-
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
529+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
536530
)
537531

538-
outputs = self(**model_inputs)
539-
next_token_logits = outputs[0][:, -1, :]
532+
outputs = self(**model_inputs, return_dict=True)
533+
next_token_logits = outputs.logits[:, -1, :]
540534

541535
scores = self.postprocess_next_token_scores(
542536
scores=next_token_logits,
@@ -553,8 +547,10 @@ def _generate_no_beam_search(
553547
)
554548

555549
# if model has past, then set the past variable to speed up decoding
556-
if self._use_cache(outputs, use_cache):
557-
past = outputs[1]
550+
if "past_key_values" in outputs:
551+
past = outputs.past_key_values
552+
elif "mems" in outputs:
553+
past = outputs.mems
558554

559555
if do_sample:
560556
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -621,10 +617,9 @@ def _generate_beam_search(
621617
length_penalty,
622618
num_beams,
623619
vocab_size,
624-
encoder_outputs,
625620
attention_mask,
626621
use_cache,
627-
model_specific_kwargs,
622+
model_kwargs,
628623
):
629624
"""Generate sequences for each example with beam search."""
630625

@@ -643,21 +638,24 @@ def _generate_beam_search(
643638
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
644639

645640
# cache compute states
646-
past = (encoder_outputs, None) if encoder_outputs is not None else None
641+
past = None
647642

648643
# done sentences
649644
done = [False for _ in range(batch_size)]
650645

651646
while cur_len < max_length:
652647
model_inputs = self.prepare_inputs_for_generation(
653-
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
648+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
654649
)
655-
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
656-
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
650+
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
651+
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
657652

658653
# if model has past, then set the past variable to speed up decoding
659-
if self._use_cache(outputs, use_cache):
660-
past = outputs[1]
654+
if "past_key_values" in outputs:
655+
past = outputs.past_key_values
656+
elif "mems" in outputs:
657+
past = outputs.mems
658+
661659
if self.config.is_encoder_decoder and do_sample is False:
662660
# TODO (PVP) still a bit hacky here - there might be a better solution
663661
next_token_logits = self.adjust_logits_during_generation(

0 commit comments

Comments
 (0)