Skip to content

Commit 2fd433e

Browse files
committed
Forward VLM convert model spec runtime settings
Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>
1 parent c761512 commit 2fd433e

5 files changed

Lines changed: 389 additions & 38 deletions

File tree

docling/datamodel/stage_model_specs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import annotations
1010

1111
import logging
12+
from copy import deepcopy
1213
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Set
1314

1415
from pydantic import BaseModel, Field
@@ -159,10 +160,24 @@ class VlmModelSpec(BaseModel):
159160
default_factory=list, description="Stop strings for generation"
160161
)
161162

163+
temperature: float = Field(
164+
default=0.0, description="Sampling temperature for generation"
165+
)
166+
162167
max_new_tokens: int = Field(
163168
default=4096, description="Maximum number of new tokens to generate"
164169
)
165170

171+
extra_generation_config: Dict[str, Any] = Field(
172+
default_factory=dict, description="Additional generation configuration"
173+
)
174+
175+
_RUNTIME_INPUT_OVERRIDE_KEYS: ClassVar[Set[str]] = {
176+
"transformers_prompt_style",
177+
"extra_processor_kwargs",
178+
"custom_stopping_criteria",
179+
}
180+
166181
def get_repo_id(self, engine_type: VlmEngineType) -> str:
167182
"""Get the repository ID for a specific engine.
168183
@@ -248,6 +263,34 @@ def get_engine_config(self, engine_type: VlmEngineType) -> EngineModelConfig:
248263
extra_config=extra_config,
249264
)
250265

266+
def get_runtime_input_extra_config(
267+
self, engine_type: VlmEngineType
268+
) -> Dict[str, Any]:
269+
"""Build runtime input config for a specific engine.
270+
271+
This returns only the subset of model/engine configuration that should
272+
flow into ``VlmEngineInput.extra_generation_config``. Load-time engine
273+
options such as ``torch_dtype`` or ``transformers_model_type`` remain in
274+
``EngineModelConfig.extra_config`` and are intentionally excluded.
275+
"""
276+
277+
runtime_config: Dict[str, Any] = deepcopy(self.extra_generation_config)
278+
279+
if engine_type not in self.engine_overrides:
280+
return runtime_config
281+
282+
override_config = self.engine_overrides[engine_type].extra_config
283+
nested_generation_config = override_config.get("extra_generation_config")
284+
285+
if isinstance(nested_generation_config, dict):
286+
runtime_config.update(deepcopy(nested_generation_config))
287+
288+
for key in self._RUNTIME_INPUT_OVERRIDE_KEYS:
289+
if key in override_config:
290+
runtime_config[key] = deepcopy(override_config[key])
291+
292+
return runtime_config
293+
251294
def has_explicit_engine_export(self, engine_type: VlmEngineType) -> bool:
252295
"""Check if this model has an explicit export for the given engine.
253296

