diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index bcdb97f5e..a80977502 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -92,8 +92,7 @@ def _vlm_request(page): headers=self.vlm_options.headers, **self.params, ) - - page_tags = self.vlm_options.decode_response(page_tags) + page_tags = self.vlm_options.decode_response(page_tags.text) page.predictions.vlm_response = VlmPrediction(text=page_tags) return page diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index a3c0c2ee0..c04c3afba 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -12,7 +12,7 @@ ) from docling.exceptions import OperationNotAllowed from docling.models.picture_description_base_model import PictureDescriptionBaseModel -from docling.utils.api_image_request import api_image_request +from docling.utils.api_image_request import ApiImageResponse, api_image_request class PictureDescriptionApiModel(PictureDescriptionBaseModel): @@ -47,7 +47,9 @@ def __init__( "pipeline_options.enable_remote_services=True." ) - def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + def _annotate_images( + self, images: Iterable[Image.Image] + ) -> Iterable[ApiImageResponse]: # Note: technically we could make a batch request here, # but not all APIs will allow for it. For example, vllm won't allow more than 1. def _api_request(image): diff --git a/docling/models/picture_description_base_model.py b/docling/models/picture_description_base_model.py index 055c74b1f..d48638456 100644 --- a/docling/models/picture_description_base_model.py +++ b/docling/models/picture_description_base_model.py @@ -14,6 +14,7 @@ from PIL import Image from docling.datamodel.accelerator_options import AcceleratorOptions +from docling.datamodel.base_models import OpenAiResponseUsage from docling.datamodel.pipeline_options import ( PictureDescriptionBaseOptions, ) @@ -22,6 +23,13 @@ BaseModelWithOptions, ItemAndImageEnrichmentElement, ) +from docling.utils.api_image_request import ApiImageResponse + + +class DescriptionAnnotationWithUsage(PictureDescriptionData): + """DescriptionAnnotation with usage information.""" + + usage: Optional[OpenAiResponseUsage] = None class PictureDescriptionBaseModel( @@ -45,7 +53,9 @@ def __init__( def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: return self.enabled and isinstance(element, PictureItem) - def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + def _annotate_images( + self, images: Iterable[Image.Image] + ) -> Iterable[ApiImageResponse]: raise NotImplementedError def __call__( @@ -77,11 +87,13 @@ def __call__( elements.append(el.item) images.append(el.image) - outputs = self._annotate_images(images) + outputs: List[ApiImageResponse] = list(self._annotate_images(images)) for item, output in zip(elements, outputs): item.annotations.append( - PictureDescriptionData(text=output, provenance=self.provenance) + DescriptionAnnotationWithUsage( + text=output.text, provenance=self.provenance, usage=output.usage + ) ) yield item diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index 4b5007fae..3c750bf1f 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -1,9 +1,10 @@ import threading from collections.abc import Iterable from pathlib import Path -from typing import Optional, Type, Union +from typing import Optional, Type, Union, cast from PIL import Image +from torch import LongTensor from transformers import AutoModelForImageTextToText from docling.datamodel.accelerator_options import AcceleratorOptions @@ -16,6 +17,7 @@ HuggingFaceModelDownloadMixin, ) from docling.utils.accelerator_utils import decide_device +from docling.utils.api_image_request import ApiImageResponse, OpenAiResponseUsage # Global lock for model initialization to prevent threading issues _model_init_lock = threading.Lock() @@ -79,7 +81,9 @@ def __init__( self.provenance = f"{self.options.repo_id}" - def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: + def _annotate_images( + self, images: Iterable[Image.Image] + ) -> Iterable[ApiImageResponse]: from transformers import GenerationConfig # Create input messages @@ -108,9 +112,23 @@ def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]: **inputs, generation_config=GenerationConfig(**self.options.generation_config), ) + generated_texts = self.processor.batch_decode( generated_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True, ) - yield generated_texts[0].strip() + input_token_count = inputs["input_ids"].shape[1] + + # Normalize generate output to support both tensor and dataclass return types + sequences = cast( + LongTensor, getattr(generated_ids, "sequences", generated_ids) + ) + output_token_count = sequences.shape[1] - input_token_count + usage = OpenAiResponseUsage( + prompt_tokens=input_token_count, + completion_tokens=output_token_count, + total_tokens=input_token_count + output_token_count, + ) + + yield ApiImageResponse(text=generated_texts[0].strip(), usage=usage) diff --git a/docling/utils/api_image_request.py b/docling/utils/api_image_request.py index e85c6cad7..8e4f68ff6 100644 --- a/docling/utils/api_image_request.py +++ b/docling/utils/api_image_request.py @@ -1,6 +1,7 @@ import base64 import json import logging +from dataclasses import dataclass from io import BytesIO from typing import Dict, List, Optional @@ -8,20 +9,29 @@ from PIL import Image from pydantic import AnyUrl -from docling.datamodel.base_models import OpenAiApiResponse +from docling.datamodel.base_models import OpenAiApiResponse, OpenAiResponseUsage from docling.models.utils.generation_utils import GenerationStopper _log = logging.getLogger(__name__) +@dataclass +class ApiImageResponse: + """Generic response from image-based API calls.""" + + text: str + usage: OpenAiResponseUsage + + def api_image_request( image: Image.Image, prompt: str, url: AnyUrl, timeout: float = 20, headers: Optional[Dict[str, str]] = None, + token_extract_key: Optional[str] = None, **params, -) -> str: +) -> ApiImageResponse: img_io = BytesIO() image.save(img_io, "PNG") image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") @@ -60,7 +70,14 @@ def api_image_request( api_resp = OpenAiApiResponse.model_validate_json(r.text) generated_text = api_resp.choices[0].message.content.strip() - return generated_text + if api_resp.usage is None: + usage = OpenAiResponseUsage( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ) + else: + usage = api_resp.usage + + return ApiImageResponse(generated_text, usage) def api_image_request_streaming( @@ -72,7 +89,7 @@ def api_image_request_streaming( headers: Optional[Dict[str, str]] = None, generation_stoppers: List[GenerationStopper] = [], **params, -) -> str: +) -> ApiImageResponse: """ Stream a chat completion from an OpenAI-compatible server (e.g., vLLM). Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'. @@ -124,6 +141,7 @@ def api_image_request_streaming( r.raise_for_status() full_text = [] + usage_data = None for raw_line in r.iter_lines(decode_unicode=True): if not raw_line: # keep-alives / blank lines continue @@ -141,6 +159,10 @@ def api_image_request_streaming( _log.debug("Skipping non-JSON SSE chunk: %r", data[:200]) continue + # Try to extract usage if present (may be in final chunk) + if obj.get("usage"): + usage_data = obj["usage"] + # OpenAI-compatible delta format # obj["choices"][0]["delta"]["content"] may be None or missing (e.g., tool calls) try: @@ -162,6 +184,22 @@ def api_image_request_streaming( # closing the connection when we exit the 'with' block. # vLLM/OpenAI-compatible servers will detect the client disconnect # and abort the request server-side. - return "".join(full_text) + return ApiImageResponse( + text="".join(full_text), + usage=OpenAiResponseUsage( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), + ) + if usage_data: + try: + usage = OpenAiResponseUsage(**usage_data) + except Exception: + usage = OpenAiResponseUsage( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ) + else: + usage = OpenAiResponseUsage( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ) - return "".join(full_text) + return ApiImageResponse(text="".join(full_text), usage=usage)