Skip to content

Commit 67cc408

Browse files
committed
Support image tool outputs in chat completions
1 parent 69ab416 commit 67cc408

5 files changed

Lines changed: 399 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"omegaconf>=2.3.0",
2727
"open-clip-torch>=2.20.0",
2828
"openai>=2.8.0",
29-
"openai-agents>=0.6.4",
29+
"openai-agents>=0.6.5,<0.7",
3030
"opencv-python>=4.10.0.84,<4.12",
3131
"pybind11>=3.0.1",
3232
"pygltflib>=1.16.5",

scenesmith/agent_utils/base_stateful_agent.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
from openai.types.shared import Reasoning
3333

3434
from scenesmith.agent_utils.action_logger import log_scene_action
35+
from scenesmith.agent_utils.chat_completions_image_filter import (
36+
ChatCompletionsToolImageFilter,
37+
CompositeCallModelInputFilter,
38+
)
3539
from scenesmith.agent_utils.checkpoint_state import initialize_checkpoint_attributes
3640
from scenesmith.agent_utils.intra_turn_image_filter import IntraTurnImageFilter
3741
from scenesmith.agent_utils.physics_tools import check_physics_violations
@@ -372,20 +376,26 @@ def _create_sessions(self, session_prefix: str = "") -> tuple[Session, Session]:
372376
return designer_session, critic_session
373377

