Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions libs/langchain_v1/langchain/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.""" # noqa: E501

from langchain.agents._middleware_transformer import (
MiddlewareEvent,
MiddlewarePhase,
MiddlewareTransformer,
)
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentState

__all__ = [
"AgentState",
"MiddlewareEvent",
"MiddlewarePhase",
"MiddlewareTransformer",
"create_agent",
]
131 changes: 131 additions & 0 deletions libs/langchain_v1/langchain/agents/_middleware_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Project graph `updates` events into typed `MiddlewareEvent` handles.

Exposed via `run.middleware` on `AgentStreamer` runs.

`create_agent` wires each middleware hook into the graph as a real
node named `"{middleware_name}.{phase}"`, where `phase` is one of
`before_agent` / `before_model` / `after_model` / `after_agent`.
When any of those nodes run, Pregel emits an `updates` event with
the node's state delta — we subscribe to `updates`, filter by the
phase-name suffix, and re-emit as `MiddlewareEvent` objects.

No cooperation from middleware implementations is required — the
node-name convention is the contract, and the transformer is a pure
consumer of the graph's existing `updates` stream mode.

`run.middleware` is an in-process-only projection: it surfaces
lifecycle information for local consumers and is not intended to
cross a wire boundary.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Literal

from langgraph.stream._event_log import EventLog
from langgraph.stream._types import StreamTransformer

if TYPE_CHECKING:
from langgraph.stream._types import ProtocolEvent


MiddlewarePhase = Literal["before_agent", "before_model", "after_model", "after_agent"]
"""Lifecycle phase a middleware node represents."""


_PHASES: frozenset[str] = frozenset(("before_agent", "before_model", "after_model", "after_agent"))


@dataclass(frozen=True)
class MiddlewareEvent:
"""One lifecycle transition from a middleware node.

Emitted once per node execution, carrying the phase, the name of
the middleware that produced it, and the state delta returned by
the hook.

Attributes:
phase: Which hook fired — `before_agent` / `before_model` /
`after_model` / `after_agent`.
middleware_name: The `name` attribute of the middleware class
that produced this event (the part before the `.` in the
graph node name).
state_delta: The state diff the middleware hook returned.
Empty dict if the hook didn't return an update.
timestamp: Milliseconds since epoch, taken from the underlying
protocol event.
"""

phase: MiddlewarePhase
middleware_name: str
state_delta: dict[str, Any]
timestamp: int


def _split_middleware_node(node_name: str) -> tuple[str, MiddlewarePhase] | None:
"""Split `"{name}.{phase}"` into `(name, phase)` if `phase` is valid.

Returns `None` for any node name that doesn't match the convention,
so non-middleware nodes are ignored.
"""
name, sep, suffix = node_name.rpartition(".")
if not sep or suffix not in _PHASES:
return None
return name, suffix # type: ignore[return-value]


class MiddlewareTransformer(StreamTransformer):
"""Project `updates` events from middleware nodes into `MiddlewareEvent`s.

Watches `updates` events at its mux scope, inspects the node name,
and emits one `MiddlewareEvent` per middleware-node execution onto
the `middleware` projection.

Native transformer — the `middleware` projection is exposed as a
direct attribute (`run.middleware`). Pre-registered by
`AgentStreamer.builtin_factories` so it's on by default for any
`create_agent` graph.

