Skip to content

Commit 712b5a3

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Only filter out audio content when sending history
audio is transcribed thus no need to be sent, but other blob(e.g. image) should still be sent. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 856422986
1 parent 89bed43 commit 712b5a3

File tree

3 files changed

+224
-4
lines changed

3 files changed

+224
-4
lines changed

src/google/adk/models/gemini_llm_connection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from google.genai import types
2222

23+
from ..utils.content_utils import filter_audio_parts
2324
from ..utils.context_utils import Aclosing
2425
from ..utils.variant_utils import GoogleLLMVariant
2526
from .base_llm_connection import BaseLlmConnection
@@ -63,15 +64,22 @@ async def send_history(self, history: list[types.Content]):
6364
# TODO: Remove this filter and translate unary contents to streaming
6465
# contents properly.
6566

66-
# We ignore any audio from user during the agent transfer phase
67+
# Filter out audio parts from history because:
68+
# 1. audio has already been transcribed.
69+
# 2. sending audio via connection.send or connection.send_live_content is
70+
# not supported by LIVE API (session will be corrupted).
71+
# This method is called when:
72+
# 1. Agent transfer to a new agent
73+
# 2. Establishing a new live connection with previous ADK session history
74+
6775
contents = [
68-
content
76+
filtered
6977
for content in history
70-
if content.parts and content.parts[0].text
78+
if (filtered := filter_audio_parts(content)) is not None
7179
]
72-
logger.debug('Sending history to live connection: %s', contents)
7380

