diff --git a/docling/cli/main.py b/docling/cli/main.py index 82c57efb4..67afe69a6 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -63,6 +63,8 @@ GOT2_TRANSFORMERS, GRANITE_VISION_OLLAMA, GRANITE_VISION_TRANSFORMERS, + GRANITEDOCLING_MLX, + GRANITEDOCLING_TRANSFORMERS, SMOLDOCLING_MLX, SMOLDOCLING_TRANSFORMERS, SMOLDOCLING_VLLM, @@ -655,6 +657,18 @@ def convert( # noqa: C901 "To run SmolDocling faster, please install mlx-vlm:\n" "pip install mlx-vlm" ) + elif vlm_model == VlmModelType.GRANITEDOCLING: + pipeline_options.vlm_options = GRANITEDOCLING_TRANSFORMERS + if sys.platform == "darwin": + try: + import mlx_vlm + + pipeline_options.vlm_options = GRANITEDOCLING_MLX + except ImportError: + _log.warning( + "To run SmolDocling faster, please install mlx-vlm:\n" + "pip install mlx-vlm" + ) elif vlm_model == VlmModelType.SMOLDOCLING_VLLM: pipeline_options.vlm_options = SMOLDOCLING_VLLM diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index 54d819780..dd796fba0 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -18,6 +18,34 @@ _log = logging.getLogger(__name__) +# Granite-Docling +GRANITEDOCLING_TRANSFORMERS = InlineVlmOptions( + repo_id="ds4sd/granite-docling-258m-2-9-2025-v2", + prompt="Convert this page to docling.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.MLX, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + ], + scale=2.0, + temperature=0.0, + max_new_tokens=8192, + stop_strings=["", "<|end_of_text|>"], +) + +GRANITEDOCLING_MLX = InlineVlmOptions( + repo_id="ds4sd/granite-docling-258m-2-9-2025-v2-mlx-bf16", + prompt="Convert this page to docling.", + response_format=ResponseFormat.DOCTAGS, + inference_framework=InferenceFramework.MLX, + supported_devices=[AcceleratorDevice.MPS], + scale=2.0, + temperature=0.0, + max_new_tokens=8192, + stop_strings=["", "<|end_of_text|>"], +) + # SmolDocling SMOLDOCLING_MLX = InlineVlmOptions( repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16", @@ -272,3 +300,4 @@ class VlmModelType(str, Enum): GRANITE_VISION_VLLM = "granite_vision_vllm" GRANITE_VISION_OLLAMA = "granite_vision_ollama" GOT_OCR_2 = "got_ocr_2" + GRANITEDOCLING = "granite_docling" diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 52e786f2a..1f610b6b9 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -1,4 +1,5 @@ import logging +import re import threading import time from collections.abc import Iterable @@ -6,6 +7,7 @@ from typing import Optional, Union import numpy as np +from docling_core.types.doc import BoundingBox, CoordOrigin, DocItem from PIL.Image import Image from docling.datamodel.accelerator_options import ( @@ -27,6 +29,37 @@ _MLX_GLOBAL_LOCK = threading.Lock() +class DoclingStopping: + def __init__(self): + self.pattern = re.compile( + r"<([a-z\_\-]+)>(<)?$" + ) + + self.bboxs: list[BoundingBox] = [] + + def overlaps(self, text: str) -> bool: + match = re.search(self.pattern, text) + if match: + tag_name = match.group(1) # First group: button + loc1 = float(match.group(2)) # Second group: 100 + loc2 = float(match.group(3)) # Third group: 200 + loc3 = float(match.group(4)) # Fourth group: 150 + loc4 = float(match.group(5)) # Fifth group: 50 + + bbox = BoundingBox( + l=loc1, b=loc2, r=loc3, t=loc4, coord_origin=CoordOrigin.BOTTOMLEFT + ) + + for _ in self.bboxs: + if bbox.intersection_over_self(_) > 1.0e-6: + _log.info(f"{bbox} overlaps with {_}") + return True + + self.bboxs.append(bbox) + + return False + + class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): def __init__( self, @@ -68,6 +101,26 @@ def __init__( self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) + self._find_doctags_labels() + + def _find_doctags_labels(self): + """Simple iteration over vocabulary""" + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + + self.special_tokens: dict[str, int] = {} + if hasattr(tokenizer, "vocab"): + # vocab is usually a dict mapping token_text -> token_id + for token_text, token_id in tokenizer.vocab.items(): + if re.match(r"^<[a-z\_\-\d]+>$", token_text): + print(f"Token ID: {token_id:6d} | Text: '{token_text}'") + self.special_tokens[token_text] = token_id + else: + print("Tokenizer doesn't have a 'vocab' attribute") + def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: @@ -199,6 +252,8 @@ def process_images( tokens: list[VlmPredictionToken] = [] output = "" + stopping_criteria = DoclingStopping() + # Use stream_generate for proper stop string handling for token in self.stream_generate( self.vlm_model, @@ -209,6 +264,10 @@ def process_images( verbose=False, temp=self.temperature, ): + _log.info( + f"logprobs.shape: {token.logprobs.shape} with token: {token}" + ) + # Collect token information if len(token.logprobs.shape) == 1: tokens.append( @@ -218,6 +277,26 @@ def process_images( logprob=token.logprobs[token.token], ) ) + if token.text in self.special_tokens: + # Get logprobs for all special tokens + special_token_logprobs = [] + for token_text, token_id in self.special_tokens.items(): + logprob = token.logprobs[token_id] + special_token_logprobs.append( + (token_text, token_id, logprob) + ) + + # Sort by logprob (highest first) and take top 5 + top_5_special = sorted( + special_token_logprobs, key=lambda x: x[2], reverse=True + )[:5] + + print("Top 5 special tokens by logprob:") + for rank, (t, token_id, logprob) in enumerate( + top_5_special, 1 + ): + print(f" {rank}. {t}: {logprob:0.3f}") + elif ( len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1 ): @@ -228,6 +307,11 @@ def process_images( logprob=token.logprobs[0, token.token], ) ) + + if token.text in self.special_tokens: + for t, i in self.special_tokens.items(): + print(f"{t}: {token.logprobs[0, i]:0.3f}") + else: _log.warning( f"incompatible shape for logprobs: {token.logprobs.shape}" @@ -235,6 +319,10 @@ def process_images( output += token.text + if stopping_criteria.overlaps(output): + _log.debug("Stopping generation due to overlapping bbox") + break + # Check for any configured stop strings if self.vlm_options.stop_strings: if any( @@ -246,7 +334,7 @@ def process_images( generation_time = time.time() - start_time - _log.debug( + _log.info( f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)." )