`scope` (inherited from `StreamTransformer`) lets the transformer
scope its filtering to its own mux — root mux sees root middleware,
subgraph mini-muxes see their own middleware, via the same
machinery `MessagesTransformer` uses.
"""

_native: ClassVar[bool] = True
required_stream_modes: ClassVar[tuple[str, ...]] = ("updates",)

def __init__(self, scope: tuple[str, ...] = ()) -> None:
super().__init__(scope)
self._log: EventLog[MiddlewareEvent] = EventLog()

def init(self) -> dict[str, Any]:
return {"middleware": self._log}

def process(self, event: ProtocolEvent) -> bool:
# Namespace filtering is handled by the mux via `scope_exact`.
if event["method"] != "updates":
return True

params = event["params"]
data = params.get("data")
if not isinstance(data, dict):
return True

timestamp = int(params.get("timestamp", 0))
for node_name, state_delta in data.items():
split = _split_middleware_node(node_name)
if split is None:
continue
name, phase = split
self._log.push(
MiddlewareEvent(
phase=phase,
middleware_name=name,
state_delta=state_delta if isinstance(state_delta, dict) else {},
timestamp=timestamp,
)
)
# Pass through — wire consumers still see the raw `updates` event.
return True
3 changes: 2 additions & 1 deletion libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from langsmith import traceable
from typing_extensions import NotRequired, Required, TypedDict

from langchain.agents._middleware_transformer import MiddlewareTransformer
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
Expand Down Expand Up @@ -1667,7 +1668,7 @@ async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> lis
debug=debug,
name=name,
cache=cache,
transformers=[ToolCallTransformer, *(transformers or ())],
transformers=[ToolCallTransformer, MiddlewareTransformer, *(transformers or ())],
).with_config(config)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Unit tests for `MiddlewareTransformer`."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from langchain_core.messages import AIMessage, HumanMessage

from langchain.agents import (
MiddlewareEvent,
MiddlewareTransformer,
create_agent,
)
from langchain.agents.middleware import AgentMiddleware
from tests.unit_tests.agents.model import FakeToolCallingModel

if TYPE_CHECKING:
from langchain.agents.middleware.types import AgentState


# ---------------------------------------------------------------------------
# Unit-level: feed the transformer synthetic protocol events
# ---------------------------------------------------------------------------


def _updates_event(
node_updates: dict[str, Any], *, namespace: list[str] | None = None
) -> dict[str, Any]:
"""Build a synthetic `updates` protocol event for testing `process`."""
return {
"type": "event",
"method": "updates",
"params": {
"namespace": namespace or [],
"timestamp": 1234,
"data": node_updates,
},
}


def _fresh_transformer() -> MiddlewareTransformer:
"""Build a bound-and-subscribed transformer so `push` accumulates."""
tr = MiddlewareTransformer()
log = tr.init()["middleware"]
log._bind(is_async=False)
# Mark subscribed so `log.push` records items. Normally this happens via
# `iter(log)`; for unit-level tests we want to inspect `_items` after
# `process` without consuming the iterator mid-test.
log._subscribed = True
return tr


class TestMiddlewareTransformerProcess:
"""`MiddlewareTransformer.process` in isolation."""

def test_emits_event_for_middleware_node(self) -> None:
tr = _fresh_transformer()
kept = tr.process(_updates_event({"MyMW.before_model": {"messages": ["hi"]}}))
assert kept is True # event passes through
events = list(tr._log._items)
assert len(events) == 1
ev = events[0]
assert isinstance(ev, MiddlewareEvent)
assert ev.phase == "before_model"
assert ev.middleware_name == "MyMW"
assert ev.state_delta == {"messages": ["hi"]}
assert ev.timestamp == 1234

def test_all_four_phases(self) -> None:
tr = _fresh_transformer()
for phase in (
"before_agent",
"before_model",
"after_model",
"after_agent",
):
tr.process(_updates_event({f"MW.{phase}": {"k": phase}}))
phases = [ev.phase for ev in tr._log._items]
assert phases == [
"before_agent",
"before_model",
"after_model",
"after_agent",
]

def test_ignores_non_middleware_nodes(self) -> None:
tr = _fresh_transformer()
tr.process(_updates_event({"model": {"messages": []}}))
tr.process(_updates_event({"tools": {"messages": []}}))
tr.process(_updates_event({"MW.weird_phase": {}})) # not a real phase
tr.process(_updates_event({"NoDot": {}}))
assert list(tr._log._items) == []

def test_ignores_non_updates_methods(self) -> None:
tr = _fresh_transformer()
kept = tr.process(
{
"type": "event",
"method": "values",
"params": {
"namespace": [],
"timestamp": 0,
"data": {"anything": "goes"},
},
}
)
assert kept is True
assert list(tr._log._items) == []

def test_multi_node_batch_in_single_event(self) -> None:
tr = _fresh_transformer()
tr.process(
_updates_event(
{
"A.before_model": {"x": 1},
"B.after_model": {"y": 2},
"regular_node": {"z": 3},
}
)
)
events = list(tr._log._items)
assert len(events) == 2
names_by_phase = {ev.phase: ev.middleware_name for ev in events}
assert names_by_phase == {"before_model": "A", "after_model": "B"}

def test_non_dict_state_delta_becomes_empty_dict(self) -> None:
tr = _fresh_transformer()
tr.process(_updates_event({"MW.before_agent": None}))
events = list(tr._log._items)
assert len(events) == 1
assert events[0].state_delta == {}

def test_dotted_middleware_name_uses_rpartition(self) -> None:
# Middleware name with a `.` in it — rpartition splits on the last one.
tr = _fresh_transformer()
tr.process(_updates_event({"my.name.spaced.MW.before_model": {"ok": 1}}))
events = list(tr._log._items)
assert len(events) == 1
assert events[0].middleware_name == "my.name.spaced.MW"
assert events[0].phase == "before_model"


# ---------------------------------------------------------------------------
# Integration: create_agent auto-registers it; run.middleware works
# ---------------------------------------------------------------------------


class _MarkerMiddleware(AgentMiddleware):
"""Touches state in `before_model` so we can see an updates event fire."""

name = "MarkerMW"

def before_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None:
return {"messages": []}


class TestCreateAgentRegistersMiddlewareTransformer:
def test_middleware_projection_present(self) -> None:
model = FakeToolCallingModel(tool_calls=[[], []])
agent = create_agent(model, [], middleware=[_MarkerMiddleware()])
run = agent.stream_v2({"messages": [HumanMessage("hi")]})
assert "middleware" in run.extensions
assert hasattr(run, "middleware")
# Drive to completion.
run.output # noqa: B018

def test_middleware_events_emitted_for_before_model_hook(self) -> None:
model = FakeToolCallingModel(
tool_calls=[[], []],
# Have the model respond once and stop.
responses=[AIMessage(content="done", id="a1", tool_calls=[])],
)
agent = create_agent(model, [], middleware=[_MarkerMiddleware()])
run = agent.stream_v2({"messages": [HumanMessage("hi")]})
events: list[MiddlewareEvent] = list(run.middleware)
assert any(
ev.phase == "before_model" and ev.middleware_name == "MarkerMW" for ev in events
), f"expected a MarkerMW.before_model event; got {events}"

def test_no_events_when_no_middleware(self) -> None:
model = FakeToolCallingModel(
tool_calls=[[]],
responses=[AIMessage(content="done", id="a1", tool_calls=[])],
)
agent = create_agent(model, [])
run = agent.stream_v2({"messages": [HumanMessage("hi")]})
events = list(run.middleware)
assert events == []
Loading