Skip to content

Commit 39e46e5

Browse files
committed
updates
1 parent c199f15 commit 39e46e5

File tree

10 files changed

+1249
-36
lines changed

10 files changed

+1249
-36
lines changed

libs/langchain_v1/langchain/agents/middleware/todo.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from __future__ import annotations
55

6-
from typing import TYPE_CHECKING, Annotated, Literal
6+
from typing import TYPE_CHECKING, Annotated, Literal, cast
77

88
if TYPE_CHECKING:
99
from collections.abc import Awaitable, Callable
@@ -194,12 +194,16 @@ def wrap_model_call(
194194
handler: Callable[[ModelRequest], ModelResponse],
195195
) -> ModelCallResult:
196196
"""Update the system message to include the todo system prompt."""
197-
new_system_content = (
198-
request.system_message.content + "\n\n" + self.system_prompt
199-
if request.system_message
200-
else self.system_prompt
197+
if request.system_message is not None:
198+
new_system_content = [
199+
*request.system_message.content_blocks,
200+
{"type": "text", "text": self.system_prompt},
201+
]
202+
else:
203+
new_system_content = [{"type": "text", "text": self.system_prompt}]
204+
new_system_message = SystemMessage(
205+
content=cast("list[str | dict[str, str]]", new_system_content)
201206
)
202-
new_system_message = SystemMessage(content=new_system_content)
203207
return handler(request.override(system_message=new_system_message))
204208

205209
async def awrap_model_call(
@@ -208,10 +212,14 @@ async def awrap_model_call(
208212
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
209213
) -> ModelCallResult:
210214
"""Update the system message to include the todo system prompt (async version)."""
211-
new_system_content = (
212-
request.system_message.content + "\n\n" + self.system_prompt
213-
if request.system_message
214-
else self.system_prompt
215+
if request.system_message is not None:
216+
new_system_content = [
217+
*request.system_message.content_blocks,
218+
{"type": "text", "text": self.system_prompt},
219+
]
220+
else:
221+
new_system_content = [{"type": "text", "text": self.system_prompt}]
222+
new_system_message = SystemMessage(
223+
content=cast("list[str | dict[str, str]]", new_system_content)
215224
)
216-
new_system_message = SystemMessage(content=new_system_content)
217225
return await handler(request.override(system_message=new_system_message))

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,11 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
242242
raise ValueError(msg)
243243

244244
if "system_prompt" in overrides:
245-
system_prompt = overrides["system_prompt"]
245+
system_prompt = cast("str", overrides.pop("system_prompt")) # type: ignore[typeddict-item]
246246
if system_prompt is None:
247247
overrides["system_message"] = None
248248
else:
249249
overrides["system_message"] = SystemMessage(content=system_prompt)
250-
# Remove system_prompt from overrides to avoid conflict
251-
overrides = {k: v for k, v in overrides.items() if k != "system_prompt"}
252250

253251
return replace(self, **overrides)
254252

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""Unit tests for dynamic system prompt middleware with SystemMessage support.
2+
3+
These tests replicate the functionality from langchainjs PR #9459:
4+
- Middleware accepting functions that return SystemMessage
5+
- Error handling for invalid return types
6+
"""
7+
8+
from typing import cast
9+
10+
import pytest
11+
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
12+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13+
from langgraph.runtime import Runtime
14+
15+
from langchain.agents.middleware.types import ModelRequest, ModelResponse
16+
17+
18+
def _fake_runtime(context: dict | None = None) -> Runtime:
19+
"""Create a fake runtime with optional context."""
20+
if context:
21+
# Create a simple object with context
22+
class FakeRuntime:
23+
def __init__(self):
24+
self.context = type("Context", (), context)()
25+
26+
return cast(Runtime, FakeRuntime())
27+
return cast(Runtime, object())
28+
29+
30+
def _make_request(
31+
system_message: SystemMessage | None = None,
32+
system_prompt: str | None = None,
33+
) -> ModelRequest:
34+
"""Create a minimal ModelRequest for testing."""
35+
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
36+
return ModelRequest(
37+
model=model,
38+
system_message=system_message,
39+
system_prompt=system_prompt,
40+
messages=[],
41+
tool_choice=None,
42+
tools=[],
43+
response_format=None,
44+
state=cast("AgentState", {"messages": []}), # type: ignore[name-defined]
45+
runtime=_fake_runtime(),
46+
model_settings={},
47+
)
48+
49+
50+
class TestDynamicSystemPromptWithSystemMessage:
51+
"""Test middleware that accepts SystemMessage return types.
52+
53+
These tests verify that middleware can work with SystemMessage objects,
54+
not just strings, enabling richer metadata handling.
55+
"""
56+
57+
def test_middleware_can_return_system_message(self) -> None:
58+
"""Test that middleware can return a SystemMessage instead of string.
59+
60+
This replicates the JS test: "should support returning a SystemMessage"
61+
"""
62+
63+
# Create a middleware function that returns SystemMessage
64+
def dynamic_system_prompt_middleware(request: ModelRequest) -> SystemMessage:
65+
"""Return a SystemMessage with dynamic content."""
66+
region = getattr(request.runtime.context, "region", "n/a")
67+
return SystemMessage(content=f"You are a helpful assistant. Region: {region}")
68+
69+
# Create request with runtime context
70+
runtime = _fake_runtime(context={"region": "EU"})
71+
request = ModelRequest(
72+
model=GenericFakeChatModel(messages=iter([AIMessage(content="response")])),
73+
system_message=None,
74+
messages=[HumanMessage(content="Hello")],
75+
tool_choice=None,
76+
tools=[],
77+
response_format=None,
78+
state=cast("AgentState", {"messages": []}), # type: ignore[name-defined]
79+
runtime=runtime,
80+
model_settings={},
81+
)
82+
83+
# Apply the middleware
84+
new_system_message = dynamic_system_prompt_middleware(request)
85+
86+
# Verify the system message was created correctly
87+
assert isinstance(new_system_message, SystemMessage)
88+
assert len(new_system_message.content_blocks) == 1
89+
assert (
90+
new_system_message.content_blocks[0]["text"]
91+
== "You are a helpful assistant. Region: EU"
92+
)
93+
94+
def test_middleware_rejects_invalid_return_types(self) -> None:
95+
"""Test that middleware properly validates return types.
96+
97+
This replicates the JS test for error handling with invalid return types.
98+
"""
99+
100+
def invalid_middleware(request: ModelRequest) -> int:
101+
"""Return an invalid type (should raise error)."""
102+
return 123
103+
104+
request = _make_request(system_prompt="Base prompt")
105+
106+
# The middleware should not accept non-string/non-SystemMessage types
107+
# In Python, we rely on type checking, but let's verify the behavior
108+
result = invalid_middleware(request)
109+
assert not isinstance(result, (str, SystemMessage))
110+
# In a real implementation, this would be caught by type checking or runtime validation
111+
112+
def test_middleware_can_use_system_message_with_metadata(self) -> None:
113+
"""Test middleware creating SystemMessage with additional metadata."""
114+
115+
def metadata_middleware(request: ModelRequest) -> SystemMessage:
116+
"""Return SystemMessage with metadata."""
117+
return SystemMessage(
118+
content="You are a helpful assistant",
119+
additional_kwargs={"temperature": 0.7, "model": "gpt-4"},
120+
response_metadata={"region": "us-east"},
121+
)
122+
123+
request = _make_request()
124+
new_system_message = metadata_middleware(request)
125+
126+
assert len(new_system_message.content_blocks) == 1
127+
assert new_system_message.content_blocks[0]["text"] == "You are a helpful assistant"
128+
assert new_system_message.additional_kwargs == {
129+
"temperature": 0.7,
130+
"model": "gpt-4",
131+
}
132+
assert new_system_message.response_metadata == {"region": "us-east"}
133+
134+
def test_middleware_handles_none_system_message(self) -> None:
135+
"""Test middleware creating new SystemMessage when none exists."""
136+
137+
def create_if_none_middleware(request: ModelRequest) -> SystemMessage:
138+
"""Create a system message if none exists."""
139+
if request.system_message is None:
140+
return SystemMessage(content="Default system prompt")
141+
return request.system_message
142+
143+
request = _make_request(system_message=None)
144+
new_system_message = create_if_none_middleware(request)
145+
146+
assert isinstance(new_system_message, SystemMessage)
147+
assert len(new_system_message.content_blocks) == 1
148+
assert new_system_message.content_blocks[0]["text"] == "Default system prompt"
149+
150+
def test_middleware_with_content_blocks(self) -> None:
151+
"""Test middleware creating SystemMessage with content blocks."""
152+
153+
def content_blocks_middleware(request: ModelRequest) -> SystemMessage:
154+
"""Create SystemMessage with content blocks including cache control."""
155+
return SystemMessage(
156+
content=[
157+
{"type": "text", "text": "Base instructions"},
158+
{
159+
"type": "text",
160+
"text": "Cached instructions",
161+
"cache_control": {"type": "ephemeral"},
162+
},
163+
]
164+
)
165+
166+
request = _make_request()
167+
new_system_message = content_blocks_middleware(request)
168+
169+
assert isinstance(new_system_message.content_blocks, list)
170+
assert len(new_system_message.content_blocks) == 2
171+
assert new_system_message.content_blocks[0]["text"] == "Base instructions"
172+
assert new_system_message.content_blocks[1]["cache_control"] == {"type": "ephemeral"}
173+
174+
175+
class TestSystemMessageMiddlewareIntegration:
176+
"""Test integration of SystemMessage with middleware chain."""
177+
178+
def test_multiple_middleware_can_modify_system_message(self) -> None:
179+
"""Test that multiple middleware can modify system message in sequence."""
180+
181+
def first_middleware(request: ModelRequest) -> ModelRequest:
182+
"""First middleware adds base system message."""
183+
new_message = SystemMessage(
184+
content="You are an assistant.",
185+
additional_kwargs={"middleware_1": "applied"},
186+
)
187+
return request.override(system_message=new_message)
188+
189+
def second_middleware(request: ModelRequest) -> ModelRequest:
190+
"""Second middleware appends to system message."""
191+
current_content = request.system_message.text
192+
new_content = current_content + " Be helpful."
193+
194+
merged_kwargs = {
195+
**request.system_message.additional_kwargs,
196+
"middleware_2": "applied",
197+
}
198+
199+
new_message = SystemMessage(
200+
content=new_content,
201+
additional_kwargs=merged_kwargs,
202+
)
203+
return request.override(system_message=new_message)
204+
205+
# Start with no system message
206+
request = _make_request(system_message=None)
207+
208+
# Apply middleware in sequence
209+
request = first_middleware(request)
210+
assert len(request.system_message.content_blocks) == 1
211+
assert request.system_message.content_blocks[0]["text"] == "You are an assistant."
212+
assert request.system_message.additional_kwargs["middleware_1"] == "applied"
213+
214+
request = second_middleware(request)
215+
assert len(request.system_message.content_blocks) == 1
216+
assert (
217+
request.system_message.content_blocks[0]["text"] == "You are an assistant. Be helpful."
218+
)
219+
assert request.system_message.additional_kwargs["middleware_1"] == "applied"
220+
assert request.system_message.additional_kwargs["middleware_2"] == "applied"
221+
222+
def test_middleware_preserves_system_message_metadata(self) -> None:
223+
"""Test that metadata is preserved when middleware modifies system message."""
224+
base_message = SystemMessage(
225+
content="Base prompt",
226+
additional_kwargs={"key1": "value1", "key2": "value2"},
227+
response_metadata={"model": "gpt-4"},
228+
)
229+
230+
def preserving_middleware(request: ModelRequest) -> ModelRequest:
231+
"""Middleware that preserves existing metadata."""
232+
new_message = SystemMessage(
233+
content=request.system_message.text + " Extended.",
234+
additional_kwargs=request.system_message.additional_kwargs,
235+
response_metadata=request.system_message.response_metadata,
236+
)
237+
return request.override(system_message=new_message)
238+
239+
request = _make_request(system_message=base_message)
240+
new_request = preserving_middleware(request)
241+
242+
assert len(new_request.system_message.content_blocks) == 1
243+
assert new_request.system_message.content_blocks[0]["text"] == "Base prompt Extended."
244+
assert new_request.system_message.additional_kwargs == {
245+
"key1": "value1",
246+
"key2": "value2",
247+
}
248+
assert new_request.system_message.response_metadata == {"model": "gpt-4"}
249+
250+
def test_backward_compatibility_with_string_system_prompt(self) -> None:
251+
"""Test that middleware still works with string system prompts."""
252+
253+
def string_middleware(request: ModelRequest) -> ModelRequest:
254+
"""Middleware using string system prompt (backward compatible)."""
255+
current_prompt = request.system_prompt or ""
256+
new_prompt = current_prompt + " Additional instructions."
257+
return request.override(system_prompt=new_prompt.strip())
258+
259+
request = _make_request(system_prompt="Base prompt")
260+
new_request = string_middleware(request)
261+
262+
assert new_request.system_prompt == "Base prompt Additional instructions."
263+
# The system_prompt should be converted to SystemMessage internally
264+
assert isinstance(new_request.system_message, SystemMessage)
265+
266+
def test_middleware_can_switch_between_string_and_system_message(self) -> None:
267+
"""Test middleware can work with both string and SystemMessage.
268+
269+
Note: In the Python implementation, system_prompt is automatically
270+
converted to SystemMessage, so middleware always sees a SystemMessage.
271+
"""
272+
273+
def flexible_middleware(request: ModelRequest) -> ModelRequest:
274+
"""Middleware that works with both formats."""
275+
if request.system_message:
276+
# Work with SystemMessage
277+
new_message = SystemMessage(content=request.system_message.text + " [modified]")
278+
return request.override(system_message=new_message)
279+
else:
280+
# Create new SystemMessage if none exists
281+
new_message = SystemMessage(content="[created]")
282+
return request.override(system_message=new_message)
283+
284+
# Test with explicit SystemMessage
285+
request1 = _make_request(system_message=SystemMessage(content="Hello"))
286+
result1 = flexible_middleware(request1)
287+
assert len(result1.system_message.content_blocks) == 1
288+
assert result1.system_message.content_blocks[0]["text"] == "Hello [modified]"
289+
290+
# Test with string (gets converted to SystemMessage automatically)
291+
request2 = _make_request(system_prompt="Hello")
292+
result2 = flexible_middleware(request2)
293+
# String prompts are converted to SystemMessage internally
294+
assert len(result2.system_message.content_blocks) == 1
295+
assert result2.system_message.content_blocks[0]["text"] == "Hello [modified]"
296+
297+
# Test with None
298+
request3 = _make_request(system_message=None)
299+
result3 = flexible_middleware(request3)
300+
assert len(result3.system_message.content_blocks) == 1
301+
assert result3.system_message.content_blocks[0]["text"] == "[created]"

0 commit comments

Comments
 (0)