Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion temporalio/contrib/langgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = [
"LangGraphPlugin",
"entrypoint",
"cache",
"entrypoint",
"graph",
]
Comment thread
brianstrauch marked this conversation as resolved.
44 changes: 30 additions & 14 deletions temporalio/contrib/langgraph/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Awaitable
from dataclasses import dataclass
from datetime import timedelta
from inspect import iscoroutinefunction, signature
from typing import Any, Callable

Expand All @@ -19,6 +20,7 @@
cache_lookup,
cache_put,
)
from temporalio.contrib.workflow_streams import WorkflowStreamClient

# Per-run dedupe so we only warn once when a user passes a Store via
# graph.compile(store=...) / @entrypoint(store=...). Cleared by
Expand Down Expand Up @@ -51,6 +53,9 @@ class ActivityOutput:

def wrap_activity(
func: Callable,
*,
streaming_topic: str | None = None,
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]:
"""Wrap a function as a Temporal activity that handles LangGraph config and interrupts."""
# Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks
Expand All @@ -59,20 +64,31 @@ def wrap_activity(
accepts_runtime = "runtime" in signature(func).parameters

async def wrapper(input: ActivityInput) -> ActivityOutput:
runtime = set_langgraph_config(input.langgraph_config)
kwargs = dict(input.kwargs)
if accepts_runtime:
kwargs["runtime"] = runtime
try:
if iscoroutinefunction(func):
result = await func(*input.args, **kwargs)
else:
result = func(*input.args, **kwargs)
if isinstance(result, Command):
return ActivityOutput(langgraph_command=result)
return ActivityOutput(result=result)
except GraphInterrupt as e:
return ActivityOutput(langgraph_interrupts=e.args[0])
async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput:
runtime = set_langgraph_config(
input.langgraph_config, stream_writer=stream_writer
)
kwargs = dict(input.kwargs)
if accepts_runtime:
kwargs["runtime"] = runtime
try:
if iscoroutinefunction(func):
result = await func(*input.args, **kwargs)
else:
result = func(*input.args, **kwargs)
if isinstance(result, Command):
return ActivityOutput(langgraph_command=result)
return ActivityOutput(result=result)
except GraphInterrupt as e:
return ActivityOutput(langgraph_interrupts=e.args[0])

if streaming_topic is None:
return await run(stream_writer=None)
async with WorkflowStreamClient.from_within_activity(
batch_interval=streaming_batch_interval,
) as client:
topic = client.topic(streaming_topic)
return await run(stream_writer=topic.publish)
Comment thread
brianstrauch marked this conversation as resolved.

return wrapper

Expand Down
10 changes: 7 additions & 3 deletions temporalio/contrib/langgraph/_langgraph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportMissingTypeStubs=false

import dataclasses
from typing import Any
from typing import Any, Callable

from langchain_core.runnables.config import var_child_runnable_config
from langgraph._internal._constants import (
Expand Down Expand Up @@ -93,7 +93,11 @@ def get_langgraph_config() -> dict[str, Any]:
}


def set_langgraph_config(config: dict[str, Any]) -> Runtime:
def set_langgraph_config(
config: dict[str, Any],
*,
stream_writer: Callable[[Any], None] | None = None,
) -> Runtime:
"""Restore a LangGraph runnable config from a serialized dict.

Returns the reconstructed Runtime so callers can re-inject it into the
Expand All @@ -112,7 +116,7 @@ def get_null_resume(consume: bool = False) -> Any:
execution_info_dict = config.get("execution_info")
runtime = Runtime(
context=config.get("context"),
stream_writer=lambda _: None,
stream_writer=stream_writer or (lambda _: None),
previous=config.get("previous"),
execution_info=(
ExecutionInfo(**execution_info_dict) if execution_info_dict else None
Expand Down
19 changes: 17 additions & 2 deletions temporalio/contrib/langgraph/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import warnings
from dataclasses import replace
from datetime import timedelta
from typing import Any, Callable

from langgraph._internal._runnable import RunnableCallable
Expand Down Expand Up @@ -58,8 +59,15 @@ def __init__(
# TODO: Remove activity_options when we have support for @task(metadata=...)
activity_options: dict[str, dict[str, Any]] | None = None,
default_activity_options: dict[str, Any] | None = None,
streaming_topic: str | None = None,
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
):
"""Initialize the LangGraph plugin with graphs, entrypoints, and tasks."""
"""Initialize the LangGraph plugin with graphs, entrypoints, and tasks.

.. warning::
Streaming support is experimental and may change in
future versions.
"""
Comment thread
brianstrauch marked this conversation as resolved.
if sys.version_info < (3, 11):
warnings.warn( # type: ignore[reportUnreachable]
"LangGraphPlugin requires Python >= 3.11 for full async support. "
Expand All @@ -79,6 +87,8 @@ def __init__(
)

self.activities: list = []
self._streaming_topic = streaming_topic
self._streaming_batch_interval = streaming_batch_interval

# Graph API: Wrap graph nodes as Temporal Activities.
if graphs:
Expand Down Expand Up @@ -197,7 +207,12 @@ def execute(
execute_in = opts.pop("execute_in")

if execute_in == "activity":
a = activity.defn(name=activity_name)(wrap_activity(func))
wrapped = wrap_activity(
func,
streaming_topic=self._streaming_topic,
streaming_batch_interval=self._streaming_batch_interval,
)
a = activity.defn(name=activity_name)(wrapped)
self.activities.append(a)
return wrap_execute_activity(a, task_id=task_id(func), **opts)
elif execute_in == "workflow":
Expand Down
135 changes: 117 additions & 18 deletions tests/contrib/langgraph/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,166 @@
from typing import Any
from uuid import uuid4

from langgraph.config import (
get_stream_writer, # pyright: ignore[reportMissingTypeStubs]
)
from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs]
from typing_extensions import TypedDict

from temporalio import workflow
from temporalio.client import Client
from temporalio.contrib.langgraph import LangGraphPlugin, graph
from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient
from temporalio.worker import Worker


class State(TypedDict):
value: str


async def node_a(state: State) -> dict[str, str]:
return {"value": state["value"] + "a"}
async def token_node(state: State) -> dict[str, str]:
tokens = ["a", "b", "c"]
writer = get_stream_writer()
for token in tokens:
writer({"token": token})
writer({"done": True})
return {"value": state["value"] + "".join(tokens)}


async def node_b(state: State) -> dict[str, str]:
return {"value": state["value"] + "b"}
@workflow.defn
class StreamingWorkflowStreamsWorkflow:
def __init__(self) -> None:
_ = WorkflowStream()
self.app = graph("streaming-ws").compile()

@workflow.run
async def run(self, input: str) -> str:
result = await self.app.ainvoke({"value": input})
return result["value"]


async def test_streaming_via_workflow_streams(client: Client):
g = StateGraph(State)
g.add_node("token_node", token_node, metadata={"execute_in": "activity"})
g.add_edge(START, "token_node")

task_queue = f"streaming-ws-{uuid4()}"

async with Worker(
client,
task_queue=task_queue,
workflows=[StreamingWorkflowStreamsWorkflow],
plugins=[
LangGraphPlugin(
graphs={"streaming-ws": g},
default_activity_options={
"start_to_close_timeout": timedelta(seconds=10)
},
streaming_topic="tokens",
)
],
):
handle = await client.start_workflow(
StreamingWorkflowStreamsWorkflow.run,
"",
id=f"test-streaming-ws-{uuid4()}",
task_queue=task_queue,
)

ws_client = WorkflowStreamClient.create(client, handle.id)
chunks: list[dict[str, Any]] = []
async for item in ws_client.topic("tokens", type=dict).subscribe(
from_offset=0,
poll_cooldown=timedelta(0),
):
chunks.append(item.data)
if chunks[-1].get("done"):
break

result = await handle.result()

assert result == "abc"
assert chunks == [
{"token": "a"},
{"token": "b"},
{"token": "c"},
{"done": True},
]


# ---------------------------------------------------------------------------
# Workflow-side publish: iterate astream() in the workflow and forward each
# chunk via self.stream.topic("astream").publish(...) so external subscribers
# see node-level progress alongside any activity-emitted tokens.
# ---------------------------------------------------------------------------


@workflow.defn
class StreamingWorkflow:
class AstreamPublishWorkflow:
def __init__(self) -> None:
self.app = graph("streaming").compile()
self.stream = WorkflowStream()
self.app = graph("astream-publish").compile()

@workflow.run
async def run(self, input: str) -> Any:
chunks = []
async def run(self, input: str) -> str:
topic = self.stream.topic("astream")
async for chunk in self.app.astream({"value": input}):
chunks.append(chunk)
return chunks
topic.publish(chunk)
topic.publish({"done": True})
return "done"


async def node_a(state: State) -> dict[str, str]:
return {"value": state["value"] + "a"}


async def node_b(state: State) -> dict[str, str]:
return {"value": state["value"] + "b"}


async def test_streaming(client: Client):
async def test_workflow_publishes_astream_chunks(client: Client):
g = StateGraph(State)
g.add_node("node_a", node_a, metadata={"execute_in": "activity"})
g.add_node("node_b", node_b, metadata={"execute_in": "activity"})
g.add_edge(START, "node_a")
g.add_edge("node_a", "node_b")

task_queue = f"streaming-{uuid4()}"
task_queue = f"astream-publish-{uuid4()}"

async with Worker(
client,
task_queue=task_queue,
workflows=[StreamingWorkflow],
workflows=[AstreamPublishWorkflow],
plugins=[
LangGraphPlugin(
graphs={"streaming": g},
graphs={"astream-publish": g},
default_activity_options={
"start_to_close_timeout": timedelta(seconds=10)
},
)
],
):
chunks = await client.execute_workflow(
StreamingWorkflow.run,
handle = await client.start_workflow(
AstreamPublishWorkflow.run,
"",
id=f"test-streaming-{uuid4()}",
id=f"test-astream-publish-{uuid4()}",
task_queue=task_queue,
)

assert chunks == [{"node_a": {"value": "a"}}, {"node_b": {"value": "ab"}}]
ws_client = WorkflowStreamClient.create(client, handle.id)
chunks: list[dict[str, Any]] = []
async for item in ws_client.topic("astream", type=dict).subscribe(
from_offset=0,
poll_cooldown=timedelta(0),
):
chunks.append(item.data)
if chunks[-1].get("done"):
break

await handle.result()

assert chunks == [
{"node_a": {"value": "a"}},
{"node_b": {"value": "ab"}},
{"done": True},
]
Loading