Skip to content

Commit fcb7e48

Browse files
committed
fix: openai vision
1 parent 7e1f46a commit fcb7e48

File tree

2 files changed

+80
-22
lines changed

2 files changed

+80
-22
lines changed

example.png

19.5 KB
Loading

llm_dialog_manager/agent.py

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
import uuid
5-
from typing import List, Dict, Optional, Union
5+
from typing import List, Dict, Union, Optional, Any
66
import logging
77
from pathlib import Path
88
import random
@@ -97,13 +97,30 @@ def completion(model: str, messages: List[Dict[str, Union[str, List[Union[str, I
9797
api_key = os.getenv(f"{service.upper()}_API_KEY")
9898
base_url = os.getenv(f"{service.upper()}_BASE_URL")
9999

100-
def format_messages_for_api(model, messages):
101-
"""Convert ChatHistory messages to the format required by the specific API."""
100+
def format_messages_for_api(
101+
model: str,
102+
messages: List[Dict[str, Union[str, List[Union[str, Image.Image, Dict]]]]]
103+
) -> tuple[Optional[str], List[Dict[str, Any]]]:
104+
"""
105+
Convert ChatHistory messages to the format required by the specific API.
106+
107+
Args:
108+
model: The model name (e.g., "claude", "gemini", "gpt")
109+
messages: List of message dictionaries with role and content
110+
111+
Returns:
112+
tuple: (system_message, formatted_messages)
113+
- system_message is extracted system message for Claude, None for others
114+
- formatted_messages is the list of formatted message dictionaries
115+
"""
102116
if "claude" in model and "openai" not in model:
103117
formatted = []
104118
system_msg = ""
119+
120+
# Extract system message if present
105121
if messages and messages[0]["role"] == "system":
106122
system_msg = messages.pop(0)["content"]
123+
107124
for msg in messages:
108125
content = msg["content"]
109126
if isinstance(content, str):
@@ -113,9 +130,12 @@ def format_messages_for_api(model, messages):
113130
combined_content = []
114131
for block in content:
115132
if isinstance(block, str):
116-
combined_content.append({"type": "text", "text": block})
133+
combined_content.append({
134+
"type": "text",
135+
"text": block
136+
})
117137
elif isinstance(block, Image.Image):
118-
# For Claude, convert PIL.Image to base64
138+
# Convert PIL.Image to base64
119139
buffered = io.BytesIO()
120140
block.save(buffered, format="PNG")
121141
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -145,9 +165,12 @@ def format_messages_for_api(model, messages):
145165
"data": block["image_base64"]["data"]
146166
}
147167
})
148-
formatted.append({"role": msg["role"], "content": combined_content})
168+
formatted.append({
169+
"role": msg["role"],
170+
"content": combined_content
171+
})
149172
return system_msg, formatted
150-
173+
151174
elif ("gemini" in model or "gpt" in model or "grok" in model) and "openai" not in model:
152175
formatted = []
153176
for msg in messages:
@@ -160,40 +183,75 @@ def format_messages_for_api(model, messages):
160183
if isinstance(block, str):
161184
parts.append(block)
162185
elif isinstance(block, Image.Image):
186+
# Keep PIL.Image objects as is for Gemini
163187
parts.append(block)
164188
elif isinstance(block, dict):
165189
if block.get("type") == "image_url":
166-
parts.append({"type": "image_url", "image_url": {"url": block["image_url"]["url"]}})
190+
parts.append({
191+
"type": "image_url",
192+
"image_url": {
193+
"url": block["image_url"]["url"]
194+
}
195+
})
167196
elif block.get("type") == "image_base64":
168-
parts.append({"type": "image_base64", "image_base64": {"data": block["image_base64"]["data"], "media_type": block["image_base64"]["media_type"]}})
169-
formatted.append({"role": msg["role"], "parts": parts})
197+
parts.append({
198+
"type": "image_base64",
199+
"image_base64": {
200+
"data": block["image_base64"]["data"],
201+
"media_type": block["image_base64"]["media_type"]
202+
}
203+
})
204+
formatted.append({
205+
"role": msg["role"],
206+
"parts": parts
207+
})
170208
return None, formatted
171-
209+
172210
else: # OpenAI models
173211
formatted = []
174212
for msg in messages:
175213
content = msg["content"]
176214
if isinstance(content, str):
177-
formatted.append({"role": msg["role"], "content": content})
215+
formatted.append({
216+
"role": msg["role"],
217+
"content": content
218+
})
178219
elif isinstance(content, list):
179-
# OpenAI expects 'content' as string; images are not directly supported
180-
# You can convert images to URLs or descriptions if needed
181-
combined_content = ""
220+
formatted_content = []
182221
for block in content:
183222
if isinstance(block, str):
184-
combined_content += block + "\n"
223+
formatted_content.append({
224+
"type": "text",
225+
"text": block
226+
})
185227
elif isinstance(block, Image.Image):
186-
# Convert PIL.Image to base64 or upload and use URL
228+
# Convert PIL.Image to base64
187229
buffered = io.BytesIO()
188230
block.save(buffered, format="PNG")
189231
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
190-
combined_content += f"[Image Base64: {image_base64[:30]}...]\n"
232+
formatted_content.append({
233+
"type": "image_url",
234+
"image_url": {
235+
"url": f"data:image/jpeg;base64,{image_base64}"
236+
}
237+
})
191238
elif isinstance(block, dict):
192239
if block.get("type") == "image_url":
193-
combined_content += f"[Image: {block['image_url']['url']}]\n"
240+
formatted_content.append({
241+
"type": "image_url",
242+
"image_url": block["image_url"]
243+
})
194244
elif block.get("type") == "image_base64":
195-
combined_content += f"[Image Base64: {block['image_base64']['data'][:30]}...]\n"
196-
formatted.append({"role": msg["role"], "content": combined_content.strip()})
245+
formatted_content.append({
246+
"type": "image_url",
247+
"image_url": {
248+
"url": f"data:image/jpeg;base64,{block['image_base64']['data']}"
249+
}
250+
})
251+
formatted.append({
252+
"role": msg["role"],
253+
"content": formatted_content
254+
})
197255
return None, formatted
198256

199257
system_msg, formatted_messages = format_messages_for_api(model, messages.copy())
@@ -546,7 +604,7 @@ def add_repo(self, repo_url: Optional[str] = None, username: Optional[str] = Non
546604
if __name__ == "__main__":
547605
# Example Usage
548606
# Create an Agent instance (Gemini model)
549-
agent = Agent("gemini-1.5-flash", "you are Jack101", memory_enabled=True)
607+
agent = Agent("gemini-1.5-flash-openai", "you are Jack101", memory_enabled=True)
550608

551609
# Add an image
552610
agent.add_image(image_path="example.png")

0 commit comments

Comments
 (0)