Skip to content

Commit fdc98d5

Browse files
GWealecopybara-github
authored andcommitted
fix: Convert unsupported inline artifact MIME types to text in LoadArtifactsTool
The LoadArtifactsTool now checks if an artifact's inline data MIME type is supported by Gemini. If not, it attempts to convert the artifact content into a text Part Close #4028 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 856404510
1 parent 7b035aa commit fdc98d5

File tree

2 files changed

+265
-2
lines changed

2 files changed

+265
-2
lines changed

src/google/adk/tools/load_artifacts_tool.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
import base64
18+
import binascii
1719
import json
1820
import logging
1921
from typing import Any
@@ -24,13 +26,99 @@
2426

2527
from .base_tool import BaseTool
2628

29+
# MIME types Gemini accepts for inline data in requests.
30+
_GEMINI_SUPPORTED_INLINE_MIME_PREFIXES = (
31+
'image/',
32+
'audio/',
33+
'video/',
34+
)
35+
_GEMINI_SUPPORTED_INLINE_MIME_TYPES = frozenset({'application/pdf'})
36+
_TEXT_LIKE_MIME_TYPES = frozenset({
37+
'application/csv',
38+
'application/json',
39+
'application/xml',
40+
})
41+
2742
if TYPE_CHECKING:
2843
from ..models.llm_request import LlmRequest
2944
from .tool_context import ToolContext
3045

3146
logger = logging.getLogger('google_adk.' + __name__)
3247

3348

49+
def _normalize_mime_type(mime_type: str | None) -> str | None:
50+
"""Returns the normalized MIME type, without parameters like charset."""
51+
if not mime_type:
52+
return None
53+
return mime_type.split(';', 1)[0].strip()
54+
55+
56+
def _is_inline_mime_type_supported(mime_type: str | None) -> bool:
57+
"""Returns True if Gemini accepts this MIME type as inline data."""
58+
normalized = _normalize_mime_type(mime_type)
59+
if not normalized:
60+
return False
61+
return normalized.startswith(_GEMINI_SUPPORTED_INLINE_MIME_PREFIXES) or (
62+
normalized in _GEMINI_SUPPORTED_INLINE_MIME_TYPES
63+
)
64+
65+
66+
def _maybe_base64_to_bytes(data: str) -> bytes | None:
67+
"""Best-effort base64 decode for both std and urlsafe formats."""
68+
try:
69+
return base64.b64decode(data, validate=True)
70+
except (binascii.Error, ValueError):
71+
try:
72+
return base64.urlsafe_b64decode(data)
73+
except (binascii.Error, ValueError):
74+
return None
75+
76+
77+
def _as_safe_part_for_llm(
78+
artifact: types.Part, artifact_name: str
79+
) -> types.Part:
80+
"""Returns a Part that is safe to send to Gemini."""
81+
inline_data = artifact.inline_data
82+
if inline_data is None:
83+
return artifact
84+
85+
if _is_inline_mime_type_supported(inline_data.mime_type):
86+
return artifact
87+
88+
mime_type = _normalize_mime_type(inline_data.mime_type) or (
89+
'application/octet-stream'
90+
)
91+
data = inline_data.data
92+
if data is None:
93+
return types.Part.from_text(
94+
text=(
95+
f'[Artifact: {artifact_name}, type: {mime_type}. '
96+
'No inline data was provided.]'
97+
)
98+
)
99+
100+
if isinstance(data, str):
101+
decoded = _maybe_base64_to_bytes(data)
102+
if decoded is None:
103+
return types.Part.from_text(text=data)
104+
data = decoded
105+
106+
if mime_type.startswith('text/') or mime_type in _TEXT_LIKE_MIME_TYPES:
107+
try:
108+
return types.Part.from_text(text=data.decode('utf-8'))
109+
except UnicodeDecodeError:
110+
return types.Part.from_text(text=data.decode('utf-8', errors='replace'))
111+
112+
size_kb = len(data) / 1024
113+
return types.Part.from_text(
114+
text=(
115+
f'[Binary artifact: {artifact_name}, '
116+
f'type: {mime_type}, size: {size_kb:.1f} KB. '
117+
'Content cannot be displayed inline.]'
118+
)
119+
)
120+
121+
34122
class LoadArtifactsTool(BaseTool):
35123
"""A tool that loads the artifacts and adds them to the session."""
36124

