Skip to content

Commit 727ab9d

Browse files
patil-surajpatrickvonplaten
authored andcommitted
[RAG] fix generate (#10094)
* fix rag generate and tests * put back adjust_logits_during_generation * tests are okay Co-authored-by: Patrick von Platen <[email protected]>
1 parent c95fae6 commit 727ab9d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/transformers/models/rag/modeling_rag.py

+6
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,7 @@ def generate(
13061306
eos_token_id=None,
13071307
length_penalty=None,
13081308
no_repeat_ngram_size=None,
1309+
encoder_no_repeat_ngram_size=None,
13091310
repetition_penalty=None,
13101311
bad_words_ids=None,
13111312
num_return_sequences=None,
@@ -1372,6 +1373,9 @@ def generate(
13721373
order to encourage the model to produce longer sequences.
13731374
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
13741375
If set to int > 0, all ngrams of that size can only occur once.
1376+
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
1377+
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
1378+
``decoder_input_ids``.
13751379
bad_words_ids(:obj:`List[int]`, `optional`):
13761380
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
13771381
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
@@ -1490,6 +1494,8 @@ def extend_enc_output(tensor, num_beams=None):
14901494
pre_processor = self._get_logits_processor(
14911495
repetition_penalty=repetition_penalty,
14921496
no_repeat_ngram_size=no_repeat_ngram_size,
1497+
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
1498+
encoder_input_ids=context_input_ids,
14931499
bad_words_ids=bad_words_ids,
14941500
min_length=min_length,
14951501
eos_token_id=eos_token_id,

0 commit comments

Comments
 (0)