docling/models/inference_engines/vlm/api_openai_compatible_engine.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,21 @@ def __init__(
5656
super().__init__(options, model_config=model_config)
5757
self.enable_remote_services = enable_remote_services
5858
self.options: ApiVlmEngineOptions = options
59+
self.model_api_params: dict[str, object] = {}
60+
self.user_params: dict[str, object] = self.options.params.copy()
5961

6062
if not self.enable_remote_services:
6163
raise OperationNotAllowed(
6264
"Connections to remote services is only allowed when set explicitly. "
6365
"pipeline_options.enable_remote_services=True."
6466
)
6567

66-
# Merge model_config extra_config (which contains API params from model spec)
67-
# with runtime options params. Runtime options take precedence.
68+
# Keep model spec API params as defaults only when the user has not
69+
# provided explicit API params; explicit runtime params are treated as
70+
# complete overrides to avoid vendor-specific conflicts.
6871
if model_config and "api_params" in model_config.extra_config:
69-
# Model spec provides API params (e.g., model name)
70-
model_api_params = model_config.extra_config["api_params"]
71-
72-
# Only use model spec params if user hasn't provided any params
73-
# This prevents conflicts when users provide custom params (e.g., model_id for watsonx)
7472
if not self.options.params:
75-
self.merged_params = model_api_params.copy()
76-
else:
77-
# User provided params - use them as-is (don't merge with model spec)
78-
self.merged_params = self.options.params.copy()
79-
else:
80-
self.merged_params = self.options.params.copy()
73+
self.model_api_params = model_config.extra_config["api_params"].copy()
8174

8275
def initialize(self) -> None:
8376
"""Initialize the API engine.
@@ -122,19 +115,21 @@ def _process_single_input(input_data: VlmEngineInput) -> VlmEngineOutput:
122115
images = preprocess_image_batch([input_data.image])
123116
image = images[0]
124117

125-
# Prepare API parameters: engine defaults first, then user/model
126-
# params override. This allows users to set Azure-specific params
127-
# like max_completion_tokens or override temperature (#3112).
128-
api_params: dict[str, object] = {
129-
"temperature": input_data.temperature,
130-
}
118+
# Apply precedence in this order:
119+
# 1. model spec API defaults
120+
# 2. per-request generation settings from VlmEngineInput
121+
# 3. explicit user API params from engine_options.params
122+
api_params: dict[str, object] = self.model_api_params.copy()
123+
api_params["temperature"] = input_data.temperature
131124

132125
# Add max_tokens if specified
133126
if input_data.max_new_tokens:
134127
api_params["max_tokens"] = input_data.max_new_tokens
135128

136-
# User/model spec params take precedence over engine defaults
137-
api_params.update(self.merged_params)
129+
# Explicit user params take precedence over per-request defaults.
130+
# This allows users to set Azure-specific params like
131+
# max_completion_tokens or override temperature (#3112).
132+
api_params.update(self.user_params)
138133

139134
# If user specified max_completion_tokens, remove conflicting
140135
# max_tokens (required for Azure OpenAI compatibility)

docling/models/stages/vlm_convert/vlm_convert_model.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
"""
66

77
import logging
8+
import time
89
from collections.abc import Iterable
910
from pathlib import Path
10-
from typing import Optional, Union
1111

1212
from PIL import Image as PILImage
1313

@@ -19,6 +19,7 @@
1919
from docling.models.inference_engines.vlm import (
2020
BaseVlmEngine,
2121
VlmEngineInput,
22+
VlmEngineType,
2223
create_vlm_engine,
2324
)
2425
from docling.utils.profiling import TimeRecorder
@@ -42,7 +43,7 @@ def __init__(
4243
self,
4344
enabled: bool,
4445
enable_remote_services: bool,
45-
artifacts_path: Optional[Union[Path, str]],
46+
artifacts_path: Path | str | None,
4647
options: VlmConvertOptions,
4748
accelerator_options: AcceleratorOptions,
4849
):
@@ -81,6 +82,26 @@ def __init__(
8182

8283
_log.info("VlmConvertModel initialized successfully")
8384

85+
def _get_runtime_engine_type(self) -> VlmEngineType:
86+
selected_engine_type = getattr(self.engine, "selected_engine_type", None)
87+
if selected_engine_type is not None:
88+
return selected_engine_type
89+
return self.options.engine_options.engine_type
90+
91+
def _build_engine_input(self, image: PILImage.Image, prompt: str) -> VlmEngineInput:
92+
model_spec = self.options.model_spec
93+
runtime_engine_type = self._get_runtime_engine_type()
94+
return VlmEngineInput(
95+
image=image,
96+
prompt=prompt,
97+
temperature=model_spec.temperature,
98+
max_new_tokens=model_spec.max_new_tokens,
99+
stop_strings=list(model_spec.stop_strings),
100+
extra_generation_config=model_spec.get_runtime_input_extra_config(
101+
runtime_engine_type
102+
),
103+
)
104+
84105
def __call__(
85106
self, conv_res: ConversionResult, page_batch: Iterable[Page]
86107
) -> Iterable[Page]:
@@ -106,33 +127,43 @@ def __call__(
106127
images = []
107128
prompts = []
108129
valid_pages = []
130+
rasterize_time = 0.0
131+
scale_resize_time = 0.0
132+
max_size_resize_time = 0.0
109133

110134
for page in page_list:
111-
if page.image is None:
135+
rasterize_start = time.perf_counter()
136+
image = page.image
137+
rasterize_time += time.perf_counter() - rasterize_start
138+
139+
if image is None:
112140
_log.warning(
113141
f"Page {page.page_no} has no image, skipping VLM conversion"
114142
)
115143
continue
116144

117145
# Scale image if needed
118-
image = page.image
119146
if self.options.scale != 1.0:
147+
resize_start = time.perf_counter()
120148
new_size = (
121149
int(image.width * self.options.scale),
122150
int(image.height * self.options.scale),
123151
)
124152
image = image.resize(new_size, PILImage.Resampling.LANCZOS)
153+
scale_resize_time += time.perf_counter() - resize_start
125154

126155
# Apply max_size constraint if specified
127156
if self.options.max_size is not None:
128157
max_dim = max(image.width, image.height)
129158
if max_dim > self.options.max_size:
159+
resize_start = time.perf_counter()
130160
scale_factor = self.options.max_size / max_dim
131161
new_size = (
132162
int(image.width * scale_factor),
133163
int(image.height * scale_factor),
134164
)
135165
image = image.resize(new_size, PILImage.Resampling.LANCZOS)
166+
max_size_resize_time += time.perf_counter() - resize_start
136167

137168
images.append(image)
138169
prompts.append(self.options.model_spec.prompt)
@@ -143,22 +174,29 @@ def __call__(
143174
return
144175

145176
# Process through runtime using batch prediction
146-
_log.debug(f"Processing {len(images)} pages through VLM engine (batched)")
177+
_log.debug(
178+
"Prepared %s pages for VLM engine: rasterize=%.3fs, scale_resize=%.3fs, max_size_resize=%.3fs",
179+
len(images),
180+
rasterize_time,
181+
scale_resize_time,
182+
max_size_resize_time,
183+
)
147184

148185
try:
149186
# Create batch of runtime inputs
150187
engine_inputs = [
151-
VlmEngineInput(
152-
image=img,
153-
prompt=prompt,
154-
temperature=0.0, # Use from options if needed
155-
max_new_tokens=4096, # Use from options if needed
156-
)
188+
self._build_engine_input(image=img, prompt=prompt)
157189
for img, prompt in zip(images, prompts)
158190
]
159191

160192
# Run batch inference
193+
batch_start = time.perf_counter()
161194
outputs = self.engine.predict_batch(engine_inputs)
195+
_log.debug(
196+
"Processed %s pages through VLM engine in %.3fs",
197+
len(engine_inputs),
198+
time.perf_counter() - batch_start,
199+
)
162200

163201
# Attach predictions to pages
164202
for page, output in zip(valid_pages, outputs):
@@ -226,12 +264,7 @@ def process_images(
226264

227265
# Process batch of images
228266
engine_inputs = [
229-
VlmEngineInput(
230-
image=img,
231-
prompt=p,
232-
temperature=0.0,
233-
max_new_tokens=4096,
234-
)
267+
self._build_engine_input(image=img, prompt=p)
235268
for img, p in zip(images, prompts)
236269
]
237270

tests/test_api_vlm_engine.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from PIL import Image
2+
3+
from docling.datamodel.stage_model_specs import EngineModelConfig
4+
from docling.datamodel.vlm_engine_options import ApiVlmEngineOptions
5+
from docling.models.inference_engines.vlm.api_openai_compatible_engine import (
6+
ApiVlmEngine,
7+
)
8+
from docling.models.inference_engines.vlm.base import VlmEngineInput, VlmEngineType
9+
10+
11+
def test_api_vlm_engine_uses_request_generation_settings_over_model_defaults(
12+
monkeypatch,
13+
) -> None:
14+
captured = {}
15+
16+
def _fake_api_image_request(**kwargs):
17+
captured.update(kwargs)
18+
return "ok", 1, "stop"
19+
20+
monkeypatch.setattr(
21+
"docling.models.inference_engines.vlm.api_openai_compatible_engine.api_image_request",
22+
_fake_api_image_request,
23+
)
24+
25+
engine = ApiVlmEngine(
26+
enable_remote_services=True,
27+
options=ApiVlmEngineOptions(
28+
engine_type=VlmEngineType.API_OPENAI,
29+
url="http://localhost:11434/v1/chat/completions",
30+
),
31+
model_config=EngineModelConfig(
32+
extra_config={
33+
"api_params": {
34+
"model": "test-model",
35+
"max_tokens": 4096,
36+
"temperature": 0.0,
37+
}
38+
}
39+
),
40+
)
41+
42+
outputs = engine.predict_batch(
43+
[
44+
VlmEngineInput(
45+
image=Image.new("RGB", (8, 8), "white"),
46+
prompt="Prompt",
47+
temperature=0.4,
48+
max_new_tokens=128,
49+
stop_strings=["</doctag>"],
50+
)
51+
]
52+
)
53+
54+
assert [output.text for output in outputs] == ["ok"]
55+
assert captured["model"] == "test-model"
56+
assert captured["temperature"] == 0.4
57+
assert captured["max_tokens"] == 128
58+
assert captured["stop"] == ["</doctag>"]
59+
60+
61+
def test_api_vlm_engine_allows_explicit_user_params_to_override_request_settings(
62+
monkeypatch,
63+
) -> None:
64+
captured = {}
65+
66+
def _fake_api_image_request(**kwargs):
67+
captured.update(kwargs)
68+
return "ok", 1, "stop"
69+
70+
monkeypatch.setattr(
71+
"docling.models.inference_engines.vlm.api_openai_compatible_engine.api_image_request",
72+
_fake_api_image_request,
73+
)
74+
75+
engine = ApiVlmEngine(
76+
enable_remote_services=True,
77+
options=ApiVlmEngineOptions(
78+
engine_type=VlmEngineType.API_OPENAI,
79+
url="http://localhost:11434/v1/chat/completions",
80+
params={
81+
"model": "override-model",
82+
"temperature": 0.8,
83+
"max_completion_tokens": 256,
84+
},
85+
),
86+
model_config=EngineModelConfig(
87+
extra_config={"api_params": {"model": "default-model", "max_tokens": 4096}}
88+
),
89+
)
90+
91+
outputs = engine.predict_batch(
92+
[
93+
VlmEngineInput(
94+
image=Image.new("RGB", (8, 8), "white"),
95+
prompt="Prompt",
96+
temperature=0.4,
97+
max_new_tokens=128,
98+
)
99+
]
100+
)
101+
102+
assert [output.text for output in outputs] == ["ok"]
103+
assert captured["model"] == "override-model"
104+
assert captured["temperature"] == 0.8
105+
assert captured["max_completion_tokens"] == 256
106+
assert "max_tokens" not in captured

0 commit comments

Comments
 (0)