7481
if contents:
82+
logger.debug('Sending history to live connection: %s', contents)
7583
await self._gemini_session.send(
7684
input=types.LiveClientContent(
7785
turns=contents,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from google.genai import types
18+
19+
20+
def is_audio_part(part: types.Part) -> bool:
21+
return (
22+
part.inline_data
23+
and part.inline_data.mime_type
24+
and part.inline_data.mime_type.startswith('audio/')
25+
) or (
26+
part.file_data
27+
and part.file_data.mime_type
28+
and part.file_data.mime_type.startswith('audio/')
29+
)
30+
31+
32+
def filter_audio_parts(content: types.Content) -> types.Content | None:
33+
if not content.parts:
34+
return None
35+
filtered_parts = [part for part in content.parts if not is_audio_part(part)]
36+
if not filtered_parts:
37+
return None
38+
return types.Content(role=content.role, parts=filtered_parts)

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,177 @@ async def mock_receive_generator():
600600
assert responses[2].output_transcription.text == 'How can I help?'
601601
assert responses[2].output_transcription.finished is True
602602
assert responses[2].partial is False
603+
604+
605+
@pytest.mark.asyncio
606+
@pytest.mark.parametrize(
607+
'audio_part',
608+
[
609+
types.Part(
610+
inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
611+
),
612+
types.Part(
613+
file_data=types.FileData(
614+
file_uri='artifact://app/user/session/_adk_live/audio.pcm#1',
615+
mime_type='audio/pcm',
616+
)
617+
),
618+
],
619+
)
620+
async def test_send_history_filters_audio(mock_gemini_session, audio_part):
621+
"""Test that audio parts (inline or file_data) are filtered out."""
622+
connection = GeminiLlmConnection(
623+
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
624+
)
625+
history = [
626+
types.Content(
627+
role='user',
628+
parts=[audio_part],
629+
),
630+
types.Content(
631+
role='model', parts=[types.Part.from_text(text='I heard you')]
632+
),
633+
]
634+
635+
await connection.send_history(history)
636+
637+
mock_gemini_session.send.assert_called_once()
638+
call_args = mock_gemini_session.send.call_args[1]
639+
sent_contents = call_args['input'].turns
640+
# Only the model response should be sent (user audio filtered out)
641+
assert len(sent_contents) == 1
642+
assert sent_contents[0].role == 'model'
643+
assert sent_contents[0].parts == [types.Part.from_text(text='I heard you')]
644+
645+
646+
@pytest.mark.asyncio
647+
async def test_send_history_keeps_image_data(mock_gemini_session):
648+
"""Test that image data is NOT filtered out."""
649+
connection = GeminiLlmConnection(
650+
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
651+
)
652+
image_blob = types.Blob(data=b'\x89PNG\r\n', mime_type='image/png')
653+
history = [
654+
types.Content(
655+
role='user',
656+
parts=[types.Part(inline_data=image_blob)],
657+
),
658+
types.Content(
659+
role='model', parts=[types.Part.from_text(text='Nice image!')]
660+
),
661+
]
662+
663+
await connection.send_history(history)
664+
665+
mock_gemini_session.send.assert_called_once()
666+
call_args = mock_gemini_session.send.call_args[1]
667+
sent_contents = call_args['input'].turns
668+
# Both contents should be sent (image is not filtered)
669+
assert len(sent_contents) == 2
670+
assert sent_contents[0].parts[0].inline_data == image_blob
671+
672+
673+
@pytest.mark.asyncio
674+
async def test_send_history_mixed_content_filters_only_audio(
675+
mock_gemini_session,
676+
):
677+
"""Test that mixed content keeps non-audio parts."""
678+
connection = GeminiLlmConnection(
679+
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
680+
)
681+
history = [
682+
types.Content(
683+
role='user',
684+
parts=[
685+
types.Part(
686+
inline_data=types.Blob(
687+
data=b'\x00\xFF', mime_type='audio/wav'
688+
)
689+
),
690+
types.Part.from_text(text='transcribed text'),
691+
],
692+
),
693+
]
694+
695+
await connection.send_history(history)
696+
697+
mock_gemini_session.send.assert_called_once()
698+
call_args = mock_gemini_session.send.call_args[1]
699+
sent_contents = call_args['input'].turns
700+
# Content should be sent but only with the text part
701+
assert len(sent_contents) == 1
702+
assert len(sent_contents[0].parts) == 1
703+
assert sent_contents[0].parts[0].text == 'transcribed text'
704+
705+
706+
@pytest.mark.asyncio
707+
async def test_send_history_all_audio_content_not_sent(mock_gemini_session):
708+
"""Test that content with only audio parts is completely removed."""
709+
connection = GeminiLlmConnection(
710+
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
711+
)
712+
history = [
713+
types.Content(
714+
role='user',
715+
parts=[
716+
types.Part(
717+
inline_data=types.Blob(
718+
data=b'\x00\xFF', mime_type='audio/pcm'
719+
)
720+
),
721+
types.Part(
722+
file_data=types.FileData(
723+
file_uri='artifact://audio.pcm#1',
724+
mime_type='audio/wav',
725+
)
726+
),
727+
],
728+
),
729+
]
730+
731+
await connection.send_history(history)
732+
733+
# No content should be sent since all parts are audio
734+
mock_gemini_session.send.assert_not_called()
735+
736+
737+
@pytest.mark.asyncio
738+
async def test_send_history_empty_history_not_sent(mock_gemini_session):
739+
"""Test that empty history does not call send."""
740+
connection = GeminiLlmConnection(
741+
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
742+
)
743+
744+
await connection.send_history([])
745+
746+
mock_gemini_session.send.assert_not_called()
747+
748+
749+
@pytest.mark.asyncio
750+
@pytest.mark.parametrize(
751+
'audio_mime_type',
752+
['audio/pcm', 'audio/wav', 'audio/mp3', 'audio/ogg'],
753+
)
754+
async def test_send_history_filters_various_audio_mime_types(
755+
mock_gemini_session,
756+
audio_mime_type,
757+
):
758+
"""Test that various audio mime types are all filtered."""
759+
connection = GeminiLlmConnection(
760+
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
761+
)
762+
history = [
763+
types.Content(
764+
role='user',
765+
parts=[
766+
types.Part(
767+
inline_data=types.Blob(data=b'', mime_type=audio_mime_type)
768+
)
769+
],
770+
),
771+
]
772+
773+
await connection.send_history(history)
774+
775+
# No content should be sent since the only part is audio
776+
mock_gemini_session.send.assert_not_called()

0 commit comments

Comments
 (0)