374378
def _create_run_config(self) -> RunConfig:
375-
"""Create RunConfig with intra-turn image filter if enabled.
379+
"""Create RunConfig with model input filters.
376380
377-
The filter strips images from older observe_scene outputs within a turn,
378-
keeping only the last N observations with images intact. This reduces
379-
token usage when agents call observe_scene multiple times within a turn.
381+
Intra-turn stripping reduces token usage when agents call observe_scene
382+
multiple times within a turn. The Chat Completions image filter keeps
383+
image-returning tools usable if the SDK is configured away from the
384+
default Responses API.
380385
381386
Returns:
382-
RunConfig with call_model_input_filter set if enabled, empty otherwise.
387+
RunConfig with call_model_input_filter set.
383388
"""
389+
input_filters = []
384390
intra_cfg = self.cfg.session_memory.intra_turn_observation_stripping
385391
if intra_cfg.enabled:
386-
return RunConfig(call_model_input_filter=IntraTurnImageFilter(cfg=self.cfg))
392+
input_filters.append(IntraTurnImageFilter(cfg=self.cfg))
393+
394+
input_filters.append(ChatCompletionsToolImageFilter())
387395

388-
return RunConfig()
396+
return RunConfig(
397+
call_model_input_filter=CompositeCallModelInputFilter(input_filters)
398+
)
389399

390400
def _should_reset_to_checkpoint(
391401
self,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Chat Completions compatibility for image-returning tools."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any
7+
8+
from agents import OpenAIChatCompletionsModel
9+
from agents.items import TResponseInputItem
10+
from agents.models import _openai_shared
11+
from agents.run import CallModelData, ModelInputData
12+
13+
from scenesmith.agent_utils.turn_trimming_session import _is_image_content
14+
15+
16+
def _uses_openai_chat_completions(data: CallModelData[Any]) -> bool:
17+
"""Return true when the current run will use OpenAI Chat Completions."""
18+
if isinstance(data.agent.model, OpenAIChatCompletionsModel):
19+
return True
20+
21+
# SceneSmith creates agents with model names. In that case, the Agents SDK
22+
# resolves the backend through the default OpenAI provider.
23+
if isinstance(data.agent.model, str) or data.agent.model is None:
24+
return not _openai_shared.get_use_responses_by_default()
25+
26+
return False
27+
28+
29+
@dataclass
30+
class ChatCompletionsToolImageFilter:
31+
"""Expose tool-output images to Chat Completions models.
32+
33+
OpenAI Chat Completions only allows text content in tool messages. When
34+
tools return images, keep tool messages text-only and add a synthetic user
35+
message with the images after each contiguous block of tool results.
36+
"""
37+
38+
force_enable: bool = False
39+
40+
def __call__(self, data: CallModelData[Any]) -> ModelInputData:
41+
if not self.force_enable and not _uses_openai_chat_completions(data):
42+
return data.model_data
43+
44+
transformed: list[TResponseInputItem] = []
45+
pending_image_parts: list[dict[str, Any]] = []
46+
changed = False
47+
48+
for item in data.model_data.input:
49+
if not self._is_function_call_output(item):
50+
self._flush_pending_images(transformed, pending_image_parts)
51+
transformed.append(item)
52+
continue
53+
54+
tool_item, image_parts = self._split_image_tool_output(item)
55+
transformed.append(tool_item)
56+
pending_image_parts.extend(image_parts)
57+
changed = changed or tool_item is not item or bool(image_parts)
58+
59+
self._flush_pending_images(transformed, pending_image_parts)
60+
61+
if not changed:
62+
return data.model_data
63+
64+
return ModelInputData(
65+
input=transformed, instructions=data.model_data.instructions
66+
)
67+
68+
def _is_function_call_output(self, item: TResponseInputItem) -> bool:
69+
if not isinstance(item, dict):
70+
return False
71+
return item.get("type") == "function_call_output"
72+
73+
def _is_list_tool_output(self, item: TResponseInputItem) -> bool:
74+
if not self._is_function_call_output(item):
75+
return False
76+
assert isinstance(item, dict)
77+
output = item.get("output")
78+
return isinstance(output, list)
79+
80+
def _split_image_tool_output(
81+
self, item: TResponseInputItem
82+
) -> tuple[TResponseInputItem, list[dict[str, Any]]]:
83+
if not self._is_list_tool_output(item):
84+
return item, []
85+
86+
assert isinstance(item, dict) # Narrowed by _is_list_tool_output.
87+
output = item.get("output")
88+
assert isinstance(output, list)
89+
90+
image_parts = [
91+
part
92+
for part in output
93+
if isinstance(part, dict) and _is_image_content(part)
94+
]
95+
text_parts = [
96+
part
97+
for part in output
98+
if not (isinstance(part, dict) and _is_image_content(part))
99+
]
100+
101+
tool_item = dict(item)
102+
if text_parts:
103+
tool_item["output"] = self._text_parts_to_string(text_parts)
104+
else:
105+
tool_item["output"] = (
106+
"The tool returned image output. The image content is attached in "
107+
"the following user message."
108+
)
109+
110+
return tool_item, image_parts
111+
112+
def _text_parts_to_string(self, text_parts: list[Any]) -> str:
113+
text_segments = []
114+
for part in text_parts:
115+
if isinstance(part, str):
116+
text_segments.append(part)
117+
elif isinstance(part, dict):
118+
text = part.get("text")
119+
if isinstance(text, str):
120+
text_segments.append(text)
121+
else:
122+
text_segments.append(str(part))
123+
else:
124+
text_segments.append(str(part))
125+
126+
return (
127+
"\n".join(segment for segment in text_segments if segment)
128+
or "[Tool output]"
129+
)
130+
131+
def _flush_pending_images(
132+
self,
133+
transformed: list[TResponseInputItem],
134+
pending_image_parts: list[dict[str, Any]],
135+
) -> None:
136+
if not pending_image_parts:
137+
return
138+
139+
image_message: TResponseInputItem = {
140+
"role": "user",
141+
"content": [
142+
{
143+
"type": "input_text",
144+
"text": "Images returned by the previous tool call(s):",
145+
},
146+
*pending_image_parts,
147+
],
148+
}
149+
transformed.append(image_message)
150+
pending_image_parts.clear()
151+
152+
153+
@dataclass
154+
class CompositeCallModelInputFilter:
155+
"""Apply multiple call_model_input_filter functions in order."""
156+
157+
filters: list[Any]
158+
159+
def __call__(self, data: CallModelData[Any]) -> ModelInputData:
160+
model_data = data.model_data
161+
for input_filter in self.filters:
162+
model_data = input_filter(
163+
CallModelData(
164+
model_data=model_data,
165+
agent=data.agent,
166+
context=data.context,
167+
)
168+
)
169+
return model_data

0 commit comments

Comments
 (0)