Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docling/models/api_vlm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def _vlm_request(page):
headers=self.vlm_options.headers,
**self.params,
)

page_tags = self.vlm_options.decode_response(page_tags)
page_tags = self.vlm_options.decode_response(page_tags.text)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
return page

Expand Down
6 changes: 4 additions & 2 deletions docling/models/picture_description_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.api_image_request import api_image_request
from docling.utils.api_image_request import ApiImageResponse, api_image_request


class PictureDescriptionApiModel(PictureDescriptionBaseModel):
Expand Down Expand Up @@ -47,7 +47,9 @@ def __init__(
"pipeline_options.enable_remote_services=True."
)

def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
def _annotate_images(
self, images: Iterable[Image.Image]
) -> Iterable[ApiImageResponse]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
def _api_request(image):
Expand Down
18 changes: 15 additions & 3 deletions docling/models/picture_description_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from PIL import Image

from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import OpenAiResponseUsage
from docling.datamodel.pipeline_options import (
PictureDescriptionBaseOptions,
)
Expand All @@ -22,6 +23,13 @@
BaseModelWithOptions,
ItemAndImageEnrichmentElement,
)
from docling.utils.api_image_request import ApiImageResponse


class DescriptionAnnotationWithUsage(PictureDescriptionData):
"""DescriptionAnnotation with usage information."""

usage: Optional[OpenAiResponseUsage] = None


class PictureDescriptionBaseModel(
Expand All @@ -45,7 +53,9 @@ def __init__(
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
return self.enabled and isinstance(element, PictureItem)

def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
def _annotate_images(
self, images: Iterable[Image.Image]
) -> Iterable[ApiImageResponse]:
raise NotImplementedError

def __call__(
Expand Down Expand Up @@ -77,11 +87,13 @@ def __call__(
elements.append(el.item)
images.append(el.image)

outputs = self._annotate_images(images)
outputs: List[ApiImageResponse] = list(self._annotate_images(images))

for item, output in zip(elements, outputs):
item.annotations.append(
PictureDescriptionData(text=output, provenance=self.provenance)
DescriptionAnnotationWithUsage(
text=output.text, provenance=self.provenance, usage=output.usage
)
)
yield item

Expand Down
24 changes: 21 additions & 3 deletions docling/models/picture_description_vlm_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import threading
from collections.abc import Iterable
from pathlib import Path
from typing import Optional, Type, Union
from typing import Optional, Type, Union, cast

from PIL import Image
from torch import LongTensor
from transformers import AutoModelForImageTextToText

from docling.datamodel.accelerator_options import AcceleratorOptions
Expand All @@ -16,6 +17,7 @@
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.api_image_request import ApiImageResponse, OpenAiResponseUsage

# Global lock for model initialization to prevent threading issues
_model_init_lock = threading.Lock()
Expand Down Expand Up @@ -79,7 +81,9 @@ def __init__(

self.provenance = f"{self.options.repo_id}"

def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
def _annotate_images(
self, images: Iterable[Image.Image]
) -> Iterable[ApiImageResponse]:
from transformers import GenerationConfig

# Create input messages
Expand Down Expand Up @@ -108,9 +112,23 @@ def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
**inputs,
generation_config=GenerationConfig(**self.options.generation_config),
)

generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)

yield generated_texts[0].strip()
input_token_count = inputs["input_ids"].shape[1]

# Normalize generate output to support both tensor and dataclass return types
sequences = cast(
LongTensor, getattr(generated_ids, "sequences", generated_ids)
)
output_token_count = sequences.shape[1] - input_token_count
usage = OpenAiResponseUsage(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
)

yield ApiImageResponse(text=generated_texts[0].strip(), usage=usage)
50 changes: 44 additions & 6 deletions docling/utils/api_image_request.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
import base64
import json
import logging
from dataclasses import dataclass
from io import BytesIO
from typing import Dict, List, Optional

import requests
from PIL import Image
from pydantic import AnyUrl

from docling.datamodel.base_models import OpenAiApiResponse
from docling.datamodel.base_models import OpenAiApiResponse, OpenAiResponseUsage
from docling.models.utils.generation_utils import GenerationStopper

_log = logging.getLogger(__name__)


@dataclass
class ApiImageResponse:
"""Generic response from image-based API calls."""

text: str
usage: OpenAiResponseUsage


def api_image_request(
image: Image.Image,
prompt: str,
url: AnyUrl,
timeout: float = 20,
headers: Optional[Dict[str, str]] = None,
token_extract_key: Optional[str] = None,
**params,
) -> str:
) -> ApiImageResponse:
img_io = BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
Expand Down Expand Up @@ -60,7 +70,14 @@ def api_image_request(

api_resp = OpenAiApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
return generated_text
if api_resp.usage is None:
usage = OpenAiResponseUsage(
prompt_tokens=0, completion_tokens=0, total_tokens=0
)
else:
usage = api_resp.usage

return ApiImageResponse(generated_text, usage)


def api_image_request_streaming(
Expand All @@ -72,7 +89,7 @@ def api_image_request_streaming(
headers: Optional[Dict[str, str]] = None,
generation_stoppers: List[GenerationStopper] = [],
**params,
) -> str:
) -> ApiImageResponse:
"""
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
Expand Down Expand Up @@ -124,6 +141,7 @@ def api_image_request_streaming(
r.raise_for_status()

full_text = []
usage_data = None
for raw_line in r.iter_lines(decode_unicode=True):
if not raw_line: # keep-alives / blank lines
continue
Expand All @@ -141,6 +159,10 @@ def api_image_request_streaming(
_log.debug("Skipping non-JSON SSE chunk: %r", data[:200])
continue

# Try to extract usage if present (may be in final chunk)
if obj.get("usage"):
usage_data = obj["usage"]

# OpenAI-compatible delta format
# obj["choices"][0]["delta"]["content"] may be None or missing (e.g., tool calls)
try:
Expand All @@ -162,6 +184,22 @@ def api_image_request_streaming(
# closing the connection when we exit the 'with' block.
# vLLM/OpenAI-compatible servers will detect the client disconnect
# and abort the request server-side.
return "".join(full_text)
return ApiImageResponse(
text="".join(full_text),
usage=OpenAiResponseUsage(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
)
if usage_data:
try:
usage = OpenAiResponseUsage(**usage_data)
except Exception:
usage = OpenAiResponseUsage(
prompt_tokens=0, completion_tokens=0, total_tokens=0
)
else:
usage = OpenAiResponseUsage(
prompt_tokens=0, completion_tokens=0, total_tokens=0
)

return "".join(full_text)
return ApiImageResponse(text="".join(full_text), usage=usage)