diff --git a/libs/langchain_v1/langchain/agents/__init__.py b/libs/langchain_v1/langchain/agents/__init__.py index 5e5c545fe218a..f628c864e2c0c 100644 --- a/libs/langchain_v1/langchain/agents/__init__.py +++ b/libs/langchain_v1/langchain/agents/__init__.py @@ -1,9 +1,17 @@ """Entrypoint to building [Agents](https://docs.langchain.com/oss/python/langchain/agents) with LangChain.""" # noqa: E501 +from langchain.agents._middleware_transformer import ( + MiddlewareEvent, + MiddlewarePhase, + MiddlewareTransformer, +) from langchain.agents.factory import create_agent from langchain.agents.middleware.types import AgentState __all__ = [ "AgentState", + "MiddlewareEvent", + "MiddlewarePhase", + "MiddlewareTransformer", "create_agent", ] diff --git a/libs/langchain_v1/langchain/agents/_middleware_transformer.py b/libs/langchain_v1/langchain/agents/_middleware_transformer.py new file mode 100644 index 0000000000000..b2178e1a8fb0c --- /dev/null +++ b/libs/langchain_v1/langchain/agents/_middleware_transformer.py @@ -0,0 +1,131 @@ +"""Project graph `updates` events into typed `MiddlewareEvent` handles. + +Exposed via `run.middleware` on `AgentStreamer` runs. + +`create_agent` wires each middleware hook into the graph as a real +node named `"{middleware_name}.{phase}"`, where `phase` is one of +`before_agent` / `before_model` / `after_model` / `after_agent`. +When any of those nodes run, Pregel emits an `updates` event with +the node's state delta — we subscribe to `updates`, filter by the +phase-name suffix, and re-emit as `MiddlewareEvent` objects. + +No cooperation from middleware implementations is required — the +node-name convention is the contract, and the transformer is a pure +consumer of the graph's existing `updates` stream mode. + +`run.middleware` is an in-process-only projection: it surfaces +lifecycle information for local consumers and is not intended to +cross a wire boundary. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Literal + +from langgraph.stream._event_log import EventLog +from langgraph.stream._types import StreamTransformer + +if TYPE_CHECKING: + from langgraph.stream._types import ProtocolEvent + + +MiddlewarePhase = Literal["before_agent", "before_model", "after_model", "after_agent"] +"""Lifecycle phase a middleware node represents.""" + + +_PHASES: frozenset[str] = frozenset(("before_agent", "before_model", "after_model", "after_agent")) + + +@dataclass(frozen=True) +class MiddlewareEvent: + """One lifecycle transition from a middleware node. + + Emitted once per node execution, carrying the phase, the name of + the middleware that produced it, and the state delta returned by + the hook. + + Attributes: + phase: Which hook fired — `before_agent` / `before_model` / + `after_model` / `after_agent`. + middleware_name: The `name` attribute of the middleware class + that produced this event (the part before the `.` in the + graph node name). + state_delta: The state diff the middleware hook returned. + Empty dict if the hook didn't return an update. + timestamp: Milliseconds since epoch, taken from the underlying + protocol event. + """ + + phase: MiddlewarePhase + middleware_name: str + state_delta: dict[str, Any] + timestamp: int + + +def _split_middleware_node(node_name: str) -> tuple[str, MiddlewarePhase] | None: + """Split `"{name}.{phase}"` into `(name, phase)` if `phase` is valid. + + Returns `None` for any node name that doesn't match the convention, + so non-middleware nodes are ignored. + """ + name, sep, suffix = node_name.rpartition(".") + if not sep or suffix not in _PHASES: + return None + return name, suffix # type: ignore[return-value] + + +class MiddlewareTransformer(StreamTransformer): + """Project `updates` events from middleware nodes into `MiddlewareEvent`s. + + Watches `updates` events at its mux scope, inspects the node name, + and emits one `MiddlewareEvent` per middleware-node execution onto + the `middleware` projection. + + Native transformer — the `middleware` projection is exposed as a + direct attribute (`run.middleware`). Pre-registered by + `AgentStreamer.builtin_factories` so it's on by default for any + `create_agent` graph. + + `scope` (inherited from `StreamTransformer`) lets the transformer + scope its filtering to its own mux — root mux sees root middleware, + subgraph mini-muxes see their own middleware, via the same + machinery `MessagesTransformer` uses. + """ + + _native: ClassVar[bool] = True + required_stream_modes: ClassVar[tuple[str, ...]] = ("updates",) + + def __init__(self, scope: tuple[str, ...] = ()) -> None: + super().__init__(scope) + self._log: EventLog[MiddlewareEvent] = EventLog() + + def init(self) -> dict[str, Any]: + return {"middleware": self._log} + + def process(self, event: ProtocolEvent) -> bool: + # Namespace filtering is handled by the mux via `scope_exact`. + if event["method"] != "updates": + return True + + params = event["params"] + data = params.get("data") + if not isinstance(data, dict): + return True + + timestamp = int(params.get("timestamp", 0)) + for node_name, state_delta in data.items(): + split = _split_middleware_node(node_name) + if split is None: + continue + name, phase = split + self._log.push( + MiddlewareEvent( + phase=phase, + middleware_name=name, + state_delta=state_delta if isinstance(state_delta, dict) else {}, + timestamp=timestamp, + ) + ) + # Pass through — wire consumers still see the raw `updates` event. + return True diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 4161e5cfd73ea..d2e3d317b77ab 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -28,6 +28,7 @@ from langsmith import traceable from typing_extensions import NotRequired, Required, TypedDict +from langchain.agents._middleware_transformer import MiddlewareTransformer from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, @@ -1667,7 +1668,7 @@ async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> lis debug=debug, name=name, cache=cache, - transformers=[ToolCallTransformer, *(transformers or ())], + transformers=[ToolCallTransformer, MiddlewareTransformer, *(transformers or ())], ).with_config(config) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_transformer.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_transformer.py new file mode 100644 index 0000000000000..46f53081ce281 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_transformer.py @@ -0,0 +1,188 @@ +"""Unit tests for `MiddlewareTransformer`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from langchain_core.messages import AIMessage, HumanMessage + +from langchain.agents import ( + MiddlewareEvent, + MiddlewareTransformer, + create_agent, +) +from langchain.agents.middleware import AgentMiddleware +from tests.unit_tests.agents.model import FakeToolCallingModel + +if TYPE_CHECKING: + from langchain.agents.middleware.types import AgentState + + +# --------------------------------------------------------------------------- +# Unit-level: feed the transformer synthetic protocol events +# --------------------------------------------------------------------------- + + +def _updates_event( + node_updates: dict[str, Any], *, namespace: list[str] | None = None +) -> dict[str, Any]: + """Build a synthetic `updates` protocol event for testing `process`.""" + return { + "type": "event", + "method": "updates", + "params": { + "namespace": namespace or [], + "timestamp": 1234, + "data": node_updates, + }, + } + + +def _fresh_transformer() -> MiddlewareTransformer: + """Build a bound-and-subscribed transformer so `push` accumulates.""" + tr = MiddlewareTransformer() + log = tr.init()["middleware"] + log._bind(is_async=False) + # Mark subscribed so `log.push` records items. Normally this happens via + # `iter(log)`; for unit-level tests we want to inspect `_items` after + # `process` without consuming the iterator mid-test. + log._subscribed = True + return tr + + +class TestMiddlewareTransformerProcess: + """`MiddlewareTransformer.process` in isolation.""" + + def test_emits_event_for_middleware_node(self) -> None: + tr = _fresh_transformer() + kept = tr.process(_updates_event({"MyMW.before_model": {"messages": ["hi"]}})) + assert kept is True # event passes through + events = list(tr._log._items) + assert len(events) == 1 + ev = events[0] + assert isinstance(ev, MiddlewareEvent) + assert ev.phase == "before_model" + assert ev.middleware_name == "MyMW" + assert ev.state_delta == {"messages": ["hi"]} + assert ev.timestamp == 1234 + + def test_all_four_phases(self) -> None: + tr = _fresh_transformer() + for phase in ( + "before_agent", + "before_model", + "after_model", + "after_agent", + ): + tr.process(_updates_event({f"MW.{phase}": {"k": phase}})) + phases = [ev.phase for ev in tr._log._items] + assert phases == [ + "before_agent", + "before_model", + "after_model", + "after_agent", + ] + + def test_ignores_non_middleware_nodes(self) -> None: + tr = _fresh_transformer() + tr.process(_updates_event({"model": {"messages": []}})) + tr.process(_updates_event({"tools": {"messages": []}})) + tr.process(_updates_event({"MW.weird_phase": {}})) # not a real phase + tr.process(_updates_event({"NoDot": {}})) + assert list(tr._log._items) == [] + + def test_ignores_non_updates_methods(self) -> None: + tr = _fresh_transformer() + kept = tr.process( + { + "type": "event", + "method": "values", + "params": { + "namespace": [], + "timestamp": 0, + "data": {"anything": "goes"}, + }, + } + ) + assert kept is True + assert list(tr._log._items) == [] + + def test_multi_node_batch_in_single_event(self) -> None: + tr = _fresh_transformer() + tr.process( + _updates_event( + { + "A.before_model": {"x": 1}, + "B.after_model": {"y": 2}, + "regular_node": {"z": 3}, + } + ) + ) + events = list(tr._log._items) + assert len(events) == 2 + names_by_phase = {ev.phase: ev.middleware_name for ev in events} + assert names_by_phase == {"before_model": "A", "after_model": "B"} + + def test_non_dict_state_delta_becomes_empty_dict(self) -> None: + tr = _fresh_transformer() + tr.process(_updates_event({"MW.before_agent": None})) + events = list(tr._log._items) + assert len(events) == 1 + assert events[0].state_delta == {} + + def test_dotted_middleware_name_uses_rpartition(self) -> None: + # Middleware name with a `.` in it — rpartition splits on the last one. + tr = _fresh_transformer() + tr.process(_updates_event({"my.name.spaced.MW.before_model": {"ok": 1}})) + events = list(tr._log._items) + assert len(events) == 1 + assert events[0].middleware_name == "my.name.spaced.MW" + assert events[0].phase == "before_model" + + +# --------------------------------------------------------------------------- +# Integration: create_agent auto-registers it; run.middleware works +# --------------------------------------------------------------------------- + + +class _MarkerMiddleware(AgentMiddleware): + """Touches state in `before_model` so we can see an updates event fire.""" + + name = "MarkerMW" + + def before_model(self, state: AgentState, runtime: Any) -> dict[str, Any] | None: + return {"messages": []} + + +class TestCreateAgentRegistersMiddlewareTransformer: + def test_middleware_projection_present(self) -> None: + model = FakeToolCallingModel(tool_calls=[[], []]) + agent = create_agent(model, [], middleware=[_MarkerMiddleware()]) + run = agent.stream_v2({"messages": [HumanMessage("hi")]}) + assert "middleware" in run.extensions + assert hasattr(run, "middleware") + # Drive to completion. + run.output # noqa: B018 + + def test_middleware_events_emitted_for_before_model_hook(self) -> None: + model = FakeToolCallingModel( + tool_calls=[[], []], + # Have the model respond once and stop. + responses=[AIMessage(content="done", id="a1", tool_calls=[])], + ) + agent = create_agent(model, [], middleware=[_MarkerMiddleware()]) + run = agent.stream_v2({"messages": [HumanMessage("hi")]}) + events: list[MiddlewareEvent] = list(run.middleware) + assert any( + ev.phase == "before_model" and ev.middleware_name == "MarkerMW" for ev in events + ), f"expected a MarkerMW.before_model event; got {events}" + + def test_no_events_when_no_middleware(self) -> None: + model = FakeToolCallingModel( + tool_calls=[[]], + responses=[AIMessage(content="done", id="a1", tool_calls=[])], + ) + agent = create_agent(model, []) + run = agent.stream_v2({"messages": [HumanMessage("hi")]}) + events = list(run.middleware) + assert events == []