@@ -255,14 +255,26 @@ def _greedy_decode(
255255 encoder_outputs = model .get_encoder ()(input_ids , attention_mask = attention_mask )
256256 next_token_logits = None
257257 for _ in range (length ):
258- model_inputs = model .prepare_inputs_for_generation (
259- decode_ids ,
260- encoder_outputs = encoder_outputs ,
261- past = None ,
262- attention_mask = attention_mask ,
263- use_cache = True ,
264- )
265- outputs = model (** model_inputs ) # (batch_size, cur_len, vocab_size)
258+ try :
259+ model_inputs = model .prepare_inputs_for_generation (
260+ decode_ids ,
261+ encoder_outputs = encoder_outputs ,
262+ past = None ,
263+ attention_mask = attention_mask ,
264+ use_cache = True ,
265+ )
266+ outputs = model (** model_inputs )
267+ except TypeError :
268+ # Newer transformers versions have deprecated `past`
269+ # Our aim is to maintain pipeline compatibility for as many people as possible
270+ # So currently, we maintain a forking path with this error. Might need to do it more elegantly later on (TODO).
271+ model_inputs = model .prepare_inputs_for_generation (
272+ decode_ids ,
273+ encoder_outputs = encoder_outputs ,
274+ attention_mask = attention_mask ,
275+ use_cache = True ,
276+ )
277+ outputs = model (** model_inputs ) # (batch_size, cur_len, vocab_size)
266278 next_token_logits = outputs [0 ][:, - 1 , :] # (batch_size, vocab_size)
267279 decode_ids = torch .cat (
268280 [decode_ids , next_token_logits .max (1 )[1 ].unsqueeze (- 1 )], dim = - 1
0 commit comments