Skip to content
Open
109 changes: 67 additions & 42 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from langchain.agents.middleware.types import (
AgentMiddleware,
AgentRuntime,
AgentState,
JumpTo,
ModelRequest,
Expand Down Expand Up @@ -703,6 +704,48 @@ def check_weather(location: str) -> str:
else:
default_tools = list(built_in_tools)

# Derive model_name for AgentRuntime (best-effort; None for dynamic callables)
_agent_model_name: str | None = None
if isinstance(model, str):
_agent_model_name = model
elif hasattr(model, "model_name"):
_agent_model_name = model.model_name
elif hasattr(model, "model"):
_agent_model_name = model.model

# Wrappers that convert the LangGraph Runtime into an AgentRuntime before
# dispatching to each middleware hook. Middleware may further enrich the
# runtime by overriding _build_runtime (private, not a public extension point).
def _wrap_hook(hook, mw):
if hook is None:
return None

def _wrapped(state: AgentState, runtime: Runtime[ContextT]):
agent_runtime = AgentRuntime.from_runtime(
name or "agent",
runtime,
model_name=_agent_model_name,
tools=default_tools,
)
return hook(state, mw._build_runtime(agent_runtime))
Comment thread
open-swe[bot] marked this conversation as resolved.
Outdated

return _wrapped

def _wrap_async_hook(hook, mw):
if hook is None:
return None

async def _wrapped(state: AgentState, runtime: Runtime[ContextT]):
agent_runtime = AgentRuntime.from_runtime(
name or "agent",
runtime,
model_name=_agent_model_name,
tools=default_tools,
)
return await hook(state, mw._build_runtime(agent_runtime))

return _wrapped

# validate middleware
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
Expand Down Expand Up @@ -1018,6 +1061,9 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse:

def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
# Create flat AgentRuntime with all runtime properties
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime, model_name=_agent_model_name, tools=default_tools)

request = ModelRequest(
model=model,
tools=default_tools,
Expand All @@ -1026,7 +1072,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
runtime=agent_runtime,
)

if wrap_model_call_handler is None:
Expand Down Expand Up @@ -1071,6 +1117,9 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse:

async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
# Create flat AgentRuntime with all runtime properties
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime, model_name=_agent_model_name, tools=default_tools)

request = ModelRequest(
model=model,
tools=default_tools,
Expand All @@ -1079,7 +1128,7 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
runtime=agent_runtime,
)

