-
Notifications
You must be signed in to change notification settings - Fork 23k
feat(agents): AgentRuntime fields + _build_runtime hook for middleware #37879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
98122b0
22e7deb
28c0278
ea3ec45
3b458bf
d50e54f
0889a1e
9fe6113
35be8df
7824218
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,8 @@ | |
| from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002 | ||
| from langgraph.channels.ephemeral_value import EphemeralValue | ||
| from langgraph.graph.message import add_messages | ||
| from langgraph.types import Command # noqa: TC002 | ||
| from langgraph.store.base import BaseStore # noqa: TC002 | ||
| from langgraph.types import Command, StreamWriter # noqa: TC002 | ||
| from langgraph.typing import ContextT | ||
| from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack | ||
|
|
||
|
|
@@ -40,6 +41,7 @@ | |
|
|
||
| __all__ = [ | ||
| "AgentMiddleware", | ||
| "AgentRuntime", | ||
| "AgentState", | ||
| "ContextT", | ||
| "ModelRequest", | ||
|
|
@@ -60,6 +62,90 @@ | |
| ResponseT = TypeVar("ResponseT") | ||
|
|
||
|
|
||
| @dataclass | ||
| class AgentRuntime(Generic[ContextT]): | ||
| """Runtime context for agent execution, extending LangGraph's Runtime. | ||
|
|
||
| This class provides agent-specific execution context to middleware, including | ||
| the name of the currently executing graph and all Runtime properties flattened | ||
| for convenient access. | ||
|
|
||
| The AgentRuntime follows the same pattern as ToolRuntime, providing a flat | ||
| structure with all runtime properties directly accessible. | ||
|
|
||
| Attributes: | ||
| agent_name: The name of the currently executing graph/agent. This is the | ||
| name passed to `create_agent(name=...)` or defaults to "LangGraph". | ||
| context: Static context for the graph run (e.g., `user_id`, `db_conn`). | ||
| store: Store for persistence and memory, if configured. | ||
| stream_writer: Function for writing to the custom stream. | ||
| previous: The previous return value for the given thread (functional API only). | ||
|
|
||
| Example: | ||
| ```python | ||
| from langchain.agents.middleware import wrap_model_call, AgentRuntime | ||
| from langchain.agents.middleware.types import ModelRequest, ModelResponse | ||
|
|
||
|
|
||
| @wrap_model_call | ||
| def log_agent_name( | ||
| request: ModelRequest, | ||
| handler: Callable[[ModelRequest], ModelResponse], | ||
| ) -> ModelResponse: | ||
| '''Log which agent is making the model call.''' | ||
| agent_name = request.runtime.agent_name | ||
| print(f"Agent '{agent_name}' is calling the model") | ||
|
|
||
| # Access runtime context directly (flattened) | ||
| user_id = request.runtime.context.get("user_id") | ||
| print(f"User: {user_id}") | ||
|
|
||
| return handler(request) | ||
| ``` | ||
| """ | ||
|
|
||
| agent_name: str | ||
| """The name of the currently executing graph/agent.""" | ||
|
|
||
| context: ContextT = field(default=None) # type: ignore[assignment] | ||
| """Static context for the graph run, like `user_id`, `db_conn`, etc.""" | ||
|
|
||
| store: BaseStore | None = field(default=None) | ||
| """Store for the graph run, enabling persistence and memory.""" | ||
|
|
||
| stream_writer: StreamWriter = field(default=None) # type: ignore[assignment] | ||
| """Function that writes to the custom stream.""" | ||
|
|
||
| previous: Any = field(default=None) | ||
| """The previous return value for the given thread.""" | ||
|
|
||
| model_name: str | None = field(default=None) | ||
| """Name of the model being used, if statically known.""" | ||
|
|
||
| tools: list[BaseTool] = field(default_factory=list) | ||
| """Tools registered with the agent.""" | ||
|
|
||
| @classmethod | ||
| def from_runtime( | ||
| cls, | ||
| name: str, | ||
| runtime: Runtime[ContextT], | ||
| *, | ||
| model_name: str | None = None, | ||
| tools: list[BaseTool] | None = None, | ||
| ) -> AgentRuntime[ContextT]: | ||
| """Create an AgentRuntime from a Runtime.""" | ||
| return AgentRuntime[ContextT]( | ||
| agent_name=name, | ||
| context=runtime.context, | ||
| store=runtime.store, | ||
| stream_writer=runtime.stream_writer, | ||
| previous=runtime.previous, | ||
| model_name=model_name, | ||
| tools=tools or [], | ||
| ) | ||
|
|
||
|
|
||
| class _ModelRequestOverrides(TypedDict, total=False): | ||
| """Possible overrides for ModelRequest.override() method.""" | ||
|
|
||
|
|
@@ -74,7 +160,23 @@ class _ModelRequestOverrides(TypedDict, total=False): | |
|
|
||
| @dataclass | ||
| class ModelRequest: | ||
| """Model request information for the agent.""" | ||
| """Model request information for the agent. | ||
|
|
||
| This dataclass contains all the information needed for a model invocation, | ||
| including the model, messages, tools, and runtime context. | ||
|
|
||
| Attributes: | ||
| model: The chat model to invoke. | ||
| system_prompt: Optional system prompt to prepend to messages. | ||
| messages: List of conversation messages (excluding system prompt). | ||
| tool_choice: Tool selection configuration for the model. | ||
| tools: Available tools for the model to use. | ||
| response_format: Structured output format specification. | ||
| state: Complete agent state at the time of model invocation. | ||
| runtime: Agent runtime context including agent name and underlying | ||
| LangGraph Runtime with context, store, and stream_writer. | ||
| model_settings: Additional model-specific settings. | ||
| """ | ||
|
|
||
| model: BaseChatModel | ||
| system_prompt: str | None | ||
|
|
@@ -83,7 +185,7 @@ class ModelRequest: | |
| tools: list[BaseTool | dict] | ||
| response_format: ResponseFormat | None | ||
| state: AgentState | ||
| runtime: Runtime[ContextT] # type: ignore[valid-type] | ||
| runtime: AgentRuntime[ContextT] # type: ignore[valid-type] | ||
|
Comment on lines
+160
to
+162
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟠 ModelRequest constructor now rejects existing calls Removing these dataclass fields from (Refers to lines 167-169) Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| model_settings: dict[str, Any] = field(default_factory=dict) | ||
|
|
||
| def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: | ||
|
|
@@ -209,6 +311,16 @@ def name(self) -> str: | |
| """ | ||
| return self.__class__.__name__ | ||
|
|
||
| def _build_runtime(self, runtime: AgentRuntime[ContextT]) -> AgentRuntime[ContextT]: | ||
| """Enrich AgentRuntime before it is passed to hook methods. | ||
|
|
||
| Called by the agent factory for every hook node (before_agent, before_model, | ||
| after_model, after_agent). The default is identity. Subpackages that need | ||
| extra fields on the runtime (e.g. a resolved backend) override this privately | ||
| — it is not a public extension point for end-user middleware. | ||
| """ | ||
| return runtime | ||
|
|
||
| def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: | ||
| """Logic to run before the agent execution starts.""" | ||
|
|
||
|
|
@@ -932,7 +1044,7 @@ def before_agent( | |
| ```python | ||
| @before_agent | ||
| def log_before_agent(state: AgentState, runtime: Runtime) -> None: | ||
| print(f"Starting agent with {len(state['messages'])} messages") | ||
| print(f"Starting agent '{runtime.agent_name}' with {len(state['messages'])} messages") | ||
| ``` | ||
|
|
||
| With conditional jumping: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.