Skip to content

Commit f11559b

Browse files
ajjimenoUbuntuLaverdeS
authored
Faster version of Chipper (#252)
Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Antonio Jimeno Yepes <[email protected]> Co-authored-by: Sebastian Laverde Alfonso <[email protected]>
1 parent c305d10 commit f11559b

File tree

4 files changed

+85
-14
lines changed

4 files changed

+85
-14
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
## 0.7.4-dev1
1+
## 0.7.4
22

3+
* Dynamic beam search size has been implemented for Chipper, the decoding process starts with a size = 1 and changes to size = 3 if repetitions appear.
34
* Fixed bug when PDFMiner predicts that an image text occupies the full page and removes annotations by Chipper.
45
* Added random seed to Chipper text generation to avoid differences between calls to Chipper.
56
* Allows user to use super-gradients model if they have a callback predict function, a yaml file with names field corresponding to classes and a path to the model weights

test_unstructured_inference/models/test_chippermodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def generate(*args, **kwargs):
3939

4040
def mock_initialize(self, *arg, **kwargs):
4141
self.model = MockModel()
42+
self.model.encoder = mock.MagicMock()
43+
self.stopping_criteria = mock.MagicMock()
4244
self.processor = mock.MagicMock()
4345
self.logits_processor = mock.MagicMock()
4446
self.input_ids = mock.MagicMock()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.4-dev1" # pragma: no cover
1+
__version__ = "0.7.4" # pragma: no cover

unstructured_inference/models/chipper.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from PIL.Image import Image
1111
from transformers import DonutProcessor, VisionEncoderDecoderModel
1212
from transformers.generation.logits_process import LogitsProcessor
13+
from transformers.generation.stopping_criteria import StoppingCriteria
1314

1415
from unstructured_inference.constants import Source
1516
from unstructured_inference.inference.elements import Rectangle
@@ -75,10 +76,19 @@ def initialize(
7576
self.source = source
7677
self.processor = DonutProcessor.from_pretrained(pre_trained_model_repo, token=auth_token)
7778
self.tokenizer = self.processor.tokenizer
78-
self.logits_processor = NoRepeatNGramLogitsProcessor(
79-
no_repeat_ngram_size,
80-
get_table_token_ids(self.processor),
81-
)
79+
self.logits_processor = [
80+
NoRepeatNGramLogitsProcessor(
81+
no_repeat_ngram_size,
82+
get_table_token_ids(self.processor),
83+
),
84+
]
85+
86+
self.stopping_criteria = [
87+
NGramRepetitonStoppingCriteria(
88+
repetition_window=30,
89+
skip_tokens=get_table_token_ids(self.processor),
90+
),
91+
]
8292

8393
self.model = VisionEncoderDecoderModel.from_pretrained(
8494
pre_trained_model_repo,
@@ -137,28 +147,45 @@ def predict_tokens(
137147
"""Predict tokens from image."""
138148
transformers.set_seed(42)
139149
with torch.no_grad():
140-
outputs = self.model.generate(
150+
encoder_outputs = self.model.encoder(
141151
self.processor(
142152
np.array(
143153
image,
144154
np.float32,
145155
),
146156
return_tensors="pt",
147157
).pixel_values.to(self.device),
148-
decoder_input_ids=self.input_ids,
149-
logits_processor=[self.logits_processor],
150-
max_length=self.max_length,
151-
do_sample=True,
152-
top_p=0.92,
153-
top_k=5,
158+
)
159+
160+
outputs = self.model.generate(
161+
encoder_outputs=encoder_outputs,
162+
input_ids=self.input_ids,
154163
no_repeat_ngram_size=0,
155-
num_beams=3,
164+
num_beams=1,
156165
return_dict_in_generate=True,
157166
output_attentions=True,
158167
output_scores=True,
159168
output_hidden_states=False,
169+
stopping_criteria=self.stopping_criteria,
160170
)
161171

172+
if (
173+
len(outputs["sequences"][0]) < self.max_length
174+
and outputs["sequences"][0][-1] != self.processor.tokenizer.eos_token_id
175+
):
176+
outputs = self.model.generate(
177+
encoder_outputs=encoder_outputs,
178+
input_ids=self.input_ids,
179+
logits_processor=self.logits_processor,
180+
do_sample=False,
181+
no_repeat_ngram_size=0,
182+
num_beams=5,
183+
return_dict_in_generate=True,
184+
output_attentions=True,
185+
output_scores=True,
186+
output_hidden_states=False,
187+
)
188+
162189
if "beam_indices" in outputs:
163190
offset = len(outputs["beam_indices"][0]) - len(outputs["cross_attentions"])
164191

@@ -459,6 +486,47 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
459486
)
460487

461488

489+
class NGramRepetitonStoppingCriteria(StoppingCriteria):
490+
def __init__(self, repetition_window: int, skip_tokens: set = set()):
491+
self.repetition_window = repetition_window
492+
self.skip_tokens = skip_tokens
493+
494+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
495+
"""
496+
Args:
497+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
498+
Indices of input sequence tokens in the vocabulary.
499+
500+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`]
501+
and [`PreTrainedTokenizer.__call__`] for details.
502+
503+
[What are input IDs?](../glossary#input-ids)
504+
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
505+
Prediction scores of a language modeling head. These can be scores for each
506+
vocabulary token before SoftMax or scores for each vocabulary token after SoftMax.
507+
kwargs (`Dict[str, Any]`, *optional*):
508+
Additional stopping criteria specific kwargs.
509+
510+
Return:
511+
`bool`. `False` indicates we should continue, `True` indicates we should stop.
512+
513+
"""
514+
num_batch_hypotheses = input_ids.shape[0]
515+
cur_len = input_ids.shape[-1]
516+
517+
for banned_tokens in _calc_banned_tokens(
518+
input_ids,
519+
num_batch_hypotheses,
520+
self.repetition_window,
521+
cur_len,
522+
):
523+
for token in banned_tokens:
524+
if token not in self.skip_tokens:
525+
return True
526+
527+
return False
528+
529+
462530
def _no_repeat_ngram_logits(
463531
input_ids: torch.LongTensor,
464532
cur_len: int,

0 commit comments

Comments
 (0)