Skip to content

Commit cf595dc

Browse files
chore(langchain): Support for SystemMessage in create_agent (#33640)
- **Description:** Updated Function Signature of `create_agent`, the system prompt can be both a list and string. I see no harm in doing this, since SystemMessage accepts both. - **Issue:** #33630 --------- Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
1 parent d27211c commit cf595dc

6 files changed

Lines changed: 313 additions & 31 deletions

File tree

libs/langchain_v1/langchain/agents/factory.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def create_agent( # noqa: PLR0915
516516
model: str | BaseChatModel,
517517
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
518518
*,
519-
system_prompt: str | None = None,
519+
system_prompt: str | SystemMessage | None = None,
520520
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
521521
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
522522
state_schema: type[AgentState[ResponseT]] | None = None,
@@ -548,11 +548,9 @@ def create_agent( # noqa: PLR0915
548548
549549
If `None` or an empty list, the agent will consist of a model node without a
550550
tool calling loop.
551-
system_prompt: An optional system prompt for the LLM.
551+
system_prompt: An optional system prompt for the LLM or
552+
can already be a [`SystemMessage`][langchain.messages.SystemMessage] object.
552553
553-
Prompts are converted to a
554-
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
555-
beginning of the message list.
556554
middleware: A sequence of middleware instances to apply to the agent.
557555
558556
Middleware can intercept and modify agent behavior at various stages. See
@@ -1040,8 +1038,10 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse:
10401038
# Get the bound model (with auto-detection if needed)
10411039
model_, effective_response_format = _get_bound_model(request)
10421040
messages = request.messages
1043-
if request.system_prompt:
1044-
messages = [SystemMessage(request.system_prompt), *messages]
1041+
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
1042+
messages = [SystemMessage(content=request.system_prompt), *messages]
1043+
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
1044+
messages = [request.system_prompt, *messages]
10451045

10461046
output = model_.invoke(messages)
10471047

@@ -1093,8 +1093,10 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse:
10931093
# Get the bound model (with auto-detection if needed)
10941094
model_, effective_response_format = _get_bound_model(request)
10951095
messages = request.messages
1096-
if request.system_prompt:
1097-
messages = [SystemMessage(request.system_prompt), *messages]
1096+
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
1097+
messages = [SystemMessage(content=request.system_prompt), *messages]
1098+
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
1099+
messages = [request.system_prompt, *messages]
10981100

10991101
output = await model_.ainvoke(messages)
11001102

libs/langchain_v1/langchain/agents/middleware/context_editing.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,11 @@ def wrap_model_call(
225225
def count_tokens(messages: Sequence[BaseMessage]) -> int:
226226
return count_tokens_approximately(messages)
227227
else:
228-
system_msg = (
229-
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
230-
)
228+
system_msg = []
229+
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
230+
system_msg = [SystemMessage(content=request.system_prompt)]
231+
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
232+
system_msg = [request.system_prompt]
231233

232234
def count_tokens(messages: Sequence[BaseMessage]) -> int:
233235
return request.model.get_num_tokens_from_messages(
@@ -253,9 +255,12 @@ async def awrap_model_call(
253255
def count_tokens(messages: Sequence[BaseMessage]) -> int:
254256
return count_tokens_approximately(messages)
255257
else:
256-
system_msg = (
257-
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
258-
)
258+
system_msg = []
259+
260+
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
261+
system_msg = [SystemMessage(content=request.system_prompt)]
262+
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
263+
system_msg = [request.system_prompt]
259264

260265
def count_tokens(messages: Sequence[BaseMessage]) -> int:
261266
return request.model.get_num_tokens_from_messages(

libs/langchain_v1/langchain/agents/middleware/todo.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if TYPE_CHECKING:
99
from collections.abc import Awaitable, Callable
1010

11-
from langchain_core.messages import ToolMessage
11+
from langchain_core.messages import SystemMessage, ToolMessage
1212
from langchain_core.tools import tool
1313
from langgraph.types import Command
1414
from typing_extensions import NotRequired, TypedDict
@@ -199,11 +199,22 @@ def wrap_model_call(
199199
handler: Callable[[ModelRequest], ModelResponse],
200200
) -> ModelCallResult:
201201
"""Update the system prompt to include the todo system prompt."""
202-
request.system_prompt = (
203-
request.system_prompt + "\n\n" + self.system_prompt
204-
if request.system_prompt
205-
else self.system_prompt
206-
)
202+
if request.system_prompt is None:
203+
request.system_prompt = self.system_prompt
204+
elif isinstance(request.system_prompt, str):
205+
request.system_prompt = request.system_prompt + "\n\n" + self.system_prompt
206+
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
207+
request.system_prompt.content, str
208+
):
209+
request.system_prompt = SystemMessage(
210+
content=request.system_prompt.content + self.system_prompt
211+
)
212+
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
213+
request.system_prompt.content, list
214+
):
215+
request.system_prompt = SystemMessage(
216+
content=[*request.system_prompt.content, self.system_prompt]
217+
)
207218
return handler(request)
208219

209220
async def awrap_model_call(
@@ -212,9 +223,20 @@ async def awrap_model_call(
212223
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
213224
) -> ModelCallResult:
214225
"""Update the system prompt to include the todo system prompt (async version)."""
215-
request.system_prompt = (
216-
request.system_prompt + "\n\n" + self.system_prompt
217-
if request.system_prompt
218-
else self.system_prompt
219-
)
226+
if request.system_prompt is None:
227+
request.system_prompt = self.system_prompt
228+
elif isinstance(request.system_prompt, str):
229+
request.system_prompt = request.system_prompt + "\n\n" + self.system_prompt
230+
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
231+
request.system_prompt.content, str
232+
):
233+
request.system_prompt = SystemMessage(
234+
content=request.system_prompt.content + self.system_prompt
235+
)
236+
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
237+
request.system_prompt.content, list
238+
):
239+
request.system_prompt = SystemMessage(
240+
content=[*request.system_prompt.content, self.system_prompt]
241+
)
220242
return await handler(request)

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AIMessage,
2727
AnyMessage,
2828
BaseMessage,
29+
SystemMessage,
2930
ToolMessage,
3031
)
3132
from langgraph.channels.ephemeral_value import EphemeralValue
@@ -85,7 +86,7 @@ class ModelRequest:
8586
"""Model request information for the agent."""
8687

8788
model: BaseChatModel
88-
system_prompt: str | None
89+
system_prompt: str | SystemMessage | None
8990
messages: list[AnyMessage] # excluding system prompt
9091
tool_choice: Any | None
9192
tools: list[BaseTool | dict]
@@ -103,7 +104,7 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
103104
Args:
104105
**overrides: Keyword arguments for attributes to override. Supported keys:
105106
- model: BaseChatModel instance
106-
- system_prompt: Optional system prompt string
107+
- system_prompt: Optional system prompt string or SystemMessage object
107108
- messages: List of messages
108109
- tool_choice: Tool choice configuration
109110
- tools: List of available tools
@@ -1256,7 +1257,7 @@ def wrapped(
12561257
request: ModelRequest,
12571258
handler: Callable[[ModelRequest], ModelResponse],
12581259
) -> ModelCallResult:
1259-
prompt = cast("str", func(request))
1260+
prompt = cast("str | SystemMessage", func(request))
12601261
request.system_prompt = prompt
12611262
return handler(request)
12621263

@@ -1266,7 +1267,7 @@ async def async_wrapped_from_sync(
12661267
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
12671268
) -> ModelCallResult:
12681269
# Delegate to sync function
1269-
prompt = cast("str", func(request))
1270+
prompt = cast("str | SystemMessage", func(request))
12701271
request.system_prompt = prompt
12711272
return await handler(request)
12721273

libs/langchain_v1/tests/unit_tests/agents/test_context_editing_middleware.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from langchain_core.messages import (
1414
AIMessage,
1515
MessageLikeRepresentation,
16+
SystemMessage,
1617
ToolMessage,
1718
)
1819
from langgraph.runtime import Runtime
@@ -399,3 +400,126 @@ async def mock_handler(req: ModelRequest) -> AIMessage:
399400

400401
assert isinstance(calc_tool, ToolMessage)
401402
assert calc_tool.content == "[cleared]"
403+
404+
405+
# ==============================================================================
406+
# SystemMessage Tests
407+
# ==============================================================================
408+
409+
410+
def test_handles_system_message_prompt() -> None:
411+
"""Test that middleware handles SystemMessage as system_prompt correctly."""
412+
tool_call_id = "call-1"
413+
ai_message = AIMessage(
414+
content="",
415+
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
416+
)
417+
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
418+
419+
system_prompt = SystemMessage(content="You are a helpful assistant.")
420+
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
421+
# Manually set SystemMessage as system_prompt
422+
request.system_prompt = system_prompt
423+
424+
middleware = ContextEditingMiddleware(
425+
edits=[ClearToolUsesEdit(trigger=50)],
426+
token_count_method="model",
427+
)
428+
429+
def mock_handler(req: ModelRequest) -> AIMessage:
430+
return AIMessage(content="mock response")
431+
432+
# Call wrap_model_call - should not fail with SystemMessage
433+
middleware.wrap_model_call(request, mock_handler)
434+
435+
# Request should have processed without errors
436+
assert request.system_prompt == system_prompt
437+
assert isinstance(request.system_prompt, SystemMessage)
438+
439+
440+
def test_does_not_double_wrap_system_message() -> None:
441+
"""Test that middleware doesn't wrap SystemMessage in another SystemMessage."""
442+
tool_call_id = "call-1"
443+
ai_message = AIMessage(
444+
content="",
445+
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
446+
)
447+
tool_message = ToolMessage(content="x" * 100, tool_call_id=tool_call_id)
448+
449+
system_prompt = SystemMessage(content="Original system prompt")
450+
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
451+
request.system_prompt = system_prompt
452+
453+
middleware = ContextEditingMiddleware(
454+
edits=[ClearToolUsesEdit(trigger=50)],
455+
token_count_method="model",
456+
)
457+
458+
def mock_handler(req: ModelRequest) -> AIMessage:
459+
return AIMessage(content="mock response")
460+
461+
middleware.wrap_model_call(request, mock_handler)
462+
463+
# System prompt should still be the same SystemMessage, not wrapped
464+
assert request.system_prompt == system_prompt
465+
assert isinstance(request.system_prompt, SystemMessage)
466+
assert request.system_prompt.content == "Original system prompt"
467+
468+
469+
async def test_handles_system_message_prompt_async() -> None:
470+
"""Test async version - middleware handles SystemMessage as system_prompt correctly."""
471+
tool_call_id = "call-1"
472+
ai_message = AIMessage(
473+
content="",
474+
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
475+
)
476+
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
477+
478+
system_prompt = SystemMessage(content="You are a helpful assistant.")
479+
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
480+
# Manually set SystemMessage as system_prompt
481+
request.system_prompt = system_prompt
482+
483+
middleware = ContextEditingMiddleware(
484+
edits=[ClearToolUsesEdit(trigger=50)],
485+
token_count_method="model",
486+
)
487+
488+
async def mock_handler(req: ModelRequest) -> AIMessage:
489+
return AIMessage(content="mock response")
490+
491+
# Call awrap_model_call - should not fail with SystemMessage
492+
await middleware.awrap_model_call(request, mock_handler)
493+
494+
# Request should have processed without errors
495+
assert request.system_prompt == system_prompt
496+
assert isinstance(request.system_prompt, SystemMessage)
497+
498+
499+
async def test_does_not_double_wrap_system_message_async() -> None:
500+
"""Test async version - middleware doesn't wrap SystemMessage in another SystemMessage."""
501+
tool_call_id = "call-1"
502+
ai_message = AIMessage(
503+
content="",
504+
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
505+
)
506+
tool_message = ToolMessage(content="x" * 100, tool_call_id=tool_call_id)
507+
508+
system_prompt = SystemMessage(content="Original system prompt")
509+
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
510+
request.system_prompt = system_prompt
511+
512+
middleware = ContextEditingMiddleware(
513+
edits=[ClearToolUsesEdit(trigger=50)],
514+
token_count_method="model",
515+
)
516+
517+
async def mock_handler(req: ModelRequest) -> AIMessage:
518+
return AIMessage(content="mock response")
519+
520+
await middleware.awrap_model_call(request, mock_handler)
521+
522+
# System prompt should still be the same SystemMessage, not wrapped
523+
assert request.system_prompt == system_prompt
524+
assert isinstance(request.system_prompt, SystemMessage)
525+
assert request.system_prompt.content == "Original system prompt"

0 commit comments

Comments
 (0)