|
10 | 10 | from PIL.Image import Image |
11 | 11 | from transformers import DonutProcessor, VisionEncoderDecoderModel |
12 | 12 | from transformers.generation.logits_process import LogitsProcessor |
| 13 | +from transformers.generation.stopping_criteria import StoppingCriteria |
13 | 14 |
|
14 | 15 | from unstructured_inference.constants import Source |
15 | 16 | from unstructured_inference.inference.elements import Rectangle |
@@ -75,10 +76,19 @@ def initialize( |
75 | 76 | self.source = source |
76 | 77 | self.processor = DonutProcessor.from_pretrained(pre_trained_model_repo, token=auth_token) |
77 | 78 | self.tokenizer = self.processor.tokenizer |
78 | | - self.logits_processor = NoRepeatNGramLogitsProcessor( |
79 | | - no_repeat_ngram_size, |
80 | | - get_table_token_ids(self.processor), |
81 | | - ) |
| 79 | + self.logits_processor = [ |
| 80 | + NoRepeatNGramLogitsProcessor( |
| 81 | + no_repeat_ngram_size, |
| 82 | + get_table_token_ids(self.processor), |
| 83 | + ), |
| 84 | + ] |
| 85 | + |
| 86 | + self.stopping_criteria = [ |
| 87 | + NGramRepetitonStoppingCriteria( |
| 88 | + repetition_window=30, |
| 89 | + skip_tokens=get_table_token_ids(self.processor), |
| 90 | + ), |
| 91 | + ] |
82 | 92 |
|
83 | 93 | self.model = VisionEncoderDecoderModel.from_pretrained( |
84 | 94 | pre_trained_model_repo, |
@@ -137,28 +147,45 @@ def predict_tokens( |
137 | 147 | """Predict tokens from image.""" |
138 | 148 | transformers.set_seed(42) |
139 | 149 | with torch.no_grad(): |
140 | | - outputs = self.model.generate( |
| 150 | + encoder_outputs = self.model.encoder( |
141 | 151 | self.processor( |
142 | 152 | np.array( |
143 | 153 | image, |
144 | 154 | np.float32, |
145 | 155 | ), |
146 | 156 | return_tensors="pt", |
147 | 157 | ).pixel_values.to(self.device), |
148 | | - decoder_input_ids=self.input_ids, |
149 | | - logits_processor=[self.logits_processor], |
150 | | - max_length=self.max_length, |
151 | | - do_sample=True, |
152 | | - top_p=0.92, |
153 | | - top_k=5, |
| 158 | + ) |
| 159 | + |
| 160 | + outputs = self.model.generate( |
| 161 | + encoder_outputs=encoder_outputs, |
| 162 | + input_ids=self.input_ids, |
154 | 163 | no_repeat_ngram_size=0, |
155 | | - num_beams=3, |
| 164 | + num_beams=1, |
156 | 165 | return_dict_in_generate=True, |
157 | 166 | output_attentions=True, |
158 | 167 | output_scores=True, |
159 | 168 | output_hidden_states=False, |
| 169 | + stopping_criteria=self.stopping_criteria, |
160 | 170 | ) |
161 | 171 |
|
| 172 | + if ( |
| 173 | + len(outputs["sequences"][0]) < self.max_length |
| 174 | + and outputs["sequences"][0][-1] != self.processor.tokenizer.eos_token_id |
| 175 | + ): |
| 176 | + outputs = self.model.generate( |
| 177 | + encoder_outputs=encoder_outputs, |
| 178 | + input_ids=self.input_ids, |
| 179 | + logits_processor=self.logits_processor, |
| 180 | + do_sample=False, |
| 181 | + no_repeat_ngram_size=0, |
| 182 | + num_beams=5, |
| 183 | + return_dict_in_generate=True, |
| 184 | + output_attentions=True, |
| 185 | + output_scores=True, |
| 186 | + output_hidden_states=False, |
| 187 | + ) |
| 188 | + |
162 | 189 | if "beam_indices" in outputs: |
163 | 190 | offset = len(outputs["beam_indices"][0]) - len(outputs["cross_attentions"]) |
164 | 191 |
|
@@ -459,6 +486,47 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to |
459 | 486 | ) |
460 | 487 |
|
461 | 488 |
|
| 489 | +class NGramRepetitonStoppingCriteria(StoppingCriteria): |
| 490 | + def __init__(self, repetition_window: int, skip_tokens: set = set()): |
| 491 | + self.repetition_window = repetition_window |
| 492 | + self.skip_tokens = skip_tokens |
| 493 | + |
| 494 | + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| 495 | + """ |
| 496 | + Args: |
| 497 | + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| 498 | + Indices of input sequence tokens in the vocabulary. |
| 499 | +
|
| 500 | + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] |
| 501 | + and [`PreTrainedTokenizer.__call__`] for details. |
| 502 | +
|
| 503 | + [What are input IDs?](../glossary#input-ids) |
| 504 | + scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): |
| 505 | + Prediction scores of a language modeling head. These can be scores for each |
| 506 | + vocabulary token before SoftMax or scores for each vocabulary token after SoftMax. |
| 507 | + kwargs (`Dict[str, Any]`, *optional*): |
| 508 | + Additional stopping criteria specific kwargs. |
| 509 | +
|
| 510 | + Return: |
| 511 | + `bool`. `False` indicates we should continue, `True` indicates we should stop. |
| 512 | +
|
| 513 | + """ |
| 514 | + num_batch_hypotheses = input_ids.shape[0] |
| 515 | + cur_len = input_ids.shape[-1] |
| 516 | + |
| 517 | + for banned_tokens in _calc_banned_tokens( |
| 518 | + input_ids, |
| 519 | + num_batch_hypotheses, |
| 520 | + self.repetition_window, |
| 521 | + cur_len, |
| 522 | + ): |
| 523 | + for token in banned_tokens: |
| 524 | + if token not in self.skip_tokens: |
| 525 | + return True |
| 526 | + |
| 527 | + return False |
| 528 | + |
| 529 | + |
462 | 530 | def _no_repeat_ngram_logits( |
463 | 531 | input_ids: torch.LongTensor, |
464 | 532 | cur_len: int, |
|
0 commit comments