@@ -1306,6 +1306,7 @@ def generate(
1306
1306
eos_token_id = None ,
1307
1307
length_penalty = None ,
1308
1308
no_repeat_ngram_size = None ,
1309
+ encoder_no_repeat_ngram_size = None ,
1309
1310
repetition_penalty = None ,
1310
1311
bad_words_ids = None ,
1311
1312
num_return_sequences = None ,
@@ -1372,6 +1373,9 @@ def generate(
1372
1373
order to encourage the model to produce longer sequences.
1373
1374
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
1374
1375
If set to int > 0, all ngrams of that size can only occur once.
1376
+ encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
1377
+ If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
1378
+ ``decoder_input_ids``.
1375
1379
bad_words_ids(:obj:`List[int]`, `optional`):
1376
1380
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
1377
1381
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
@@ -1490,6 +1494,8 @@ def extend_enc_output(tensor, num_beams=None):
1490
1494
pre_processor = self ._get_logits_processor (
1491
1495
repetition_penalty = repetition_penalty ,
1492
1496
no_repeat_ngram_size = no_repeat_ngram_size ,
1497
+ encoder_no_repeat_ngram_size = encoder_no_repeat_ngram_size ,
1498
+ encoder_input_ids = context_input_ids ,
1493
1499
bad_words_ids = bad_words_ids ,
1494
1500
min_length = min_length ,
1495
1501
eos_token_id = eos_token_id ,
0 commit comments