From e202ef86a74efc367b8082318c26fb4342a81d90 Mon Sep 17 00:00:00 2001 From: Matteo-Omenetti Date: Mon, 10 Mar 2025 10:48:37 +0100 Subject: [PATCH 1/2] updated tests to use new versions of code_formula model and figure classiier + changes in the CodeFormulaPredictor class Signed-off-by: Matteo-Omenetti --- .../code_formula_predictor.py | 62 ++++--------------- tests/test_code_formula_predictor.py | 2 +- tests/test_document_figure_classifier.py | 2 +- 3 files changed, 14 insertions(+), 52 deletions(-) diff --git a/docling_ibm_models/code_formula_model/code_formula_predictor.py b/docling_ibm_models/code_formula_model/code_formula_predictor.py index 1768fd8..292034e 100644 --- a/docling_ibm_models/code_formula_model/code_formula_predictor.py +++ b/docling_ibm_models/code_formula_model/code_formula_predictor.py @@ -18,21 +18,19 @@ _log = logging.getLogger(__name__) -class StopOnString(StoppingCriteria): - def __init__(self, tokenizer, stop_string): - self.stop_token_ids = tokenizer.encode(stop_string, add_special_tokens=False) +from transformers import StoppingCriteria - def __call__(self, input_ids, scores, **kwargs): - for sequence in input_ids: - sequence_list = sequence.tolist() - for i in range(len(sequence_list) - len(self.stop_token_ids) + 1): - if ( - sequence_list[i : i + len(self.stop_token_ids)] - == self.stop_token_ids - ): - return True - return False +class RepeatTokenStoppingCriteria(StoppingCriteria): + def __init__(self, repeat_count=100): + self.repeat_count = repeat_count + + def __call__(self, input_ids, scores, **kwargs): + if input_ids.shape[-1] < self.repeat_count: + return False + # Check the last `repeat_count` tokens + last_tokens = input_ids[0, -self.repeat_count :] + return bool((last_tokens == last_tokens[0]).all()) class CodeFormulaPredictor: """ @@ -143,31 +141,6 @@ def _get_prompt(self, label: str) -> str: return prompt - def _strip(self, text: str): - """ - Removes any occurrences of the substrings in remove_list from the end of text. - - Parameters - ---------- - text : str - The original string. - - Returns - ------- - str - The trimmed string. - """ - remove_list = [r"\quad", r"\\", r"\,", " c c c c", " l l l l l"] - changed = True - while changed: - changed = False - for substr in remove_list: - if text.endswith(substr): - text = text[: -len(substr)] - changed = True - - return text.strip() - @torch.inference_mode() def predict( self, @@ -239,15 +212,7 @@ def predict( prompt_ids = tokenized["input_ids"] attention_mask = tokenized["attention_mask"] - stopping_criteria = StoppingCriteriaList( - [ - StopOnString(self._tokenizer, r" \quad \quad \quad \quad"), - StopOnString(self._tokenizer, r" \\ \\ \\ \\"), - StopOnString(self._tokenizer, r" \, \, \, \,"), - StopOnString(self._tokenizer, r" c c c c c c c c c c c c c c c c"), - StopOnString(self._tokenizer, r" l l l l l l l l l l l l l l l l l"), - ] - ) + stopping_criteria = StoppingCriteriaList([RepeatTokenStoppingCriteria()]) if self._device == "cpu": output_ids_list = self._model.generate( @@ -258,7 +223,6 @@ def predict( temperature=temperature, max_new_tokens=4096 - prompt_ids.shape[1], use_cache=True, - no_repeat_ngram_size=200, stopping_criteria=stopping_criteria, ) else: @@ -270,13 +234,11 @@ def predict( temperature=temperature, max_new_tokens=4096 - prompt_ids.shape[1], use_cache=True, - no_repeat_ngram_size=200, stopping_criteria=stopping_criteria, ) outputs = self._tokenizer.batch_decode( output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True ) - outputs = [self._strip(output) for output in outputs] return outputs diff --git a/tests/test_code_formula_predictor.py b/tests/test_code_formula_predictor.py index ef8633a..e6a52a8 100644 --- a/tests/test_code_formula_predictor.py +++ b/tests/test_code_formula_predictor.py @@ -37,7 +37,7 @@ def init() -> dict: } # Download models from HF - artifact_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.1") + artifact_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.2") init["artifact_path"] = artifact_path diff --git a/tests/test_document_figure_classifier.py b/tests/test_document_figure_classifier.py index 92c44f9..f3918dc 100644 --- a/tests/test_document_figure_classifier.py +++ b/tests/test_document_figure_classifier.py @@ -38,7 +38,7 @@ def init() -> dict: # Download models from HF init["artifact_path"] = snapshot_download( - repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.0" + repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.1" ) return init From 5d3cd3bbf63bcc35891b717f106c0e2ed5f31d0e Mon Sep 17 00:00:00 2001 From: Matteo-Omenetti Date: Mon, 10 Mar 2025 10:54:53 +0100 Subject: [PATCH 2/2] updated tests to use new versions of code_formula model and figure classiier + changes in the CodeFormulaPredictor class Signed-off-by: Matteo-Omenetti --- .../code_formula_model/code_formula_predictor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docling_ibm_models/code_formula_model/code_formula_predictor.py b/docling_ibm_models/code_formula_model/code_formula_predictor.py index 292034e..a7c5b3f 100644 --- a/docling_ibm_models/code_formula_model/code_formula_predictor.py +++ b/docling_ibm_models/code_formula_model/code_formula_predictor.py @@ -32,6 +32,7 @@ def __call__(self, input_ids, scores, **kwargs): last_tokens = input_ids[0, -self.repeat_count :] return bool((last_tokens == last_tokens[0]).all()) + class CodeFormulaPredictor: """ Code and Formula Predictor using a multi-modal vision-language model. @@ -228,7 +229,8 @@ def predict( else: with torch.autocast(device_type=self._device, dtype=torch.bfloat16): output_ids_list = self._model.generate( - prompt_ids, + input_ids=prompt_ids, + attention_mask=attention_mask, images=images_tensor, do_sample=do_sample, temperature=temperature,