-
Notifications
You must be signed in to change notification settings - Fork 179
contrib/google_adk_agents: stream LlmResponse chunks via Workflow Streams #1498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jssmith
wants to merge
4
commits into
main
Choose a base branch
from
contrib/google-adk-streaming
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+311
−9
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
d817542
contrib: google_adk_agents streaming integration
jssmith 89eae56
Update tests/contrib/google_adk_agents/test_adk_streaming.py
brianstrauch 6f2fa04
Update tests/contrib/google_adk_agents/test_adk_streaming.py
brianstrauch 347b74c
rename params
brianstrauch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| """Integration tests for ADK streaming support. | ||
|
|
||
| Verifies that the streaming model activity publishes raw ``LlmResponse`` | ||
| chunks via the WorkflowStream broker. Non-streaming behavior is covered | ||
| by ``test_google_adk_agents.py``. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import uuid | ||
| from collections.abc import AsyncGenerator | ||
| from datetime import timedelta | ||
|
|
||
| import pytest | ||
| from google.adk import Agent | ||
| from google.adk.agents.run_config import RunConfig, StreamingMode | ||
| from google.adk.models import BaseLlm, LLMRegistry | ||
| from google.adk.models.llm_request import LlmRequest | ||
| from google.adk.models.llm_response import LlmResponse | ||
| from google.adk.runners import InMemoryRunner | ||
| from google.genai.types import Content, Part | ||
|
|
||
| from temporalio import workflow | ||
| from temporalio.client import Client, WorkflowFailureError | ||
| from temporalio.contrib.google_adk_agents import GoogleAdkPlugin, TemporalModel | ||
| from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient | ||
| from temporalio.worker import Worker | ||
|
|
||
|
|
||
| class StreamingTestModel(BaseLlm): | ||
| """Test model that yields multiple partial responses to simulate streaming.""" | ||
|
|
||
| @classmethod | ||
| def supported_models(cls) -> list[str]: | ||
| return ["streaming_test_model"] | ||
|
|
||
| async def generate_content_async( | ||
| self, llm_request: LlmRequest, stream: bool = False | ||
| ) -> AsyncGenerator[LlmResponse, None]: | ||
| # The streaming activity must call us with stream=True; if a | ||
| # regression drops the flag this test should fail. | ||
| if not stream: | ||
| raise AssertionError( | ||
| "StreamingTestModel.generate_content_async requires stream=True" | ||
| ) | ||
| yield LlmResponse(content=Content(role="model", parts=[Part(text="Hello ")])) | ||
| yield LlmResponse(content=Content(role="model", parts=[Part(text="world!")])) | ||
|
|
||
|
|
||
| @workflow.defn | ||
| class StreamingAdkWorkflow: | ||
| """Test workflow that opts into streaming via RunConfig.streaming_mode.""" | ||
|
|
||
| @workflow.init | ||
| def __init__(self, prompt: str) -> None: | ||
| self.stream = WorkflowStream() | ||
|
|
||
| @workflow.run | ||
| async def run(self, prompt: str) -> str: | ||
| model = TemporalModel("streaming_test_model", streaming_topic="events") | ||
| agent = Agent( | ||
| name="test_agent", | ||
| model=model, | ||
| instruction="You are a test agent.", | ||
| ) | ||
|
|
||
| runner = InMemoryRunner(agent=agent, app_name="test-app") | ||
| session = await runner.session_service.create_session( | ||
| app_name="test-app", user_id="test" | ||
| ) | ||
|
|
||
| final_text = "" | ||
| async for event in runner.run_async( | ||
| user_id="test", | ||
| session_id=session.id, | ||
| new_message=Content(role="user", parts=[Part(text=prompt)]), | ||
| run_config=RunConfig(streaming_mode=StreamingMode.SSE), | ||
| ): | ||
| if event.content and event.content.parts: | ||
| for part in event.content.parts: | ||
| if part.text: | ||
| final_text = part.text | ||
|
|
||
| return final_text | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_streaming_publishes_events(client: Client): | ||
| """Streaming activity publishes raw LlmResponse chunks to the topic.""" | ||
| LLMRegistry.register(StreamingTestModel) | ||
|
|
||
| new_config = client.config() | ||
| new_config["plugins"] = [GoogleAdkPlugin()] | ||
| client = Client(**new_config) | ||
|
|
||
| workflow_id = f"adk-streaming-test-{uuid.uuid4()}" | ||
|
|
||
| async with Worker( | ||
| client, | ||
| task_queue="adk-streaming-test", | ||
| workflows=[StreamingAdkWorkflow], | ||
| max_cached_workflows=0, | ||
| ): | ||
| handle = await client.start_workflow( | ||
| StreamingAdkWorkflow.run, | ||
| "Hello", | ||
| id=workflow_id, | ||
| task_queue="adk-streaming-test", | ||
| execution_timeout=timedelta(seconds=30), | ||
| ) | ||
|
|
||
| stream = WorkflowStreamClient.create(client, workflow_id) | ||
| responses: list[LlmResponse] = [] | ||
|
|
||
| async def collect_events() -> None: | ||
| async for item in stream.subscribe( | ||
| ["events"], | ||
| from_offset=0, | ||
| result_type=LlmResponse, | ||
| poll_cooldown=timedelta(milliseconds=50), | ||
| ): | ||
| responses.append(item.data) | ||
| if len(responses) >= 2: | ||
| break | ||
|
|
||
| collect_task = asyncio.create_task(collect_events()) | ||
| result = await handle.result() | ||
| await asyncio.wait_for(collect_task, timeout=10.0) | ||
|
|
||
| # Workflow assembles streamed parts; the last part it observes is "world!". | ||
| assert result == "world!" | ||
|
|
||
| texts: list[str] = [] | ||
| for r in responses: | ||
| if r.content and r.content.parts: | ||
| for part in r.content.parts: | ||
| if part.text: | ||
| texts.append(part.text) | ||
| assert texts == ["Hello ", "world!"], f"Unexpected text deltas: {texts}" | ||
|
|
||
|
|
||
| @workflow.defn | ||
| class StreamingAdkRequiresTopicWorkflow: | ||
| """Calls ``generate_content_async(stream=True)`` without configuring | ||
| ``streaming_topic``; the call must raise before any activity | ||
| is scheduled.""" | ||
|
|
||
| @workflow.run | ||
| async def run(self, prompt: str) -> str: | ||
| model = TemporalModel("streaming_test_model") | ||
| agent = Agent( | ||
| name="test_agent", | ||
| model=model, | ||
| instruction="You are a test agent.", | ||
| ) | ||
| runner = InMemoryRunner(agent=agent, app_name="test-app") | ||
| session = await runner.session_service.create_session( | ||
| app_name="test-app", user_id="test" | ||
| ) | ||
| async for _ in runner.run_async( | ||
| user_id="test", | ||
| session_id=session.id, | ||
| new_message=Content(role="user", parts=[Part(text=prompt)]), | ||
| run_config=RunConfig(streaming_mode=StreamingMode.SSE), | ||
| ): | ||
| pass | ||
| return "should not reach" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_streaming_requires_topic(client: Client): | ||
| """``stream=True`` fails fast when no streaming topic was configured | ||
| on ``TemporalModel``. The error is raised in the workflow before any | ||
| streaming activity is scheduled.""" | ||
| LLMRegistry.register(StreamingTestModel) | ||
|
|
||
| new_config = client.config() | ||
| new_config["plugins"] = [GoogleAdkPlugin()] | ||
| client = Client(**new_config) | ||
|
|
||
| async with Worker( | ||
| client, | ||
| task_queue="adk-streaming-requires-topic", | ||
| workflows=[StreamingAdkRequiresTopicWorkflow], | ||
| max_cached_workflows=0, | ||
| ): | ||
| with pytest.raises(WorkflowFailureError) as exc_info: | ||
| await client.execute_workflow( | ||
| StreamingAdkRequiresTopicWorkflow.run, | ||
| "Hi", | ||
| id=f"adk-streaming-requires-topic-{uuid.uuid4()}", | ||
| task_queue="adk-streaming-requires-topic", | ||
| execution_timeout=timedelta(seconds=30), | ||
| ) | ||
|
|
||
| assert "streaming_topic" in str(exc_info.value.cause) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.