@@ -108,7 +196,8 @@ async def _append_artifacts_to_llm_request(
108196
if llm_request.contents and llm_request.contents[-1].parts:
109197
function_response = llm_request.contents[-1].parts[0].function_response
110198
if function_response and function_response.name == 'load_artifacts':
111-
artifact_names = function_response.response['artifact_names']
199+
response = function_response.response or {}
200+
artifact_names = response.get('artifact_names', [])
112201
for artifact_name in artifact_names:
113202
# Try session-scoped first (default behavior)
114203
artifact = await tool_context.load_artifact(artifact_name)
@@ -122,14 +211,26 @@ async def _append_artifacts_to_llm_request(
122211
if artifact is None:
123212
logger.warning('Artifact "%s" not found, skipping', artifact_name)
124213
continue
214+
215+
artifact_part = _as_safe_part_for_llm(artifact, artifact_name)
216+
if artifact_part is not artifact:
217+
mime_type = (
218+
artifact.inline_data.mime_type if artifact.inline_data else None
219+
)
220+
logger.debug(
221+
'Converted artifact "%s" (mime_type=%s) to text Part',
222+
artifact_name,
223+
mime_type,
224+
)
225+
125226
llm_request.contents.append(
126227
types.Content(
127228
role='user',
128229
parts=[
129230
types.Part.from_text(
130231
text=f'Artifact {artifact_name} is:'
131232
),
132-
artifact,
233+
artifact_part,
133234
],
134235
)
135236
)
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import base64
16+
17+
from google.adk.models.llm_request import LlmRequest
18+
from google.adk.tools.load_artifacts_tool import _maybe_base64_to_bytes
19+
from google.adk.tools.load_artifacts_tool import load_artifacts_tool
20+
from google.genai import types
21+
from pytest import mark
22+
23+
24+
class _StubToolContext:
25+
"""Minimal ToolContext stub for LoadArtifactsTool tests."""
26+
27+
def __init__(self, artifacts_by_name: dict[str, types.Part]):
28+
self._artifacts_by_name = artifacts_by_name
29+
30+
async def list_artifacts(self) -> list[str]:
31+
return list(self._artifacts_by_name.keys())
32+
33+
async def load_artifact(self, name: str) -> types.Part | None:
34+
return self._artifacts_by_name.get(name)
35+
36+
37+
@mark.asyncio
38+
async def test_load_artifacts_converts_unsupported_mime_to_text():
39+
"""Unsupported inline MIME types are converted to text parts."""
40+
artifact_name = 'test.csv'
41+
csv_bytes = b'col1,col2\n1,2\n'
42+
artifact = types.Part(
43+
inline_data=types.Blob(data=csv_bytes, mime_type='application/csv')
44+
)
45+
46+
tool_context = _StubToolContext({artifact_name: artifact})
47+
llm_request = LlmRequest(
48+
contents=[
49+
types.Content(
50+
role='user',
51+
parts=[
52+
types.Part(
53+
function_response=types.FunctionResponse(
54+
name='load_artifacts',
55+
response={'artifact_names': [artifact_name]},
56+
)
57+
)
58+
],
59+
)
60+
]
61+
)
62+
63+
await load_artifacts_tool.process_llm_request(
64+
tool_context=tool_context, llm_request=llm_request
65+
)
66+
67+
assert llm_request.contents[-1].parts[0].text == (
68+
f'Artifact {artifact_name} is:'
69+
)
70+
artifact_part = llm_request.contents[-1].parts[1]
71+
assert artifact_part.inline_data is None
72+
assert artifact_part.text == csv_bytes.decode('utf-8')
73+
74+
75+
@mark.asyncio
76+
async def test_load_artifacts_converts_base64_unsupported_mime_to_text():
77+
"""Unsupported base64 string data is converted to text parts."""
78+
artifact_name = 'test.csv'
79+
csv_bytes = b'col1,col2\n1,2\n'
80+
csv_base64 = base64.b64encode(csv_bytes).decode('ascii')
81+
artifact = types.Part(
82+
inline_data=types.Blob(data=csv_base64, mime_type='application/csv')
83+
)
84+
85+
tool_context = _StubToolContext({artifact_name: artifact})
86+
llm_request = LlmRequest(
87+
contents=[
88+
types.Content(
89+
role='user',
90+
parts=[
91+
types.Part(
92+
function_response=types.FunctionResponse(
93+
name='load_artifacts',
94+
response={'artifact_names': [artifact_name]},
95+
)
96+
)
97+
],
98+
)
99+
]
100+
)
101+
102+
await load_artifacts_tool.process_llm_request(
103+
tool_context=tool_context, llm_request=llm_request
104+
)
105+
106+
artifact_part = llm_request.contents[-1].parts[1]
107+
assert artifact_part.inline_data is None
108+
assert artifact_part.text == csv_bytes.decode('utf-8')
109+
110+
111+
@mark.asyncio
112+
async def test_load_artifacts_keeps_supported_mime_types():
113+
"""Supported inline MIME types are passed through unchanged."""
114+
artifact_name = 'test.pdf'
115+
artifact = types.Part(
116+
inline_data=types.Blob(data=b'%PDF-1.4', mime_type='application/pdf')
117+
)
118+
119+
tool_context = _StubToolContext({artifact_name: artifact})
120+
llm_request = LlmRequest(
121+
contents=[
122+
types.Content(
123+
role='user',
124+
parts=[
125+
types.Part(
126+
function_response=types.FunctionResponse(
127+
name='load_artifacts',
128+
response={'artifact_names': [artifact_name]},
129+
)
130+
)
131+
],
132+
)
133+
]
134+
)
135+
136+
await load_artifacts_tool.process_llm_request(
137+
tool_context=tool_context, llm_request=llm_request
138+
)
139+
140+
artifact_part = llm_request.contents[-1].parts[1]
141+
assert artifact_part.inline_data is not None
142+
assert artifact_part.inline_data.mime_type == 'application/pdf'
143+
144+
145+
def test_maybe_base64_to_bytes_decodes_standard_base64():
146+
"""Standard base64 encoded strings are decoded correctly."""
147+
original = b'hello world'
148+
encoded = base64.b64encode(original).decode('ascii')
149+
assert _maybe_base64_to_bytes(encoded) == original
150+
151+
152+
def test_maybe_base64_to_bytes_decodes_urlsafe_base64():
153+
"""URL-safe base64 encoded strings are decoded correctly."""
154+
original = b'\xfb\xff\xfe' # bytes that produce +/ in std but -_ in urlsafe
155+
encoded = base64.urlsafe_b64encode(original).decode('ascii')
156+
assert _maybe_base64_to_bytes(encoded) == original
157+
158+
159+
def test_maybe_base64_to_bytes_returns_none_for_invalid():
160+
"""Invalid base64 strings return None."""
161+
# Single character is invalid (base64 requires length % 4 == 0 after padding)
162+
assert _maybe_base64_to_bytes('x') is None

0 commit comments

Comments
 (0)