Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docling/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
GOT2_TRANSFORMERS,
GRANITE_VISION_OLLAMA,
GRANITE_VISION_TRANSFORMERS,
GRANITEDOCLING_MLX,
GRANITEDOCLING_TRANSFORMERS,
SMOLDOCLING_MLX,
SMOLDOCLING_TRANSFORMERS,
SMOLDOCLING_VLLM,
Expand Down Expand Up @@ -655,6 +657,18 @@ def convert( # noqa: C901
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
elif vlm_model == VlmModelType.GRANITEDOCLING:
pipeline_options.vlm_options = GRANITEDOCLING_TRANSFORMERS
if sys.platform == "darwin":
try:
import mlx_vlm

pipeline_options.vlm_options = GRANITEDOCLING_MLX
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
elif vlm_model == VlmModelType.SMOLDOCLING_VLLM:
pipeline_options.vlm_options = SMOLDOCLING_VLLM

Expand Down
29 changes: 29 additions & 0 deletions docling/datamodel/vlm_model_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,34 @@
_log = logging.getLogger(__name__)


# Granite-Docling
GRANITEDOCLING_TRANSFORMERS = InlineVlmOptions(
repo_id="ds4sd/granite-docling-258m-2-9-2025-v2",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
],
scale=2.0,
temperature=0.0,
max_new_tokens=8192,
stop_strings=["</doctag>", "<|end_of_text|>"],
)

GRANITEDOCLING_MLX = InlineVlmOptions(
repo_id="ds4sd/granite-docling-258m-2-9-2025-v2-mlx-bf16",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
max_new_tokens=8192,
stop_strings=["</doctag>", "<|end_of_text|>"],
)

# SmolDocling
SMOLDOCLING_MLX = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
Expand Down Expand Up @@ -272,3 +300,4 @@ class VlmModelType(str, Enum):
GRANITE_VISION_VLLM = "granite_vision_vllm"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
GOT_OCR_2 = "got_ocr_2"
GRANITEDOCLING = "granite_docling"
90 changes: 89 additions & 1 deletion docling/models/vlm_models_inline/mlx_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import re
import threading
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional, Union

import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin, DocItem
from PIL.Image import Image

from docling.datamodel.accelerator_options import (
Expand All @@ -27,6 +29,37 @@
_MLX_GLOBAL_LOCK = threading.Lock()


class DoclingStopping:
def __init__(self):
self.pattern = re.compile(
r"<([a-z\_\-]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>(<)?$"
)

self.bboxs: list[BoundingBox] = []

def overlaps(self, text: str) -> bool:
match = re.search(self.pattern, text)
if match:
tag_name = match.group(1) # First group: button
loc1 = float(match.group(2)) # Second group: 100
loc2 = float(match.group(3)) # Third group: 200
loc3 = float(match.group(4)) # Fourth group: 150
loc4 = float(match.group(5)) # Fifth group: 50

bbox = BoundingBox(
l=loc1, b=loc2, r=loc3, t=loc4, coord_origin=CoordOrigin.BOTTOMLEFT
)

for _ in self.bboxs:
if bbox.intersection_over_self(_) > 1.0e-6:
_log.info(f"{bbox} overlaps with {_}")
return True

self.bboxs.append(bbox)

return False


class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
Expand Down Expand Up @@ -68,6 +101,26 @@ def __init__(
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)

self._find_doctags_labels()

def _find_doctags_labels(self):
"""Simple iteration over vocabulary"""
tokenizer = (
self.processor.tokenizer
if hasattr(self.processor, "tokenizer")
else self.processor
)

self.special_tokens: dict[str, int] = {}
if hasattr(tokenizer, "vocab"):
# vocab is usually a dict mapping token_text -> token_id
for token_text, token_id in tokenizer.vocab.items():
if re.match(r"^<[a-z\_\-\d]+>$", token_text):
print(f"Token ID: {token_id:6d} | Text: '{token_text}'")
self.special_tokens[token_text] = token_id
else:
print("Tokenizer doesn't have a 'vocab' attribute")

def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
Expand Down Expand Up @@ -199,6 +252,8 @@ def process_images(
tokens: list[VlmPredictionToken] = []
output = ""

stopping_criteria = DoclingStopping()

# Use stream_generate for proper stop string handling
for token in self.stream_generate(
self.vlm_model,
Expand All @@ -209,6 +264,10 @@ def process_images(
verbose=False,
temp=self.temperature,
):
_log.info(
f"logprobs.shape: {token.logprobs.shape} with token: {token}"
)

# Collect token information
if len(token.logprobs.shape) == 1:
tokens.append(
Expand All @@ -218,6 +277,26 @@ def process_images(
logprob=token.logprobs[token.token],
)
)
if token.text in self.special_tokens:
# Get logprobs for all special tokens
special_token_logprobs = []
for token_text, token_id in self.special_tokens.items():
logprob = token.logprobs[token_id]
special_token_logprobs.append(
(token_text, token_id, logprob)
)

# Sort by logprob (highest first) and take top 5
top_5_special = sorted(
special_token_logprobs, key=lambda x: x[2], reverse=True
)[:5]

print("Top 5 special tokens by logprob:")
for rank, (t, token_id, logprob) in enumerate(
top_5_special, 1
):
print(f" {rank}. {t}: {logprob:0.3f}")

elif (
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
):
Expand All @@ -228,13 +307,22 @@ def process_images(
logprob=token.logprobs[0, token.token],
)
)

if token.text in self.special_tokens:
for t, i in self.special_tokens.items():
print(f"{t}: {token.logprobs[0, i]:0.3f}")

else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)

output += token.text

if stopping_criteria.overlaps(output):
_log.debug("Stopping generation due to overlapping bbox")
break

# Check for any configured stop strings
if self.vlm_options.stop_strings:
if any(
Expand All @@ -246,7 +334,7 @@ def process_images(

generation_time = time.time() - start_time

_log.debug(
_log.info(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
)

Expand Down
Loading