-
Notifications
You must be signed in to change notification settings - Fork 2.8k
fix: Extract grounding_metadata from Live API server_content #4213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
30af2f1
377a632
b0bd310
494141d
2fabdb4
8e9fee6
dfe53f7
0cd44c8
27f3391
e944f08
6bc5c0f
e57fa04
e908a0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -202,6 +202,7 @@ async def test_receive_usage_metadata_and_server_content( | |
| mock_server_content.input_transcription = None | ||
| mock_server_content.output_transcription = None | ||
| mock_server_content.turn_complete = False | ||
| mock_server_content.grounding_metadata = None | ||
|
|
||
| mock_message = mock.AsyncMock() | ||
| mock_message.usage_metadata = usage_metadata | ||
|
|
@@ -261,6 +262,7 @@ async def test_receive_transcript_finished_on_interrupt( | |
| message1.server_content.output_transcription = None | ||
| message1.server_content.turn_complete = False | ||
| message1.server_content.generation_complete = False | ||
| message1.server_content.grounding_metadata = None | ||
| message1.tool_call = None | ||
| message1.session_resumption_update = None | ||
|
|
||
|
|
@@ -275,6 +277,7 @@ async def test_receive_transcript_finished_on_interrupt( | |
| ) | ||
| message2.server_content.turn_complete = False | ||
| message2.server_content.generation_complete = False | ||
| message2.server_content.grounding_metadata = None | ||
| message2.tool_call = None | ||
| message2.session_resumption_update = None | ||
|
|
||
|
|
@@ -287,6 +290,7 @@ async def test_receive_transcript_finished_on_interrupt( | |
| message3.server_content.output_transcription = None | ||
| message3.server_content.turn_complete = False | ||
| message3.server_content.generation_complete = False | ||
| message3.server_content.grounding_metadata = None | ||
| message3.tool_call = None | ||
| message3.session_resumption_update = None | ||
|
|
||
|
|
@@ -408,6 +412,7 @@ async def test_receive_transcript_finished_on_turn_complete( | |
| message1.server_content.output_transcription = None | ||
| message1.server_content.turn_complete = False | ||
| message1.server_content.generation_complete = False | ||
| message1.server_content.grounding_metadata = None | ||
| message1.tool_call = None | ||
| message1.session_resumption_update = None | ||
|
|
||
|
|
@@ -422,6 +427,7 @@ async def test_receive_transcript_finished_on_turn_complete( | |
| ) | ||
| message2.server_content.turn_complete = False | ||
| message2.server_content.generation_complete = False | ||
| message2.server_content.grounding_metadata = None | ||
| message2.tool_call = None | ||
| message2.session_resumption_update = None | ||
|
|
||
|
|
@@ -434,6 +440,7 @@ async def test_receive_transcript_finished_on_turn_complete( | |
| message3.server_content.output_transcription = None | ||
| message3.server_content.turn_complete = True | ||
| message3.server_content.generation_complete = False | ||
| message3.server_content.grounding_metadata = None | ||
| message3.tool_call = None | ||
| message3.session_resumption_update = None | ||
|
|
||
|
|
@@ -774,3 +781,213 @@ async def test_send_history_filters_various_audio_mime_types( | |
|
|
||
| # No content should be sent since the only part is audio | ||
| mock_gemini_session.send.assert_not_called() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_receive_extracts_grounding_metadata( | ||
| gemini_connection, mock_gemini_session | ||
| ): | ||
|
Comment on lines
+787
to
+789
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While this test and the subsequent new tests are thorough, they introduce a significant amount of boilerplate for creating mock objects. This repetition across To improve maintainability, consider creating one or more helper functions to build the mock For example: def _create_mock_server_content(
model_turn=None,
grounding_metadata=None,
interrupted=False,
turn_complete=False,
**kwargs
):
return mock.Mock(
model_turn=model_turn,
grounding_metadata=grounding_metadata,
interrupted=interrupted,
turn_complete=turn_complete,
input_transcription=None,
output_transcription=None,
generation_complete=False,
**kwargs
)
def _create_mock_message(server_content=None, tool_call=None, **kwargs):
return mock.Mock(
server_content=server_content,
tool_call=tool_call,
usage_metadata=None,
session_resumption_update=None,
**kwargs
)This would make the setup for each test much more concise. |
||
| """Test that grounding_metadata is extracted from server_content and included in LlmResponse.""" | ||
| mock_content = types.Content( | ||
| role='model', parts=[types.Part.from_text(text='response text')] | ||
| ) | ||
| mock_grounding_metadata = types.GroundingMetadata( | ||
| retrieval_queries=['test query'], | ||
| web_search_queries=['web search query'], | ||
| ) | ||
|
|
||
| mock_server_content = mock.Mock() | ||
| mock_server_content.model_turn = mock_content | ||
| mock_server_content.interrupted = False | ||
| mock_server_content.input_transcription = None | ||
| mock_server_content.output_transcription = None | ||
| mock_server_content.turn_complete = True | ||
| mock_server_content.generation_complete = False | ||
| mock_server_content.grounding_metadata = mock_grounding_metadata | ||
|
|
||
| mock_message = mock.Mock() | ||
| mock_message.usage_metadata = None | ||
| mock_message.server_content = mock_server_content | ||
| mock_message.tool_call = None | ||
| mock_message.session_resumption_update = None | ||
|
|
||
| async def mock_receive_generator(): | ||
| yield mock_message | ||
|
|
||
| receive_mock = mock.Mock(return_value=mock_receive_generator()) | ||
| mock_gemini_session.receive = receive_mock | ||
|
|
||
| responses = [resp async for resp in gemini_connection.receive()] | ||
|
|
||
| # Should have at least 2 responses: content with grounding and turn_complete | ||
| assert len(responses) >= 2 | ||
|
|
||
| # Find response with content | ||
| content_response = next((r for r in responses if r.content), None) | ||
| assert content_response is not None | ||
| assert content_response.grounding_metadata == mock_grounding_metadata | ||
| assert content_response.grounding_metadata.retrieval_queries == ['test query'] | ||
| assert content_response.grounding_metadata.web_search_queries == [ | ||
| 'web search query' | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_receive_grounding_metadata_at_turn_complete( | ||
| gemini_connection, mock_gemini_session | ||
| ): | ||
| """Test that grounding_metadata is included in turn_complete response if no text was built.""" | ||
| mock_grounding_metadata = types.GroundingMetadata( | ||
| retrieval_queries=['test query'], | ||
| ) | ||
|
|
||
| # First message with grounding but no content | ||
| mock_server_content1 = mock.Mock() | ||
| mock_server_content1.model_turn = None | ||
| mock_server_content1.interrupted = False | ||
| mock_server_content1.input_transcription = None | ||
| mock_server_content1.output_transcription = None | ||
| mock_server_content1.turn_complete = False | ||
| mock_server_content1.generation_complete = False | ||
| mock_server_content1.grounding_metadata = mock_grounding_metadata | ||
|
|
||
| message1 = mock.Mock() | ||
| message1.usage_metadata = None | ||
| message1.server_content = mock_server_content1 | ||
| message1.tool_call = None | ||
| message1.session_resumption_update = None | ||
|
|
||
| # Second message with turn_complete | ||
| mock_server_content2 = mock.Mock() | ||
| mock_server_content2.model_turn = None | ||
| mock_server_content2.interrupted = False | ||
| mock_server_content2.input_transcription = None | ||
| mock_server_content2.output_transcription = None | ||
| mock_server_content2.turn_complete = True | ||
| mock_server_content2.generation_complete = False | ||
| mock_server_content2.grounding_metadata = None | ||
|
|
||
| message2 = mock.Mock() | ||
| message2.usage_metadata = None | ||
| message2.server_content = mock_server_content2 | ||
| message2.tool_call = None | ||
| message2.session_resumption_update = None | ||
|
|
||
| async def mock_receive_generator(): | ||
| yield message1 | ||
| yield message2 | ||
|
|
||
| receive_mock = mock.Mock(return_value=mock_receive_generator()) | ||
| mock_gemini_session.receive = receive_mock | ||
|
|
||
| responses = [resp async for resp in gemini_connection.receive()] | ||
|
|
||
| # Find turn_complete response | ||
| turn_complete_response = next((r for r in responses if r.turn_complete), None) | ||
| assert turn_complete_response is not None | ||
| # The grounding_metadata should be carried over to turn_complete | ||
| assert turn_complete_response.grounding_metadata == mock_grounding_metadata | ||
VedantMadane marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_receive_grounding_metadata_with_text_and_turn_complete( | ||
| gemini_connection, mock_gemini_session | ||
| ): | ||
| """Test that grounding_metadata is preserved when text content is followed by turn_complete.""" | ||
| mock_content = types.Content( | ||
| role='model', parts=[types.Part.from_text(text='response text')] | ||
| ) | ||
| mock_grounding_metadata = types.GroundingMetadata( | ||
| retrieval_queries=['test query'], | ||
| ) | ||
|
|
||
| # Message with both content and grounding, followed by turn_complete | ||
| mock_server_content = mock.Mock() | ||
| mock_server_content.model_turn = mock_content | ||
| mock_server_content.interrupted = False | ||
| mock_server_content.input_transcription = None | ||
| mock_server_content.output_transcription = None | ||
| mock_server_content.turn_complete = True | ||
| mock_server_content.generation_complete = False | ||
| mock_server_content.grounding_metadata = mock_grounding_metadata | ||
|
|
||
| mock_message = mock.Mock() | ||
| mock_message.usage_metadata = None | ||
| mock_message.server_content = mock_server_content | ||
| mock_message.tool_call = None | ||
| mock_message.session_resumption_update = None | ||
|
|
||
| async def mock_receive_generator(): | ||
| yield mock_message | ||
|
|
||
| receive_mock = mock.Mock(return_value=mock_receive_generator()) | ||
| mock_gemini_session.receive = receive_mock | ||
|
|
||
| responses = [resp async for resp in gemini_connection.receive()] | ||
|
|
||
| # Find content response with grounding | ||
| content_response = next((r for r in responses if r.content), None) | ||
| assert content_response is not None | ||
| assert content_response.grounding_metadata == mock_grounding_metadata | ||
|
|
||
| # Find turn_complete response - should also have grounding_metadata | ||
| turn_complete_response = next((r for r in responses if r.turn_complete), None) | ||
| assert turn_complete_response is not None | ||
| assert turn_complete_response.grounding_metadata == mock_grounding_metadata | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_receive_grounding_metadata_with_tool_call( | ||
| gemini_connection, mock_gemini_session | ||
| ): | ||
| """Test that grounding_metadata is propagated with tool_call responses.""" | ||
| mock_grounding_metadata = types.GroundingMetadata( | ||
| retrieval_queries=['test query'], | ||
| ) | ||
|
|
||
| # First message with grounding metadata | ||
| mock_server_content1 = mock.Mock() | ||
| mock_server_content1.model_turn = None | ||
| mock_server_content1.interrupted = False | ||
| mock_server_content1.input_transcription = None | ||
| mock_server_content1.output_transcription = None | ||
| mock_server_content1.turn_complete = False | ||
| mock_server_content1.generation_complete = False | ||
| mock_server_content1.grounding_metadata = mock_grounding_metadata | ||
|
|
||
| message1 = mock.Mock() | ||
| message1.usage_metadata = None | ||
| message1.server_content = mock_server_content1 | ||
| message1.tool_call = None | ||
| message1.session_resumption_update = None | ||
|
|
||
| # Second message with tool_call | ||
| mock_function_call = types.FunctionCall( | ||
| name='test_function', args={'param': 'value'} | ||
| ) | ||
| mock_tool_call = mock.Mock() | ||
| mock_tool_call.function_calls = [mock_function_call] | ||
|
|
||
| message2 = mock.Mock() | ||
| message2.usage_metadata = None | ||
| message2.server_content = None | ||
| message2.tool_call = mock_tool_call | ||
| message2.session_resumption_update = None | ||
|
|
||
| async def mock_receive_generator(): | ||
| yield message1 | ||
| yield message2 | ||
|
|
||
| receive_mock = mock.Mock(return_value=mock_receive_generator()) | ||
| mock_gemini_session.receive = receive_mock | ||
|
|
||
| responses = [resp async for resp in gemini_connection.receive()] | ||
|
|
||
| # Find tool_call response | ||
| tool_call_response = next( | ||
| (r for r in responses if r.content and r.content.parts[0].function_call), | ||
| None, | ||
| ) | ||
| assert tool_call_response is not None | ||
| # The grounding_metadata should be carried over to tool_call | ||
| assert tool_call_response.grounding_metadata == mock_grounding_metadata | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description mentions that
last_grounding_metadatais used to 'accumulate grounding across messages'. However, the current implementationlast_grounding_metadata = grounding_metadataoverwrites the previous value rather than accumulating or merging it.If the Live API can send
grounding_metadatain parts across multiple messages within a single turn (e.g., one message withretrieval_queriesand a subsequent one withgrounding_chunks), this implementation would lead to data loss.If accumulation is the intended behavior, the metadata from different messages should be merged. For example:
Please clarify if the API guarantees that
grounding_metadatais always sent completely in a single message, or if it's sent incrementally. If it's incremental, this logic should be updated to properly merge the data. If not, perhaps the PR description could be updated from 'accumulate' to 'persist' or 'carry over' to avoid confusion.