Skip to content

Commit 587f638

Browse files
jssmithclaude
andcommitted
workflow_streams: parametrize WorkflowStreamItem on decoded data type
Make WorkflowStreamItem generic in T so subscribers get a typed data field that matches the result_type passed to subscribe: - subscribe(result_type=T) -> WorkflowStreamItem[T] - subscribe() -> WorkflowStreamItem[Any] - subscribe(result_type=RawValue) -> WorkflowStreamItem[RawValue] Adds def-style overloads to WorkflowStreamClient.subscribe (matching the existing TopicHandle/WorkflowTopicHandle generic style) and tightens TopicHandle.subscribe to AsyncIterator[WorkflowStreamItem[T]]. The internal workflow-side _log is annotated as list[WorkflowStreamItem[Payload]] since the workflow does not decode. No runtime behavior change; existing tests (which use unparameterized WorkflowStreamItem) continue to type-check as WorkflowStreamItem[Any]. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
1 parent 1b46541 commit 587f638

4 files changed

Lines changed: 33 additions & 12 deletions

File tree

temporalio/contrib/workflow_streams/_client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,14 +456,33 @@ async def _run_flusher(self) -> None:
456456
self._flush_event.clear()
457457
await self._flush()
458458

459+
@overload
460+
def subscribe(
461+
self,
462+
topics: str | list[str] | None = ...,
463+
from_offset: int = ...,
464+
*,
465+
result_type: type[T],
466+
poll_cooldown: timedelta = ...,
467+
) -> AsyncIterator[WorkflowStreamItem[T]]: ...
468+
@overload
469+
def subscribe(
470+
self,
471+
topics: str | list[str] | None = ...,
472+
from_offset: int = ...,
473+
*,
474+
result_type: None = None,
475+
poll_cooldown: timedelta = ...,
476+
) -> AsyncIterator[WorkflowStreamItem[Any]]: ...
477+
459478
async def subscribe(
460479
self,
461480
topics: str | list[str] | None = None,
462481
from_offset: int = 0,
463482
*,
464483
result_type: type | None = None,
465484
poll_cooldown: timedelta = timedelta(milliseconds=100),
466-
) -> AsyncIterator[WorkflowStreamItem]:
485+
) -> AsyncIterator[WorkflowStreamItem[Any]]:
467486
"""Async iterator that polls for new items.
468487
469488
Automatically follows continue-as-new chains when the client

temporalio/contrib/workflow_streams/_stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(self, prior_state: WorkflowStreamState | None = None) -> None:
136136
)
137137

138138
if prior_state is not None:
139-
self._log: list[WorkflowStreamItem] = [
139+
self._log: list[WorkflowStreamItem[Payload]] = [
140140
WorkflowStreamItem(topic=item.topic, data=_decode_payload(item.data))
141141
for item in prior_state.log
142142
]

temporalio/contrib/workflow_streams/_topic_handle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def subscribe(
9191
from_offset: int = 0,
9292
*,
9393
poll_cooldown: timedelta = timedelta(milliseconds=100),
94-
) -> AsyncIterator[WorkflowStreamItem]:
94+
) -> AsyncIterator[WorkflowStreamItem[T]]:
9595
"""Async iterator over items on this topic, decoded as ``T``.
9696
9797
For raw ``Payload`` access, or any other decode type that

temporalio/contrib/workflow_streams/_types.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import base64
2121
from dataclasses import dataclass, field
2222
from datetime import datetime
23-
from typing import Any
23+
from typing import Generic, TypeVar
2424

2525
from temporalio.api.common.v1 import Payload
2626

27+
T = TypeVar("T")
28+
2729

2830
# basedpyright flags _-prefixed module-level functions as unused even when
2931
# sibling modules import them (_stream.py, _client.py). Vanilla pyright does
@@ -41,26 +43,26 @@ def _decode_payload(wire: str) -> Payload: # pyright: ignore[reportUnusedFuncti
4143

4244

4345
@dataclass
44-
class WorkflowStreamItem:
46+
class WorkflowStreamItem(Generic[T]):
4547
"""A single item in the workflow stream's log.
4648
4749
.. warning::
4850
This class is experimental and may change in future versions.
4951
50-
The ``data`` field always carries the decoded value produced by
51-
:meth:`WorkflowStreamClient.subscribe`: the converter's default
52-
``Any`` decoding when ``result_type`` is omitted, an instance of
53-
``T`` when ``result_type=T`` is passed, or a
52+
The ``data`` field carries the decoded value produced by
53+
:meth:`WorkflowStreamClient.subscribe`. The generic parameter ``T``
54+
matches the ``result_type`` passed to ``subscribe``: an instance of
55+
``T`` when ``result_type=T``, the converter's default ``Any``
56+
decoding when ``result_type`` is omitted, or a
5457
:class:`temporalio.common.RawValue` wrapping the original
55-
``Payload`` when ``result_type=RawValue`` is passed. The dataclass
56-
is typed as ``Any`` to accommodate all three.
58+
``Payload`` when ``result_type=RawValue``.
5759
5860
The ``offset`` field is populated at poll time from the item's
5961
position in the global log.
6062
"""
6163

6264
topic: str
63-
data: Any
65+
data: T
6466
offset: int = 0
6567

6668

0 commit comments

Comments
 (0)