Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,21 @@ async def send_realtime(self, input: RealtimeInput):
else:
raise ValueError('Unsupported input type: %s' % type(input))

def __build_full_text_response(self, text: str):
def __build_full_text_response(
self,
text: str,
grounding_metadata: types.GroundingMetadata | None = None,
interrupted: bool = False,
):
"""Builds a full text response.

The text should not partial and the returned LlmResponse is not be
partial.

Args:
text: The text to be included in the response.
grounding_metadata: Optional grounding metadata to include.
interrupted: Whether this response was interrupted.

Returns:
An LlmResponse containing the full text.
Expand All @@ -156,6 +163,8 @@ def __build_full_text_response(self, text: str):
role='model',
parts=[types.Part.from_text(text=text)],
),
grounding_metadata=grounding_metadata,
interrupted=interrupted,
)

async def receive(self) -> AsyncGenerator[LlmResponse, None]:
Expand All @@ -166,6 +175,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""

text = ''
last_grounding_metadata = None
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
Expand All @@ -179,17 +189,38 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
if message.server_content:
content = message.server_content.model_turn
# Extract grounding_metadata from server_content (for VertexAiSearchTool, etc.)
grounding_metadata = message.server_content.grounding_metadata
if grounding_metadata:
last_grounding_metadata = grounding_metadata
Comment on lines +194 to +195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR description mentions that last_grounding_metadata is used to 'accumulate grounding across messages'. However, the current implementation last_grounding_metadata = grounding_metadata overwrites the previous value rather than accumulating or merging it.

If the Live API can send grounding_metadata in parts across multiple messages within a single turn (e.g., one message with retrieval_queries and a subsequent one with grounding_chunks), this implementation would lead to data loss.

If accumulation is the intended behavior, the metadata from different messages should be merged. For example:

current_grounding_metadata = message.server_content.grounding_metadata
if current_grounding_metadata:
  if not last_grounding_metadata:
    last_grounding_metadata = current_grounding_metadata
  else:
    # Example merge logic, actual implementation may vary
    # based on GroundingMetadata structure and API contract.
    if getattr(current_grounding_metadata, 'retrieval_queries', None):
      if not hasattr(last_grounding_metadata, 'retrieval_queries'):
          last_grounding_metadata.retrieval_queries = []
      last_grounding_metadata.retrieval_queries.extend(current_grounding_metadata.retrieval_queries)
    if getattr(current_grounding_metadata, 'grounding_chunks', None):
      if not hasattr(last_grounding_metadata, 'grounding_chunks'):
          last_grounding_metadata.grounding_chunks = []
      last_grounding_metadata.grounding_chunks.extend(current_grounding_metadata.grounding_chunks)
    # ... etc for other fields

Please clarify if the API guarantees that grounding_metadata is always sent completely in a single message, or if it's sent incrementally. If it's incremental, this logic should be updated to properly merge the data. If not, perhaps the PR description could be updated from 'accumulate' to 'persist' or 'carry over' to avoid confusion.

# Warn if grounding_metadata is incomplete (has queries but no chunks)
# This helps identify backend issues with Vertex AI Search
if (
grounding_metadata.retrieval_queries
and not grounding_metadata.grounding_chunks
):
logger.warning(
'Incomplete grounding_metadata received: retrieval_queries=%s'
' but grounding_chunks is empty. This may indicate a'
' transient issue with the Vertex AI Search backend.',
grounding_metadata.retrieval_queries,
)
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
content=content,
interrupted=message.server_content.interrupted,
grounding_metadata=grounding_metadata,
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata
)
text = ''

yield llm_response
# Note: in some cases, tool_call may arrive before
# generation_complete, causing transcription to appear after
Expand Down Expand Up @@ -266,32 +297,50 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
self._output_transcription_text = ''
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text,
last_grounding_metadata,
interrupted=message.server_content.interrupted,
)
text = ''
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
grounding_metadata=last_grounding_metadata,
)
last_grounding_metadata = None # Reset after yielding
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata, interrupted=True
)
text = ''
else:
yield LlmResponse(interrupted=message.server_content.interrupted)
yield LlmResponse(
interrupted=message.server_content.interrupted,
grounding_metadata=last_grounding_metadata,
)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(text, last_grounding_metadata)
text = ''
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
yield LlmResponse(
content=types.Content(role='model', parts=parts),
grounding_metadata=last_grounding_metadata,
)
# Note: last_grounding_metadata is NOT reset here because tool_call
# is part of an ongoing turn. The metadata persists until turn_complete
# or interrupted with break, ensuring subsequent messages in the same
# turn can access the grounding information.
if message.session_resumption_update:
logger.debug('Received session resumption message: %s', message)
yield (
Expand Down
Loading