Skip to content

Commit 4ab9457

Browse files
feat(langchain): support SystemMessage in create_agent's system_prompt (#34055)
* `create_agent`'s `system_prompt` allows `str | SystemMessage` * added `system_message: SystemMessage` on `ModelRequest` * `ModelRequest.system_prompt` is a function of `system_message.text`, now deprecated * disallow setting `system_prompt` and `system_message` * `ModelRequest.system_prompt` can still be set (w/ custom setattr) for custom backwards compat, but the updates just get propogated to the `ModelRequest.system_message` --------- Co-authored-by: Chester Curme <[email protected]>
1 parent eb0545a commit 4ab9457

File tree

9 files changed

+1411
-83
lines changed

9 files changed

+1411
-83
lines changed

libs/langchain_v1/langchain/agents/factory.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def create_agent( # noqa: PLR0915
542542
model: str | BaseChatModel,
543543
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
544544
*,
545-
system_prompt: str | None = None,
545+
system_prompt: str | SystemMessage | None = None,
546546
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
547547
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
548548
state_schema: type[AgentState[ResponseT]] | None = None,
@@ -588,9 +588,9 @@ def create_agent( # noqa: PLR0915
588588
docs for more information.
589589
system_prompt: An optional system prompt for the LLM.
590590
591-
Prompts are converted to a
592-
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
593-
beginning of the message list.
591+
Can be a `str` (which will be converted to a `SystemMessage`) or a
592+
`SystemMessage` instance directly. The system message is added to the
593+
beginning of the message list when calling the model.
594594
middleware: A sequence of middleware instances to apply to the agent.
595595
596596
Middleware can intercept and modify agent behavior at various stages.
@@ -685,6 +685,14 @@ def check_weather(location: str) -> str:
685685
if isinstance(model, str):
686686
model = init_chat_model(model)
687687

688+
# Convert system_prompt to SystemMessage if needed
689+
system_message: SystemMessage | None = None
690+
if system_prompt is not None:
691+
if isinstance(system_prompt, SystemMessage):
692+
system_message = system_prompt
693+
else:
694+
system_message = SystemMessage(content=system_prompt)
695+
688696
# Handle tools being None or empty
689697
if tools is None:
690698
tools = []
@@ -1088,8 +1096,8 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse:
10881096
# Get the bound model (with auto-detection if needed)
10891097
model_, effective_response_format = _get_bound_model(request)
10901098
messages = request.messages
1091-
if request.system_prompt:
1092-
messages = [SystemMessage(request.system_prompt), *messages]
1099+
if request.system_message:
1100+
messages = [request.system_message, *messages]
10931101

10941102
output = model_.invoke(messages)
10951103

@@ -1108,7 +1116,7 @@ def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
11081116
request = ModelRequest(
11091117
model=model,
11101118
tools=default_tools,
1111-
system_prompt=system_prompt,
1119+
system_message=system_message,
11121120
response_format=initial_response_format,
11131121
messages=state["messages"],
11141122
tool_choice=None,
@@ -1141,8 +1149,8 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse:
11411149
# Get the bound model (with auto-detection if needed)
11421150
model_, effective_response_format = _get_bound_model(request)
11431151
messages = request.messages
1144-
if request.system_prompt:
1145-
messages = [SystemMessage(request.system_prompt), *messages]
1152+
if request.system_message:
1153+
messages = [request.system_message, *messages]
11461154

11471155
output = await model_.ainvoke(messages)
11481156

@@ -1161,7 +1169,7 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
11611169
request = ModelRequest(
11621170
model=model,
11631171
tools=default_tools,
1164-
system_prompt=system_prompt,
1172+
system_message=system_message,
11651173
response_format=initial_response_format,
11661174
messages=state["messages"],
11671175
tool_choice=None,

libs/langchain_v1/langchain/agents/middleware/context_editing.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
AIMessage,
1919
AnyMessage,
2020
BaseMessage,
21-
SystemMessage,
2221
ToolMessage,
2322
)
2423
from langchain_core.messages.utils import count_tokens_approximately
@@ -230,9 +229,7 @@ def wrap_model_call(
230229
def count_tokens(messages: Sequence[BaseMessage]) -> int:
231230
return count_tokens_approximately(messages)
232231
else:
233-
system_msg = (
234-
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
235-
)
232+
system_msg = [request.system_message] if request.system_message else []
236233

237234
def count_tokens(messages: Sequence[BaseMessage]) -> int:
238235
return request.model.get_num_tokens_from_messages(
@@ -259,9 +256,7 @@ async def awrap_model_call(
259256
def count_tokens(messages: Sequence[BaseMessage]) -> int:
260257
return count_tokens_approximately(messages)
261258
else:
262-
system_msg = (
263-
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
264-
)
259+
system_msg = [request.system_message] if request.system_message else []
265260

266261
def count_tokens(messages: Sequence[BaseMessage]) -> int:
267262
return request.model.get_num_tokens_from_messages(

libs/langchain_v1/langchain/agents/middleware/todo.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
from __future__ import annotations
55

6-
from typing import TYPE_CHECKING, Annotated, Literal
6+
from typing import TYPE_CHECKING, Annotated, Literal, cast
77

88
if TYPE_CHECKING:
99
from collections.abc import Awaitable, Callable
1010

11-
from langchain_core.messages import ToolMessage
11+
from langchain_core.messages import SystemMessage, ToolMessage
1212
from langchain_core.tools import tool
1313
from langgraph.types import Command
1414
from typing_extensions import NotRequired, TypedDict
@@ -193,23 +193,33 @@ def wrap_model_call(
193193
request: ModelRequest,
194194
handler: Callable[[ModelRequest], ModelResponse],
195195
) -> ModelCallResult:
196-
"""Update the system prompt to include the todo system prompt."""
197-
new_system_prompt = (
198-
request.system_prompt + "\n\n" + self.system_prompt
199-
if request.system_prompt
200-
else self.system_prompt
196+
"""Update the system message to include the todo system prompt."""
197+
if request.system_message is not None:
198+
new_system_content = [
199+
*request.system_message.content_blocks,
200+
{"type": "text", "text": f"\n\n{self.system_prompt}"},
201+
]
202+
else:
203+
new_system_content = [{"type": "text", "text": self.system_prompt}]
204+
new_system_message = SystemMessage(
205+
content=cast("list[str | dict[str, str]]", new_system_content)
201206
)
202-
return handler(request.override(system_prompt=new_system_prompt))
207+
return handler(request.override(system_message=new_system_message))
203208

204209
async def awrap_model_call(
205210
self,
206211
request: ModelRequest,
207212
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
208213
) -> ModelCallResult:
209-
"""Update the system prompt to include the todo system prompt (async version)."""
210-
new_system_prompt = (
211-
request.system_prompt + "\n\n" + self.system_prompt
212-
if request.system_prompt
213-
else self.system_prompt
214+
"""Update the system message to include the todo system prompt (async version)."""
215+
if request.system_message is not None:
216+
new_system_content = [
217+
*request.system_message.content_blocks,
218+
{"type": "text", "text": f"\n\n{self.system_prompt}"},
219+
]
220+
else:
221+
new_system_content = [{"type": "text", "text": self.system_prompt}]
222+
new_system_message = SystemMessage(
223+
content=cast("list[str | dict[str, str]]", new_system_content)
214224
)
215-
return await handler(request.override(system_prompt=new_system_prompt))
225+
return await handler(request.override(system_message=new_system_message))

0 commit comments

Comments
 (0)