Skip to content

Commit d817542

Browse files
jssmithclaude
andcommitted
contrib: google_adk_agents streaming integration
Re-applies the google_adk_agents streaming integration originally split out of PR #1423 on commit 59c7582. The bridge honors `stream=True` and publishes raw `LlmResponse` chunks through a typed topic handle. Opt in via the plugin's `streaming_event_topic` parameter. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 85dde16 commit d817542

3 files changed

Lines changed: 314 additions & 9 deletions

File tree

temporalio/contrib/google_adk_agents/_model.py

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import AsyncGenerator, Callable
2+
from dataclasses import dataclass
23
from datetime import timedelta
34

45
from google.adk.models import BaseLlm, LLMRegistry
@@ -7,6 +8,8 @@
78

89
import temporalio.workflow
910
from temporalio import activity, workflow
11+
from temporalio.contrib.workflow_streams import WorkflowStreamClient
12+
from temporalio.exceptions import ApplicationError
1013
from temporalio.workflow import ActivityConfig
1114

1215

@@ -36,6 +39,58 @@ async def invoke_model(llm_request: LlmRequest) -> list[LlmResponse]:
3639
]
3740

3841

42+
@dataclass
43+
class StreamingInvokeInput:
44+
"""Input for :func:`invoke_model_streaming`."""
45+
46+
llm_request: LlmRequest
47+
streaming_event_topic: str
48+
streaming_event_batch_interval: timedelta
49+
50+
51+
@activity.defn
52+
async def invoke_model_streaming(
53+
input: StreamingInvokeInput,
54+
) -> list[LlmResponse]:
55+
"""Streaming-aware model activity.
56+
57+
.. warning::
58+
Streaming support is experimental and may change in future
59+
versions.
60+
61+
Calls the LLM with ``stream=True`` and returns the collected list of
62+
raw ``LlmResponse`` chunks. The workflow's ``TemporalModel.generate_content_async``
63+
yields these to the caller.
64+
65+
Each response is also published to the workflow's stream on
66+
``streaming_event_topic`` so external consumers (UIs, tracing, etc.)
67+
can observe responses as they arrive.
68+
"""
69+
llm_request = input.llm_request
70+
if llm_request.model is None:
71+
raise ValueError("No model name provided, could not create LLM.")
72+
73+
llm = LLMRegistry.new_llm(llm_request.model)
74+
if not llm:
75+
raise ValueError(f"Failed to create LLM for model: {llm_request.model}")
76+
77+
responses: list[LlmResponse] = []
78+
79+
stream = WorkflowStreamClient.from_within_activity(
80+
batch_interval=input.streaming_event_batch_interval,
81+
)
82+
events = stream.topic(input.streaming_event_topic, type=LlmResponse)
83+
async with stream:
84+
async for response in llm.generate_content_async(
85+
llm_request=llm_request, stream=True
86+
):
87+
activity.heartbeat()
88+
responses.append(response)
89+
events.publish(response)
90+
91+
return responses
92+
93+
3994
class TemporalModel(BaseLlm):
4095
"""A Temporal-based LLM model that executes model invocations as activities."""
4196