if awrap_model_call_handler is None:
Expand Down Expand Up @@ -1109,17 +1158,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.before_agent is not AgentMiddleware.before_agent
or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before_agent = (
m.before_agent
if m.__class__.before_agent is not AgentMiddleware.before_agent
else None
sync_before_agent = _wrap_hook(
m.before_agent if m.__class__.before_agent is not AgentMiddleware.before_agent else None, m
)
async_before_agent = (
m.abefore_agent
if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
else None
async_before_agent = _wrap_async_hook(
m.abefore_agent if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent else None, m
)
before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
graph.add_node(
Expand All @@ -1130,17 +1173,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before = (
m.before_model
if m.__class__.before_model is not AgentMiddleware.before_model
else None
sync_before = _wrap_hook(
m.before_model if m.__class__.before_model is not AgentMiddleware.before_model else None, m
)
async_before = (
m.abefore_model
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
else None
async_before = _wrap_async_hook(
m.abefore_model if m.__class__.abefore_model is not AgentMiddleware.abefore_model else None, m
)
before_node = RunnableCallable(sync_before, async_before, trace=False)
graph.add_node(
Expand All @@ -1151,17 +1188,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after = (
m.after_model
if m.__class__.after_model is not AgentMiddleware.after_model
else None
sync_after = _wrap_hook(
m.after_model if m.__class__.after_model is not AgentMiddleware.after_model else None, m
)
async_after = (
m.aafter_model
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
else None
async_after = _wrap_async_hook(
m.aafter_model if m.__class__.aafter_model is not AgentMiddleware.aafter_model else None, m
)
after_node = RunnableCallable(sync_after, async_after, trace=False)
graph.add_node(f"{m.name}.after_model", after_node, input_schema=resolved_state_schema)
Expand All @@ -1170,17 +1201,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.after_agent is not AgentMiddleware.after_agent
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after_agent = (
m.after_agent
if m.__class__.after_agent is not AgentMiddleware.after_agent
else None
sync_after_agent = _wrap_hook(
m.after_agent if m.__class__.after_agent is not AgentMiddleware.after_agent else None, m
)
async_after_agent = (
m.aafter_agent
if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
else None
async_after_agent = _wrap_async_hook(
m.aafter_agent if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent else None, m
)
after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
graph.add_node(
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain_v1/langchain/agents/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentRuntime,
AgentState,
ModelRequest,
ModelResponse,
Expand All @@ -47,6 +48,7 @@

__all__ = [
"AgentMiddleware",
"AgentRuntime",
"AgentState",
"ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",
Expand Down
120 changes: 116 additions & 4 deletions libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.types import Command # noqa: TC002
from langgraph.store.base import BaseStore # noqa: TC002
from langgraph.types import Command, StreamWriter # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack

Expand All @@ -40,6 +41,7 @@

__all__ = [
"AgentMiddleware",
"AgentRuntime",
"AgentState",
"ContextT",
"ModelRequest",
Expand All @@ -60,6 +62,90 @@
ResponseT = TypeVar("ResponseT")


@dataclass
class AgentRuntime(Generic[ContextT]):
"""Runtime context for agent execution, extending LangGraph's Runtime.

This class provides agent-specific execution context to middleware, including
the name of the currently executing graph and all Runtime properties flattened
for convenient access.

The AgentRuntime follows the same pattern as ToolRuntime, providing a flat
structure with all runtime properties directly accessible.

Attributes:
agent_name: The name of the currently executing graph/agent. This is the
name passed to `create_agent(name=...)` or defaults to "LangGraph".
context: Static context for the graph run (e.g., `user_id`, `db_conn`).
store: Store for persistence and memory, if configured.
stream_writer: Function for writing to the custom stream.
previous: The previous return value for the given thread (functional API only).

Example:
```python
from langchain.agents.middleware import wrap_model_call, AgentRuntime
from langchain.agents.middleware.types import ModelRequest, ModelResponse


@wrap_model_call
def log_agent_name(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
'''Log which agent is making the model call.'''
agent_name = request.runtime.agent_name
print(f"Agent '{agent_name}' is calling the model")

# Access runtime context directly (flattened)
user_id = request.runtime.context.get("user_id")
print(f"User: {user_id}")

return handler(request)
```
"""

agent_name: str
"""The name of the currently executing graph/agent."""

context: ContextT = field(default=None) # type: ignore[assignment]
"""Static context for the graph run, like `user_id`, `db_conn`, etc."""

store: BaseStore | None = field(default=None)
"""Store for the graph run, enabling persistence and memory."""

stream_writer: StreamWriter = field(default=None) # type: ignore[assignment]
"""Function that writes to the custom stream."""

previous: Any = field(default=None)
"""The previous return value for the given thread."""

model_name: str | None = field(default=None)
"""Name of the model being used, if statically known."""

tools: list[BaseTool] = field(default_factory=list)
"""Tools registered with the agent."""

@classmethod
def from_runtime(
cls,
name: str,
runtime: Runtime[ContextT],
*,
model_name: str | None = None,
tools: list[BaseTool] | None = None,
) -> AgentRuntime[ContextT]:
"""Create an AgentRuntime from a Runtime."""
return AgentRuntime[ContextT](
agent_name=name,
context=runtime.context,
store=runtime.store,
stream_writer=runtime.stream_writer,
previous=runtime.previous,
model_name=model_name,
tools=tools or [],
)


class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for ModelRequest.override() method."""

Expand All @@ -74,7 +160,23 @@ class _ModelRequestOverrides(TypedDict, total=False):

@dataclass
class ModelRequest:
"""Model request information for the agent."""
"""Model request information for the agent.

This dataclass contains all the information needed for a model invocation,
including the model, messages, tools, and runtime context.

Attributes:
model: The chat model to invoke.
system_prompt: Optional system prompt to prepend to messages.
messages: List of conversation messages (excluding system prompt).
tool_choice: Tool selection configuration for the model.
tools: Available tools for the model to use.
response_format: Structured output format specification.
state: Complete agent state at the time of model invocation.
runtime: Agent runtime context including agent name and underlying
LangGraph Runtime with context, store, and stream_writer.
model_settings: Additional model-specific settings.
"""

model: BaseChatModel
system_prompt: str | None
Expand All @@ -83,7 +185,7 @@ class ModelRequest:
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
state: AgentState
runtime: Runtime[ContextT] # type: ignore[valid-type]
runtime: AgentRuntime[ContextT] # type: ignore[valid-type]
Comment on lines +160 to +162
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 ModelRequest constructor now rejects existing calls

Removing these dataclass fields from ModelRequest.__init__ breaks existing code that constructs a request directly with the documented invocation fields. This is not just theoretical: the current test suite still has many call sites such as ModelRequest(model=..., system_prompt=..., tools=..., response_format=...), and at this HEAD that raises TypeError: ModelRequest.__init__() got an unexpected keyword argument 'model'. Since ModelRequest is exported from the middleware package and used by middleware tests/users, please preserve the old constructor shape (for example with a custom __init__ that builds an AgentRuntime) while still delegating properties to the runtime internally.

(Refers to lines 167-169)


Was this helpful? React with 👍 or 👎 to provide feedback.

model_settings: dict[str, Any] = field(default_factory=dict)

def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
Expand Down Expand Up @@ -209,6 +311,16 @@ def name(self) -> str:
"""
return self.__class__.__name__

def _build_runtime(self, runtime: AgentRuntime[ContextT]) -> AgentRuntime[ContextT]:
"""Enrich AgentRuntime before it is passed to hook methods.

Called by the agent factory for every hook node (before_agent, before_model,
after_model, after_agent). The default is identity. Subpackages that need
extra fields on the runtime (e.g. a resolved backend) override this privately
— it is not a public extension point for end-user middleware.
"""
return runtime

def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the agent execution starts."""

Expand Down Expand Up @@ -932,7 +1044,7 @@ def before_agent(
```python
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
print(f"Starting agent with {len(state['messages'])} messages")
print(f"Starting agent '{runtime.agent_name}' with {len(state['messages'])} messages")
```

With conditional jumping:
Expand Down
Loading
Loading