Skip to content

Commit 69f09b7

Browse files
committed
fix: t5 transformers compatibility
1 parent fc8f594 commit 69f09b7

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

rerankers/models/t5ranker.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,26 @@ def _greedy_decode(
255255
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
256256
next_token_logits = None
257257
for _ in range(length):
258-
model_inputs = model.prepare_inputs_for_generation(
259-
decode_ids,
260-
encoder_outputs=encoder_outputs,
261-
past=None,
262-
attention_mask=attention_mask,
263-
use_cache=True,
264-
)
265-
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
258+
try:
259+
model_inputs = model.prepare_inputs_for_generation(
260+
decode_ids,
261+
encoder_outputs=encoder_outputs,
262+
past=None,
263+
attention_mask=attention_mask,
264+
use_cache=True,
265+
)
266+
outputs = model(**model_inputs)
267+
except TypeError:
268+
# Newer transformers versions have deprecated `past`
269+
# Our aim is to maintain pipeline compatibility for as many people as possible
270+
# So currently, we maintain a forking path with this error. Might need to do it more elegantly later on (TODO).
271+
model_inputs = model.prepare_inputs_for_generation(
272+
decode_ids,
273+
encoder_outputs=encoder_outputs,
274+
attention_mask=attention_mask,
275+
use_cache=True,
276+
)
277+
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
266278
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
267279
decode_ids = torch.cat(
268280
[decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1

0 commit comments

Comments
 (0)