@@ -45,9 +100,15 @@ def __init__(
45100
activity_config: ActivityConfig | None = None,
46101
*,
47102
summary_fn: Callable[[LlmRequest], str | None] | None = None,
103+
streaming_event_topic: str | None = None,
104+
streaming_event_batch_interval: timedelta = timedelta(milliseconds=100),
48105
) -> None:
49106
"""Initialize the TemporalModel.
50107
108+
Streaming is selected by the caller via the ADK
109+
``generate_content_async(stream=True)`` argument; no plugin-level
110+
flag is needed.
111+
51112
Args:
52113
model_name: The name of the model to use.
53114
activity_config: Configuration options for the activity execution.
@@ -56,13 +117,28 @@ def __init__(
56117
deterministic as it is called during workflow execution. If
57118
the callable raises, the exception will propagate and fail
58119
the workflow task.
120+
streaming_event_topic: Stream topic to publish raw
121+
``LlmResponse`` chunks to when streaming. Required when
122+
callers invoke ``generate_content_async(stream=True)``;
123+
if ``None``, the streaming call raises before scheduling
124+
an activity. The workflow must host a
125+
:class:`temporalio.contrib.workflow_streams.WorkflowStream`
126+
to receive the publishes; otherwise the signals are
127+
unhandled and dropped. Streaming support is
128+
experimental and may change in future versions.
129+
streaming_event_batch_interval: Interval between automatic
130+
flushes for the stream publisher used by the streaming
131+
activity. Streaming support is experimental and may
132+
change in future versions.
59133
60134
Raises:
61135
ValueError: If both ``ActivityConfig["summary"]`` and ``summary_fn`` are set.
62136
"""
63137
super().__init__(model=model_name)
64138
self._model_name = model_name
65139
self._summary_fn = summary_fn
140+
self._streaming_event_topic = streaming_event_topic
141+
self._streaming_event_batch_interval = streaming_event_batch_interval
66142
self._activity_config = ActivityConfig(
67143
start_to_close_timeout=timedelta(seconds=60)
68144
)
@@ -80,7 +156,10 @@ async def generate_content_async(
80156
81157
Args:
82158
llm_request: The LLM request containing model parameters and content.
83-
stream: Whether to stream the response (currently ignored).
159+
stream: Whether to use the streaming activity. When ``True``,
160+
each chunk is also published to ``streaming_event_topic``
161+
(if set) for external consumers. Streaming support is
162+
experimental and may change in future versions.
84163
85164
Yields:
86165
The responses from the model.
@@ -103,10 +182,28 @@ async def generate_content_async(
103182
agent_name = llm_request.config.labels.get("adk_agent_name")
104183
if agent_name:
105184
config["summary"] = agent_name
106-
responses = await workflow.execute_activity(
107-
invoke_model,
108-
args=[llm_request],
109-
**config,
110-
)
185+
186+
if stream:
187+
if self._streaming_event_topic is None:
188+
raise ApplicationError(
189+
"generate_content_async(stream=True) requires "
190+
"TemporalModel(streaming_event_topic=...) to be set.",
191+
non_retryable=True,
192+
)
193+
responses = await workflow.execute_activity(
194+
invoke_model_streaming,
195+
StreamingInvokeInput(
196+
llm_request=llm_request,
197+
streaming_event_topic=self._streaming_event_topic,
198+
streaming_event_batch_interval=self._streaming_event_batch_interval,
199+
),
200+
**config,
201+
)
202+
else:
203+
responses = await workflow.execute_activity(
204+
invoke_model,
205+
args=[llm_request],
206+
**config,
207+
)
111208
for response in responses:
112209
yield response

temporalio/contrib/google_adk_agents/_plugin.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
import dataclasses
44
import time
55
import uuid
6-
from collections.abc import AsyncIterator
6+
from collections.abc import AsyncIterator, Callable
77
from contextlib import asynccontextmanager
8+
from typing import Any
89

910
from temporalio import workflow
1011
from temporalio.contrib.google_adk_agents._mcp import TemporalMcpToolSetProvider
11-
from temporalio.contrib.google_adk_agents._model import invoke_model
12+
from temporalio.contrib.google_adk_agents._model import (
13+
invoke_model,
14+
invoke_model_streaming,
15+
)
1216
from temporalio.contrib.pydantic import (
1317
PydanticPayloadConverter,
1418
ToJsonOptions,
@@ -95,7 +99,13 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
9599
)
96100
return runner
97101

98-
new_activities = [invoke_model]
102+
# Annotate as Sequence[Callable[..., Any]] because invoke_model
103+
# and invoke_model_streaming have different signatures, so the
104+
# inferred list type would not satisfy SimplePlugin's parameter.
105+
new_activities: list[Callable[..., Any]] = [
106+
invoke_model,
107+
invoke_model_streaming,
108+
]
99109
if toolset_providers is not None:
100110
for toolset_provider in toolset_providers:
101111
new_activities.extend(toolset_provider._get_activities())
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""Integration tests for ADK streaming support.
2+
3+
Verifies that the streaming model activity publishes raw ``LlmResponse``
4+
chunks via the WorkflowStream broker. Non-streaming behavior is covered
5+
by ``test_google_adk_agents.py``.
6+
"""
7+
8+
import asyncio
9+
import logging
10+
import uuid
11+
from collections.abc import AsyncGenerator
12+
from datetime import timedelta
13+
14+
import pytest
15+
from google.adk import Agent
16+
from google.adk.agents.run_config import RunConfig, StreamingMode
17+
from google.adk.models import BaseLlm, LLMRegistry
18+
from google.adk.models.llm_request import LlmRequest
19+
from google.adk.models.llm_response import LlmResponse
20+
from google.adk.runners import InMemoryRunner
21+
from google.genai.types import Content, Part
22+
23+
from temporalio import workflow
24+
from temporalio.client import Client, WorkflowFailureError
25+
from temporalio.contrib.google_adk_agents import GoogleAdkPlugin, TemporalModel
26+
from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient
27+
from temporalio.worker import Worker
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class StreamingTestModel(BaseLlm):
33+
"""Test model that yields multiple partial responses to simulate streaming."""
34+
35+
@classmethod
36+
def supported_models(cls) -> list[str]:
37+
return ["streaming_test_model"]
38+
39+
async def generate_content_async(
40+
self, llm_request: LlmRequest, stream: bool = False
41+
) -> AsyncGenerator[LlmResponse, None]:
42+
# The streaming activity must call us with stream=True; if a
43+
# regression drops the flag this test should fail.
44+
if not stream:
45+
raise AssertionError(
46+
"StreamingTestModel.generate_content_async requires stream=True"
47+
)
48+
yield LlmResponse(content=Content(role="model", parts=[Part(text="Hello ")]))
49+
yield LlmResponse(content=Content(role="model", parts=[Part(text="world!")]))
50+
51+
52+
@workflow.defn
53+
class StreamingAdkWorkflow:
54+
"""Test workflow that opts into streaming via RunConfig.streaming_mode."""
55+
56+
@workflow.init
57+
def __init__(self, prompt: str) -> None:
58+
self.stream = WorkflowStream()
59+
60+
@workflow.run
61+
async def run(self, prompt: str) -> str:
62+
model = TemporalModel("streaming_test_model", streaming_event_topic="events")
63+
agent = Agent(
64+
name="test_agent",
65+
model=model,
66+
instruction="You are a test agent.",
67+
)
68+
69+
runner = InMemoryRunner(agent=agent, app_name="test-app")
70+
session = await runner.session_service.create_session(
71+
app_name="test-app", user_id="test"
72+
)
73+
74+
final_text = ""
75+
async for event in runner.run_async(
76+
user_id="test",
77+
session_id=session.id,
78+
new_message=Content(role="user", parts=[Part(text=prompt)]),
79+
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
80+
):
81+
if event.content and event.content.parts:
82+
for part in event.content.parts:
83+
if part.text:
84+
final_text = part.text
85+
86+
return final_text
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_streaming_publishes_events(client: Client):
91+
"""Streaming activity publishes raw LlmResponse chunks to the topic."""
92+
LLMRegistry.register(StreamingTestModel)
93+
94+
new_config = client.config()
95+
new_config["plugins"] = [GoogleAdkPlugin()]
96+
client = Client(**new_config)
97+
98+
workflow_id = f"adk-streaming-test-{uuid.uuid4()}"
99+
100+
async with Worker(
101+
client,
102+
task_queue="adk-streaming-test",
103+
workflows=[StreamingAdkWorkflow],
104+
max_cached_workflows=0,
105+
):
106+
handle = await client.start_workflow(
107+
StreamingAdkWorkflow.run,
108+
"Hello",
109+
id=workflow_id,
110+
task_queue="adk-streaming-test",
111+
execution_timeout=timedelta(seconds=30),
112+
)
113+
114+
stream = WorkflowStreamClient.create(client, workflow_id)
115+
responses: list[LlmResponse] = []
116+
117+
async def collect_events() -> None:
118+
async for item in stream.subscribe(
119+
["events"],
120+
from_offset=0,
121+
result_type=LlmResponse,
122+
poll_cooldown=timedelta(milliseconds=50),
123+
):
124+
responses.append(item.data)
125+
if len(responses) >= 2:
126+
break
127+
128+
collect_task = asyncio.create_task(collect_events())
129+
result = await handle.result()
130+
await asyncio.wait_for(collect_task, timeout=10.0)
131+
132+
# Workflow assembles streamed parts; the last part it observes is "world!".
133+
assert result == "world!"
134+
135+
texts: list[str] = []
136+
for r in responses:
137+
if r.content and r.content.parts:
138+
for part in r.content.parts:
139+
if part.text:
140+
texts.append(part.text)
141+
assert texts == ["Hello ", "world!"], f"Unexpected text deltas: {texts}"
142+
143+
144+
@workflow.defn
145+
class StreamingAdkRequiresTopicWorkflow:
146+
"""Calls ``generate_content_async(stream=True)`` without configuring
147+
``streaming_event_topic``; the call must raise before any activity
148+
is scheduled."""
149+
150+
@workflow.run
151+
async def run(self, prompt: str) -> str:
152+
model = TemporalModel("streaming_test_model")
153+
agent = Agent(
154+
name="test_agent",
155+
model=model,
156+
instruction="You are a test agent.",
157+
)
158+
runner = InMemoryRunner(agent=agent, app_name="test-app")
159+
session = await runner.session_service.create_session(
160+
app_name="test-app", user_id="test"
161+
)
162+
async for _ in runner.run_async(
163+
user_id="test",
164+
session_id=session.id,
165+
new_message=Content(role="user", parts=[Part(text=prompt)]),
166+
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
167+
):
168+
pass
169+
return "should not reach"
170+
171+
172+
@pytest.mark.asyncio
173+
async def test_streaming_requires_topic(client: Client):
174+
"""``stream=True`` fails fast when no streaming topic was configured
175+
on ``TemporalModel``. The error is raised in the workflow before any
176+
streaming activity is scheduled."""
177+
LLMRegistry.register(StreamingTestModel)
178+
179+
new_config = client.config()
180+
new_config["plugins"] = [GoogleAdkPlugin()]
181+
client = Client(**new_config)
182+
183+
async with Worker(
184+
client,
185+
task_queue="adk-streaming-requires-topic",
186+
workflows=[StreamingAdkRequiresTopicWorkflow],
187+
max_cached_workflows=0,
188+
):
189+
with pytest.raises(WorkflowFailureError) as exc_info:
190+
await client.execute_workflow(
191+
StreamingAdkRequiresTopicWorkflow.run,
192+
"Hi",
193+
id=f"adk-streaming-requires-topic-{uuid.uuid4()}",
194+
task_queue="adk-streaming-requires-topic",
195+
execution_timeout=timedelta(seconds=30),
196+
)
197+
198+
assert "streaming_event_topic" in str(exc_info.value.cause)

0 commit comments

Comments
 (0)