Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,19 @@ async def send_realtime(self, input: RealtimeInput):
else:
raise ValueError('Unsupported input type: %s' % type(input))

def __build_full_text_response(self, text: str):
def __build_full_text_response(
self,
text: str,
grounding_metadata: types.GroundingMetadata | None = None,
):
"""Builds a full text response.

The text should not partial and the returned LlmResponse is not be
partial.

Args:
text: The text to be included in the response.
grounding_metadata: Optional grounding metadata to include.

Returns:
An LlmResponse containing the full text.
Expand All @@ -156,6 +161,7 @@ def __build_full_text_response(self, text: str):
role='model',
parts=[types.Part.from_text(text=text)],
),
grounding_metadata=grounding_metadata,
)

async def receive(self) -> AsyncGenerator[LlmResponse, None]:
Expand All @@ -166,6 +172,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""

text = ''
last_grounding_metadata = None
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
Expand All @@ -179,17 +186,38 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
if message.server_content:
content = message.server_content.model_turn
# Extract grounding_metadata from server_content (for VertexAiSearchTool, etc.)
grounding_metadata = message.server_content.grounding_metadata
if grounding_metadata:
last_grounding_metadata = grounding_metadata
Comment on lines +194 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR description mentions that last_grounding_metadata is used to 'accumulate grounding across messages'. However, the current implementation last_grounding_metadata = grounding_metadata overwrites the previous value rather than accumulating or merging it.

If the Live API can send grounding_metadata in parts across multiple messages within a single turn (e.g., one message with retrieval_queries and a subsequent one with grounding_chunks), this implementation would lead to data loss.

If accumulation is the intended behavior, the metadata from different messages should be merged. For example:

current_grounding_metadata = message.server_content.grounding_metadata
if current_grounding_metadata:
  if not last_grounding_metadata:
    last_grounding_metadata = current_grounding_metadata
  else:
    # Example merge logic, actual implementation may vary
    # based on GroundingMetadata structure and API contract.
    if getattr(current_grounding_metadata, 'retrieval_queries', None):
      if not hasattr(last_grounding_metadata, 'retrieval_queries'):
          last_grounding_metadata.retrieval_queries = []
      last_grounding_metadata.retrieval_queries.extend(current_grounding_metadata.retrieval_queries)
    if getattr(current_grounding_metadata, 'grounding_chunks', None):
      if not hasattr(last_grounding_metadata, 'grounding_chunks'):
          last_grounding_metadata.grounding_chunks = []
      last_grounding_metadata.grounding_chunks.extend(current_grounding_metadata.grounding_chunks)
    # ... etc for other fields

Please clarify if the API guarantees that grounding_metadata is always sent completely in a single message, or if it's sent incrementally. If it's incremental, this logic should be updated to properly merge the data. If not, perhaps the PR description could be updated from 'accumulate' to 'persist' or 'carry over' to avoid confusion.

# Warn if grounding_metadata is incomplete (has queries but no chunks)
# This helps identify backend issues with Vertex AI Search
if (
grounding_metadata.retrieval_queries
and not grounding_metadata.grounding_chunks
):
logger.warning(
'Incomplete grounding_metadata received: retrieval_queries=%s'
' but grounding_chunks is empty. This may indicate a'
' transient issue with the Vertex AI Search backend.',
grounding_metadata.retrieval_queries,
)
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
content=content,
interrupted=message.server_content.interrupted,
grounding_metadata=grounding_metadata,
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata
)
text = ''
last_grounding_metadata = None
yield llm_response
# Note: in some cases, tool_call may arrive before
# generation_complete, causing transcription to appear after
Expand Down Expand Up @@ -266,32 +294,45 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
self._output_transcription_text = ''
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata
)
text = ''
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
grounding_metadata=last_grounding_metadata,
)
last_grounding_metadata = None # Reset after yielding
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata
)
text = ''
else:
yield LlmResponse(interrupted=message.server_content.interrupted)
yield LlmResponse(
interrupted=message.server_content.interrupted,
grounding_metadata=last_grounding_metadata,
)
last_grounding_metadata = None # Reset after yielding
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(text, last_grounding_metadata)
text = ''
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
yield LlmResponse(
content=types.Content(role='model', parts=parts),
grounding_metadata=last_grounding_metadata,
)
if message.session_resumption_update:
logger.debug('Received session resumption message: %s', message)
yield (
Expand Down
217 changes: 217 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ async def test_receive_usage_metadata_and_server_content(
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.turn_complete = False
mock_server_content.grounding_metadata = None

mock_message = mock.AsyncMock()
mock_message.usage_metadata = usage_metadata
Expand Down Expand Up @@ -261,6 +262,7 @@ async def test_receive_transcript_finished_on_interrupt(
message1.server_content.output_transcription = None
message1.server_content.turn_complete = False
message1.server_content.generation_complete = False
message1.server_content.grounding_metadata = None
message1.tool_call = None
message1.session_resumption_update = None

Expand All @@ -275,6 +277,7 @@ async def test_receive_transcript_finished_on_interrupt(
)
message2.server_content.turn_complete = False
message2.server_content.generation_complete = False
message2.server_content.grounding_metadata = None
message2.tool_call = None
message2.session_resumption_update = None

Expand All @@ -287,6 +290,7 @@ async def test_receive_transcript_finished_on_interrupt(
message3.server_content.output_transcription = None
message3.server_content.turn_complete = False
message3.server_content.generation_complete = False
message3.server_content.grounding_metadata = None
message3.tool_call = None
message3.session_resumption_update = None

Expand Down Expand Up @@ -408,6 +412,7 @@ async def test_receive_transcript_finished_on_turn_complete(
message1.server_content.output_transcription = None
message1.server_content.turn_complete = False
message1.server_content.generation_complete = False
message1.server_content.grounding_metadata = None
message1.tool_call = None
message1.session_resumption_update = None

Expand All @@ -422,6 +427,7 @@ async def test_receive_transcript_finished_on_turn_complete(
)
message2.server_content.turn_complete = False
message2.server_content.generation_complete = False
message2.server_content.grounding_metadata = None
message2.tool_call = None
message2.session_resumption_update = None

Expand All @@ -434,6 +440,7 @@ async def test_receive_transcript_finished_on_turn_complete(
message3.server_content.output_transcription = None
message3.server_content.turn_complete = True
message3.server_content.generation_complete = False
message3.server_content.grounding_metadata = None
message3.tool_call = None
message3.session_resumption_update = None

Expand Down Expand Up @@ -774,3 +781,213 @@ async def test_send_history_filters_various_audio_mime_types(

# No content should be sent since the only part is audio
mock_gemini_session.send.assert_not_called()


@pytest.mark.asyncio
async def test_receive_extracts_grounding_metadata(
gemini_connection, mock_gemini_session
):
Comment on lines +787 to +789
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this test and the subsequent new tests are thorough, they introduce a significant amount of boilerplate for creating mock objects. This repetition across test_receive_extracts_grounding_metadata, test_receive_grounding_metadata_at_turn_complete, and others makes the test suite harder to read and maintain.

To improve maintainability, consider creating one or more helper functions to build the mock message and server_content objects. This would centralize the mock setup, reduce code duplication, and make the tests more focused on the specific behavior they are verifying.

For example:

def _create_mock_server_content(
    model_turn=None,
    grounding_metadata=None,
    interrupted=False,
    turn_complete=False,
    **kwargs
):
    return mock.Mock(
        model_turn=model_turn,
        grounding_metadata=grounding_metadata,
        interrupted=interrupted,
        turn_complete=turn_complete,
        input_transcription=None,
        output_transcription=None,
        generation_complete=False,
        **kwargs
    )

def _create_mock_message(server_content=None, tool_call=None, **kwargs):
    return mock.Mock(
        server_content=server_content,
        tool_call=tool_call,
        usage_metadata=None,
        session_resumption_update=None,
        **kwargs
    )

This would make the setup for each test much more concise.

"""Test that grounding_metadata is extracted from server_content and included in LlmResponse."""
mock_content = types.Content(
role='model', parts=[types.Part.from_text(text='response text')]
)
mock_grounding_metadata = types.GroundingMetadata(
retrieval_queries=['test query'],
web_search_queries=['web search query'],
)

mock_server_content = mock.Mock()
mock_server_content.model_turn = mock_content
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.turn_complete = True
mock_server_content.generation_complete = False
mock_server_content.grounding_metadata = mock_grounding_metadata

mock_message = mock.Mock()
mock_message.usage_metadata = None
mock_message.server_content = mock_server_content
mock_message.tool_call = None
mock_message.session_resumption_update = None

async def mock_receive_generator():
yield mock_message

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Should have at least 2 responses: content with grounding and turn_complete
assert len(responses) >= 2

# Find response with content
content_response = next((r for r in responses if r.content), None)
assert content_response is not None
assert content_response.grounding_metadata == mock_grounding_metadata
assert content_response.grounding_metadata.retrieval_queries == ['test query']
assert content_response.grounding_metadata.web_search_queries == [
'web search query'
]


@pytest.mark.asyncio
async def test_receive_grounding_metadata_at_turn_complete(
gemini_connection, mock_gemini_session
):
"""Test that grounding_metadata is included in turn_complete response if no text was built."""
mock_grounding_metadata = types.GroundingMetadata(
retrieval_queries=['test query'],
)

# First message with grounding but no content
mock_server_content1 = mock.Mock()
mock_server_content1.model_turn = None
mock_server_content1.interrupted = False
mock_server_content1.input_transcription = None
mock_server_content1.output_transcription = None
mock_server_content1.turn_complete = False
mock_server_content1.generation_complete = False
mock_server_content1.grounding_metadata = mock_grounding_metadata

message1 = mock.Mock()
message1.usage_metadata = None
message1.server_content = mock_server_content1
message1.tool_call = None
message1.session_resumption_update = None

# Second message with turn_complete
mock_server_content2 = mock.Mock()
mock_server_content2.model_turn = None
mock_server_content2.interrupted = False
mock_server_content2.input_transcription = None
mock_server_content2.output_transcription = None
mock_server_content2.turn_complete = True
mock_server_content2.generation_complete = False
mock_server_content2.grounding_metadata = None

message2 = mock.Mock()
message2.usage_metadata = None
message2.server_content = mock_server_content2
message2.tool_call = None
message2.session_resumption_update = None

async def mock_receive_generator():
yield message1
yield message2

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Find turn_complete response
turn_complete_response = next((r for r in responses if r.turn_complete), None)
assert turn_complete_response is not None
# The grounding_metadata should be carried over to turn_complete
assert turn_complete_response.grounding_metadata == mock_grounding_metadata


@pytest.mark.asyncio
async def test_receive_grounding_metadata_with_text_and_turn_complete(
gemini_connection, mock_gemini_session
):
"""Test that grounding_metadata is preserved when text content is followed by turn_complete."""
mock_content = types.Content(
role='model', parts=[types.Part.from_text(text='response text')]
)
mock_grounding_metadata = types.GroundingMetadata(
retrieval_queries=['test query'],
)

# Message with both content and grounding, followed by turn_complete
mock_server_content = mock.Mock()
mock_server_content.model_turn = mock_content
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.turn_complete = True
mock_server_content.generation_complete = False
mock_server_content.grounding_metadata = mock_grounding_metadata

mock_message = mock.Mock()
mock_message.usage_metadata = None
mock_message.server_content = mock_server_content
mock_message.tool_call = None
mock_message.session_resumption_update = None

async def mock_receive_generator():
yield mock_message

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Find content response with grounding
content_response = next((r for r in responses if r.content), None)
assert content_response is not None
assert content_response.grounding_metadata == mock_grounding_metadata

# Find turn_complete response - should also have grounding_metadata
turn_complete_response = next((r for r in responses if r.turn_complete), None)
assert turn_complete_response is not None
assert turn_complete_response.grounding_metadata == mock_grounding_metadata


@pytest.mark.asyncio
async def test_receive_grounding_metadata_with_tool_call(
gemini_connection, mock_gemini_session
):
"""Test that grounding_metadata is propagated with tool_call responses."""
mock_grounding_metadata = types.GroundingMetadata(
retrieval_queries=['test query'],
)

# First message with grounding metadata
mock_server_content1 = mock.Mock()
mock_server_content1.model_turn = None
mock_server_content1.interrupted = False
mock_server_content1.input_transcription = None
mock_server_content1.output_transcription = None
mock_server_content1.turn_complete = False
mock_server_content1.generation_complete = False
mock_server_content1.grounding_metadata = mock_grounding_metadata

message1 = mock.Mock()
message1.usage_metadata = None
message1.server_content = mock_server_content1
message1.tool_call = None
message1.session_resumption_update = None

# Second message with tool_call
mock_function_call = types.FunctionCall(
name='test_function', args={'param': 'value'}
)
mock_tool_call = mock.Mock()
mock_tool_call.function_calls = [mock_function_call]

message2 = mock.Mock()
message2.usage_metadata = None
message2.server_content = None
message2.tool_call = mock_tool_call
message2.session_resumption_update = None

async def mock_receive_generator():
yield message1
yield message2

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

# Find tool_call response
tool_call_response = next(
(r for r in responses if r.content and r.content.parts[0].function_call),
None,
)
assert tool_call_response is not None
# The grounding_metadata should be carried over to tool_call
assert tool_call_response.grounding_metadata == mock_grounding_metadata