Skip to content

Commit a8d3dfc

Browse files
committed
workflow_streams: remove WorkflowStream(Client).publish in favor of topic handles
Drops the un-typed publish(topic, value) entry points on WorkflowStream and WorkflowStreamClient. Publishers now go through WorkflowStream.topic(name, type=T) and WorkflowStreamClient.topic(name, type=T) which return typed TopicHandle / WorkflowTopicHandle objects. Topic handles are the only supported publish API. Per Codex review on PR #1423: per-instance type uniqueness is a factory-level invariant — handle-based publishing on a single publisher cannot mix Ts on a topic, while still allowing escape hatches (type=typing.Any for heterogeneous topics, pre-built Payload values via the zero-copy fast path on any-typed handle). Migrates the internal users: - temporalio/contrib/openai_agents/_invoke_model_activity.py uses type=Any (TResponseStreamEvent is an annotated union, not a class) - temporalio/contrib/google_adk_agents/_model.py uses type=LlmResponse - All workflow_streams tests migrated to the handle form, preserving existing topic names and types
1 parent 4831c9a commit a8d3dfc

5 files changed

Lines changed: 50 additions & 75 deletions

File tree

temporalio/contrib/google_adk_agents/_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ async def invoke_model_streaming(
7979
stream = WorkflowStreamClient.from_within_activity(
8080
batch_interval=input.streaming_event_batch_interval,
8181
)
82+
events = stream.topic(input.streaming_event_topic, type=LlmResponse)
8283
async with stream:
8384
async for response in llm.generate_content_async(
8485
llm_request=llm_request, stream=True
8586
):
8687
activity.heartbeat()
8788
responses.append(response)
88-
stream.publish(input.streaming_event_topic, response)
89+
events.publish(response)
8990

9091
return responses
9192

temporalio/contrib/openai_agents/_invoke_model_activity.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import enum
77
from dataclasses import dataclass
88
from datetime import timedelta
9-
from typing import Any, NoReturn
9+
from typing import Any, NoReturn, cast
1010

1111
from agents import (
1212
AgentOutputSchemaBase,
@@ -373,6 +373,11 @@ async def invoke_model_activity_streaming(
373373
stream = WorkflowStreamClient.from_within_activity(
374374
batch_interval=batch_interval
375375
)
376+
# TResponseStreamEvent is a typing.Annotated[Union[...]] — a typing
377+
# special form, not a class — so it cannot be passed as type[T].
378+
# Use Any here; subscribers that want typed decode can pass
379+
# result_type=TResponseStreamEvent on their own subscribe call.
380+
events_topic = stream.topic(topic, type=cast("type[Any]", cast(object, Any)))
376381
async with stream:
377382
try:
378383
async for event in model.stream_response(
@@ -388,7 +393,7 @@ async def invoke_model_activity_streaming(
388393
prompt=input.get("prompt"),
389394
):
390395
events.append(event)
391-
stream.publish(topic, event)
396+
events_topic.publish(event)
392397
except APIStatusError as e:
393398
_raise_for_openai_status(e)
394399

temporalio/contrib/workflow_streams/_client.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ class WorkflowStreamClient:
6262
:py:meth:`from_within_activity` (infer both from the current activity
6363
context), or by passing a handle directly to the constructor.
6464
65-
For publishing, use as an async context manager to get automatic
66-
batching::
65+
For publishing, bind a typed topic handle and use the client as
66+
an async context manager to get automatic batching::
6767
6868
client = WorkflowStreamClient.create(temporal_client, workflow_id)
69+
events = client.topic("events", type=MyEvent)
6970
async with client:
70-
client.publish("events", my_event)
71-
client.publish("events", another_event, force_flush=True)
71+
events.publish(my_event)
72+
events.publish(another_event, force_flush=True)
7273
... # more publishing
7374
# Buffer is flushed automatically on context manager exit.
7475
@@ -230,34 +231,14 @@ async def __aexit__(self, *_exc: object) -> None:
230231
while self._pending is not None or self._buffer:
231232
await self._flush()
232233

233-
def publish(self, topic: str, value: Any, force_flush: bool = False) -> None:
234-
"""Buffer a message for publishing.
235-
236-
.. deprecated::
237-
Prefer :meth:`topic` and :meth:`TopicHandle.publish`. The
238-
handle form carries the value type, which is needed for
239-
cross-language SDK consistency.
240-
241-
``value`` may be any Python value the client's payload
242-
converter can handle, or a pre-built
243-
:class:`temporalio.api.common.v1.Payload` for zero-copy. The
244-
codec chain is not applied per item — it runs once on the
245-
signal envelope that delivers the batch.
246-
247-
Args:
248-
topic: Topic string.
249-
value: Value to publish. Converted to a ``Payload`` via
250-
the client's sync payload converter at flush time.
251-
Pre-built ``Payload`` instances bypass conversion.
252-
force_flush: If True, wake the flusher to send immediately
253-
(fire-and-forget — does not block the caller).
254-
"""
255-
self._publish_to_topic(topic, value, force_flush=force_flush)
256-
257234
def _publish_to_topic(
258235
self, topic: str, value: Any, *, force_flush: bool = False
259236
) -> None:
260-
"""Internal publish path shared by :meth:`publish` and topic handles."""
237+
"""Internal publish path used by :class:`TopicHandle`.
238+
239+
Not part of the public API — call
240+
:meth:`TopicHandle.publish` instead.
241+
"""
261242
self._buffer.append((topic, value))
262243
if force_flush or (
263244
self._max_batch_size is not None

temporalio/contrib/workflow_streams/_stream.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -156,26 +156,12 @@ def __init__(self, prior_state: WorkflowStreamState | None = None) -> None:
156156
)
157157
workflow.set_query_handler(_OFFSET_QUERY, self._on_offset)
158158

159-
def publish(self, topic: str, value: Any) -> None:
160-
"""Publish an item from within workflow code.
161-
162-
.. deprecated::
163-
Prefer :meth:`topic` and :meth:`WorkflowTopicHandle.publish`.
164-
The handle form carries the value type, which is needed
165-
for cross-language SDK consistency.
166-
167-
``value`` may be any Python value the workflow's payload
168-
converter can handle, or a pre-built
169-
:class:`temporalio.api.common.v1.Payload` for zero-copy.
159+
def _publish_to_topic(self, topic: str, value: Any) -> None:
160+
"""Internal publish path used by :class:`WorkflowTopicHandle`.
170161
171-
The codec chain is not applied here (it runs on the
172-
``__temporal_workflow_stream_poll`` update envelope that later
173-
delivers the item to a subscriber).
162+
Not part of the public API — call
163+
:meth:`WorkflowTopicHandle.publish` instead.
174164
"""
175-
self._publish_to_topic(topic, value)
176-
177-
def _publish_to_topic(self, topic: str, value: Any) -> None:
178-
"""Internal publish path shared by :meth:`publish` and topic handles."""
179165
if isinstance(value, Payload):
180166
payload = value
181167
else:

tests/contrib/workflow_streams/test_workflow_streams.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def run(self, count: int) -> None:
101101
start_to_close_timeout=timedelta(seconds=30),
102102
heartbeat_timeout=timedelta(seconds=10),
103103
)
104-
self.stream.publish("status", b"activity_done")
104+
self.stream.topic("status", type=bytes).publish(b"activity_done")
105105
await workflow.wait_condition(lambda: self._closed)
106106

107107

@@ -125,7 +125,9 @@ def close(self) -> None:
125125
@workflow.run
126126
async def run(self, count: int) -> None:
127127
for i in range(count):
128-
self.stream.publish("events", AgentEvent(kind="tick", payload={"i": i}))
128+
self.stream.topic("events", type=AgentEvent).publish(
129+
AgentEvent(kind="tick", payload={"i": i})
130+
)
129131
await workflow.wait_condition(lambda: self._closed)
130132

131133

@@ -164,7 +166,7 @@ def close(self) -> None:
164166
@workflow.run
165167
async def run(self, count: int) -> None:
166168
for i in range(count):
167-
self.stream.publish("events", f"item-{i}".encode())
169+
self.stream.topic("events", type=bytes).publish(f"item-{i}".encode())
168170
await workflow.wait_condition(lambda: self._closed)
169171

170172

@@ -203,14 +205,14 @@ def close(self) -> None:
203205

204206
@workflow.run
205207
async def run(self, count: int) -> None:
206-
self.stream.publish("status", b"started")
208+
self.stream.topic("status", type=bytes).publish(b"started")
207209
await workflow.execute_activity(
208210
"publish_items",
209211
count,
210212
start_to_close_timeout=timedelta(seconds=30),
211213
heartbeat_timeout=timedelta(seconds=10),
212214
)
213-
self.stream.publish("status", b"done")
215+
self.stream.topic("status", type=bytes).publish(b"done")
214216
await workflow.wait_condition(lambda: self._closed)
215217

216218

@@ -280,7 +282,7 @@ async def run(self, count: int) -> None:
280282
start_to_close_timeout=timedelta(seconds=30),
281283
heartbeat_timeout=timedelta(seconds=10),
282284
)
283-
self.stream.publish("status", b"activity_done")
285+
self.stream.topic("status", type=bytes).publish(b"activity_done")
284286
await workflow.wait_condition(lambda: self._closed)
285287

286288

@@ -349,7 +351,7 @@ async def publish_items(count: int) -> None:
349351
async with client:
350352
for i in range(count):
351353
activity.heartbeat()
352-
client.publish("events", f"item-{i}".encode())
354+
client.topic("events", type=bytes).publish(f"item-{i}".encode())
353355

354356

355357
@activity.defn(name="publish_multi_topic")
@@ -362,7 +364,7 @@ async def publish_multi_topic(count: int) -> None:
362364
for i in range(count):
363365
activity.heartbeat()
364366
topic = topics[i % len(topics)]
365-
client.publish(topic, f"{topic}-{i}".encode())
367+
client.topic(topic, type=bytes).publish(f"{topic}-{i}".encode())
366368

367369

368370
@activity.defn(name="publish_with_priority")
@@ -376,9 +378,9 @@ async def publish_with_priority() -> None:
376378
batch_interval=timedelta(seconds=60)
377379
)
378380
async with client:
379-
client.publish("events", b"normal-0")
380-
client.publish("events", b"normal-1")
381-
client.publish("events", b"priority", force_flush=True)
381+
client.topic("events", type=bytes).publish(b"normal-0")
382+
client.topic("events", type=bytes).publish(b"normal-1")
383+
client.topic("events", type=bytes).publish(b"priority", force_flush=True)
382384
for _ in range(100):
383385
activity.heartbeat()
384386
await asyncio.sleep(0.1)
@@ -392,7 +394,7 @@ async def publish_batch_test(count: int) -> None:
392394
async with client:
393395
for i in range(count):
394396
activity.heartbeat()
395-
client.publish("events", f"item-{i}".encode())
397+
client.topic("events", type=bytes).publish(f"item-{i}".encode())
396398

397399

398400
@activity.defn(name="publish_with_max_batch")
@@ -403,7 +405,7 @@ async def publish_with_max_batch(count: int) -> None:
403405
async with client:
404406
for i in range(count):
405407
activity.heartbeat()
406-
client.publish("events", f"item-{i}".encode())
408+
client.topic("events", type=bytes).publish(f"item-{i}".encode())
407409
# Yield so the flusher task can run when max_batch_size triggers
408410
# _flush_event. Real workloads (e.g. agents awaiting LLM streams)
409411
# yield constantly; a tight loop with no awaits would never let
@@ -1085,9 +1087,9 @@ async def test_explicit_flush_barrier(client: Client) -> None:
10851087

10861088
# 2. Flush makes prior publishes visible without waiting on
10871089
# the 60s batch timer.
1088-
stream.publish("events", b"a")
1089-
stream.publish("events", b"b")
1090-
stream.publish("events", b"c")
1090+
stream.topic("events", type=bytes).publish(b"a")
1091+
stream.topic("events", type=bytes).publish(b"b")
1092+
stream.topic("events", type=bytes).publish(b"c")
10911093
await stream.flush()
10921094
assert await stream.get_offset() == 3
10931095

@@ -1279,14 +1281,14 @@ async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any:
12791281
return await real_signal(*args, **kwargs)
12801282

12811283
with patch.object(handle, "signal", side_effect=maybe_failing_signal):
1282-
stream.publish("events", b"item-0")
1283-
stream.publish("events", b"item-1")
1284+
stream.topic("events", type=bytes).publish(b"item-0")
1285+
stream.topic("events", type=bytes).publish(b"item-1")
12841286
with pytest.raises(RuntimeError):
12851287
await stream._flush()
12861288

12871289
# Publish more during the failed state — must not overtake the
12881290
# pending retry on eventual delivery.
1289-
stream.publish("events", b"item-2")
1291+
stream.topic("events", type=bytes).publish(b"item-2")
12901292
with pytest.raises(RuntimeError):
12911293
await stream._flush()
12921294

@@ -1336,7 +1338,7 @@ async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any:
13361338
),
13371339
patch.object(handle, "signal", side_effect=maybe_failing_signal),
13381340
):
1339-
stream.publish("events", b"lost")
1341+
stream.topic("events", type=bytes).publish(b"lost")
13401342

13411343
# First flush fails and enters the pending-retry state.
13421344
with pytest.raises(RuntimeError):
@@ -1351,7 +1353,7 @@ async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any:
13511353

13521354
# Stop failing signals; subsequent publishes must succeed.
13531355
fail_signals = False
1354-
stream.publish("events", b"kept")
1356+
stream.topic("events", type=bytes).publish(b"kept")
13551357
await stream._flush()
13561358

13571359
items = await collect_items(client, handle, None, 0, 1)
@@ -1586,7 +1588,7 @@ def __init__(self, prepub_count: int = 0) -> None:
15861588
self.stream = WorkflowStream()
15871589
self._closed = False
15881590
for i in range(prepub_count):
1589-
self.stream.publish("events", f"item-{i}".encode())
1591+
self.stream.topic("events", type=bytes).publish(f"item-{i}".encode())
15901592

15911593
@workflow.signal
15921594
def close(self) -> None:
@@ -1890,7 +1892,7 @@ def close(self) -> None:
18901892
@workflow.run
18911893
async def run(self, count: int) -> None:
18921894
for i in range(count):
1893-
self.stream.publish("events", f"broker-{i}".encode())
1895+
self.stream.topic("events", type=bytes).publish(f"broker-{i}".encode())
18941896
await workflow.wait_condition(lambda: self._closed)
18951897

18961898

@@ -2001,7 +2003,7 @@ async def standalone_publish_to_broker(input: StandalonePublishInput) -> None:
20012003
async with client:
20022004
for i in range(input.count):
20032005
activity.heartbeat()
2004-
client.publish("events", f"standalone-{i}".encode())
2006+
client.topic("events", type=bytes).publish(f"standalone-{i}".encode())
20052007

20062008

20072009
@activity.defn(name="standalone_subscribe_to_broker")
@@ -2187,7 +2189,7 @@ def close(self) -> None:
21872189
@workflow.run
21882190
async def run(self, count: int) -> str:
21892191
for i in range(count):
2190-
self.stream.publish("events", f"nexus-{i}".encode())
2192+
self.stream.topic("events", type=bytes).publish(f"nexus-{i}".encode())
21912193
await workflow.wait_condition(lambda: self._closed)
21922194
return "done"
21932195

0 commit comments

Comments
 (0)