Skip to content

Commit a04828d

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Persist user input content to session in live mode
Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 859207592
1 parent 5d94146 commit a04828d

File tree

3 files changed

+111
-2
lines changed

3 files changed

+111
-2
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,25 @@ async def _send_to_model(
272272
await llm_connection.send_realtime(live_request.blob)
273273

274274
if live_request.content:
275+
content = live_request.content
276+
# Persist user text content to session (similar to non-live mode)
277+
# Skip function responses - they are already handled separately
278+
is_function_response = content.parts and any(
279+
part.function_response for part in content.parts
280+
)
281+
if not is_function_response:
282+
if not content.role:
283+
content.role = 'user'
284+
user_content_event = Event(
285+
id=Event.new_id(),
286+
invocation_id=invocation_context.invocation_id,
287+
author='user',
288+
content=content,
289+
)
290+
await invocation_context.session_service.append_event(
291+
session=invocation_context.session,
292+
event=user_content_event,
293+
)
275294
await llm_connection.send_content(live_request.content)
276295

277296
async def _receive_from_model(
@@ -391,8 +410,8 @@ async def _run_one_step_async(
391410
current_invocation=True, current_branch=True
392411
)
393412

394-
# Long-running tool calls should have been handled before this point.
395-
# If there are still long-running tool calls, it means the agent is paused
413+
# Long running tool calls should have been handled before this point.
414+
# If there are still long running tool calls, it means the agent is paused
396415
# before, and its branch hasn't been resumed yet.
397416
if (
398417
invocation_context.is_resumable

tests/unittests/streaming/test_streaming.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,89 @@ async def consume_responses(session: testing_utils.Session):
11201120
assert (
11211121
function_response_found
11221122
), 'Buffered function_response event was not yielded.'
1123+
1124+
1125+
def test_live_streaming_text_content_persisted_in_session():
1126+
"""Test that user text content sent via send_content is persisted in session."""
1127+
response1 = LlmResponse(
1128+
content=types.Content(
1129+
role='model', parts=[types.Part(text='Hello! How can I help you?')]
1130+
),
1131+
turn_complete=True,
1132+
)
1133+
1134+
mock_model = testing_utils.MockModel.create([response1])
1135+
1136+
root_agent = Agent(
1137+
name='root_agent',
1138+
model=mock_model,
1139+
tools=[],
1140+
)
1141+
1142+
class CustomTestRunner(testing_utils.InMemoryRunner):
1143+
1144+
def run_live_and_get_session(
1145+
self,
1146+
live_request_queue: LiveRequestQueue,
1147+
run_config: testing_utils.RunConfig = None,
1148+
) -> tuple[list[testing_utils.Event], testing_utils.Session]:
1149+
collected_responses = []
1150+
1151+
async def consume_responses(session: testing_utils.Session):
1152+
run_res = self.runner.run_live(
1153+
session=session,
1154+
live_request_queue=live_request_queue,
1155+
run_config=run_config or testing_utils.RunConfig(),
1156+
)
1157+
async for response in run_res:
1158+
collected_responses.append(response)
1159+
if len(collected_responses) >= 1:
1160+
return
1161+
1162+
try:
1163+
session = self.session
1164+
loop = asyncio.new_event_loop()
1165+
asyncio.set_event_loop(loop)
1166+
try:
1167+
loop.run_until_complete(
1168+
asyncio.wait_for(consume_responses(session), timeout=5.0)
1169+
)
1170+
finally:
1171+
loop.close()
1172+
except (asyncio.TimeoutError, asyncio.CancelledError):
1173+
pass
1174+
1175+
# Get the updated session
1176+
updated_session = self.runner.session_service.get_session_sync(
1177+
app_name=self.app_name,
1178+
user_id=session.user_id,
1179+
session_id=session.id,
1180+
)
1181+
return collected_responses, updated_session
1182+
1183+
runner = CustomTestRunner(root_agent=root_agent)
1184+
live_request_queue = LiveRequestQueue()
1185+
1186+
# Send text content (not audio blob)
1187+
user_text = 'Hello, this is a test message'
1188+
live_request_queue.send_content(
1189+
types.Content(role='user', parts=[types.Part(text=user_text)])
1190+
)
1191+
1192+
res_events, session = runner.run_live_and_get_session(live_request_queue)
1193+
1194+
assert res_events is not None, 'Expected a list of events, got None.'
1195+
1196+
# Check that user text content was persisted in the session
1197+
user_content_found = False
1198+
for event in session.events:
1199+
if event.author == 'user' and event.content:
1200+
for part in event.content.parts:
1201+
if part.text and user_text in part.text:
1202+
user_content_found = True
1203+
break
1204+
1205+
assert user_content_found, (
1206+
f'Expected user text content "{user_text}" to be persisted in session. '
1207+
f'Session events: {[e.content for e in session.events]}'
1208+
)

tests/unittests/testing_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ async def send_realtime(self, blob: types.Blob):
409409
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
410410
"""Yield each of the pre-defined LlmResponses."""
411411
for response in self.llm_responses:
412+
# Yield control to allow other tasks (like send_task) to run first.
413+
# This ensures user content gets persisted before the mock response
414+
# is yielded.
415+
await asyncio.sleep(0)
412416
yield response
413417

414418
async def close(self):

0 commit comments

Comments
 (0)