Skip to content
Open
149 changes: 91 additions & 58 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from langchain.agents.middleware.types import (
AgentMiddleware,
AgentRuntime,
AgentState,
JumpTo,
ModelRequest,
Expand Down Expand Up @@ -519,6 +520,7 @@ def create_agent( # noqa: PLR0915
debug: bool = False,
name: str | None = None,
cache: BaseCache | None = None,
backend: object | None = None,
) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
]:
Expand Down Expand Up @@ -703,6 +705,53 @@ def check_weather(location: str) -> str:
else:
default_tools = list(built_in_tools)

# Derive model_name for AgentRuntime (best-effort; None for dynamic callables)
_agent_model_name: str | None = None
if isinstance(model, str):
_agent_model_name = model
elif hasattr(model, "model_name"):
_agent_model_name = model.model_name
elif hasattr(model, "model"):
_agent_model_name = model.model

# Wrappers that convert the LangGraph Runtime into an AgentRuntime before
# dispatching to each middleware hook. _build_runtime calls are accumulated
# across all middlewares in order, so a subpackage that prepends a specialised
# middleware can inject an enriched runtime subclass for all hooks downstream.
def _accumulate_runtime(ar: AgentRuntime[ContextT]) -> AgentRuntime[ContextT]:
for mw in middleware:
ar = mw._build_runtime(ar)
return ar

def _build_hook_runtime(runtime: Runtime[ContextT]) -> AgentRuntime[ContextT]:
return _accumulate_runtime(
AgentRuntime.from_runtime(
name or "agent",
runtime,
model_name=_agent_model_name,
tools=default_tools,
backend=backend,
)
)

def _wrap_hook(hook):
if hook is None:
return None

def _wrapped(state: AgentState, runtime: Runtime[ContextT]):
return hook(state, _build_hook_runtime(runtime))

return _wrapped

def _wrap_async_hook(hook):
if hook is None:
return None

async def _wrapped(state: AgentState, runtime: Runtime[ContextT]):
return await hook(state, _build_hook_runtime(runtime))

return _wrapped

# validate middleware
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
Expand Down Expand Up @@ -1018,16 +1067,20 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse:

def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
agent_runtime = _accumulate_runtime(
AgentRuntime.from_runtime(
name or "agent",
runtime,
model_name=_agent_model_name,
model=model if isinstance(model, BaseChatModel) else None,
system_prompt=system_prompt,
tools=default_tools,
tool_choice=None,
response_format=initial_response_format,
backend=backend,
)
)
request = ModelRequest.from_runtime(agent_runtime, messages=state["messages"], state=state)

if wrap_model_call_handler is None:
# No handlers - execute directly
Expand Down Expand Up @@ -1071,16 +1124,20 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse:

async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
agent_runtime = _accumulate_runtime(
AgentRuntime.from_runtime(
name or "agent",
runtime,
model_name=_agent_model_name,
model=model if isinstance(model, BaseChatModel) else None,
system_prompt=system_prompt,
tools=default_tools,
tool_choice=None,
response_format=initial_response_format,
backend=backend,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Async agents crash on backend reference

The backend parameter was removed from create_agent and AgentRuntime.from_runtime, but the async model node still passes backend=backend. Any agent.ainvoke() now fails before the model runs with NameError: name 'backend' is not defined (I verified this with a minimal create_agent(...).ainvoke(...)). Please remove this argument here as well, matching the sync path above.

(Refers to line 1134)


Was this helpful? React with 👍 or 👎 to provide feedback.

)
)
request = ModelRequest.from_runtime(agent_runtime, messages=state["messages"], state=state)

if awrap_model_call_handler is None:
# No async handlers - execute directly
Expand Down Expand Up @@ -1109,17 +1166,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.before_agent is not AgentMiddleware.before_agent
or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before_agent = (
m.before_agent
if m.__class__.before_agent is not AgentMiddleware.before_agent
else None
sync_before_agent = _wrap_hook(
m.before_agent if m.__class__.before_agent is not AgentMiddleware.before_agent else None
)
async_before_agent = (
m.abefore_agent
if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
else None
async_before_agent = _wrap_async_hook(
m.abefore_agent if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent else None
)
before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
graph.add_node(
Expand All @@ -1130,17 +1181,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before = (
m.before_model
if m.__class__.before_model is not AgentMiddleware.before_model
else None
sync_before = _wrap_hook(
m.before_model if m.__class__.before_model is not AgentMiddleware.before_model else None
)
async_before = (
m.abefore_model
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
else None
async_before = _wrap_async_hook(
m.abefore_model if m.__class__.abefore_model is not AgentMiddleware.abefore_model else None
)
before_node = RunnableCallable(sync_before, async_before, trace=False)
graph.add_node(
Expand All @@ -1151,17 +1196,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after = (
m.after_model
if m.__class__.after_model is not AgentMiddleware.after_model
else None
sync_after = _wrap_hook(
m.after_model if m.__class__.after_model is not AgentMiddleware.after_model else None
)
async_after = (
m.aafter_model
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
else None
async_after = _wrap_async_hook(
m.aafter_model if m.__class__.aafter_model is not AgentMiddleware.aafter_model else None
)
after_node = RunnableCallable(sync_after, async_after, trace=False)
graph.add_node(f"{m.name}.after_model", after_node, input_schema=resolved_state_schema)
Expand All @@ -1170,17 +1209,11 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
m.__class__.after_agent is not AgentMiddleware.after_agent
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after_agent = (
m.after_agent
if m.__class__.after_agent is not AgentMiddleware.after_agent
else None
sync_after_agent = _wrap_hook(
m.after_agent if m.__class__.after_agent is not AgentMiddleware.after_agent else None
)
async_after_agent = (
m.aafter_agent
if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
else None
async_after_agent = _wrap_async_hook(
m.aafter_agent if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent else None
)
after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
graph.add_node(
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain_v1/langchain/agents/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentRuntime,
AgentState,
ModelRequest,
ModelResponse,
Expand All @@ -47,6 +48,7 @@

__all__ = [
"AgentMiddleware",
"AgentRuntime",
"AgentState",
"ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",
Expand Down
Loading