Skip to content

Commit 195a985

Browse files
committed
nits from review
1 parent 24b1bb6 commit 195a985

File tree

1 file changed

+13
-10
lines changed
  • libs/langchain_v1/langchain/agents/middleware

1 file changed

+13
-10
lines changed

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,13 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
230230
)
231231
```
232232
233-
!!! example "Override system prompt (backward compatible)"
233+
!!! example "Override multiple attributes"
234234
235235
```python
236-
new_request = request.override(system_prompt="New instructions")
236+
new_request = request.override(
237+
model=ChatOpenAI(model="gpt-4o"),
238+
system_message=SystemMessage(content="New instructions"),
239+
)
237240
```
238241
"""
239242
# Handle system_prompt/system_message conversion
@@ -695,7 +698,7 @@ def __call__(
695698
...
696699

697700

698-
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
701+
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
699702
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
700703

701704
def __call__(
@@ -1382,24 +1385,24 @@ def wrapped(
13821385

13831386
@overload
13841387
def dynamic_prompt(
1385-
func: _CallableReturningPromptString[StateT, ContextT],
1388+
func: _CallableReturningSystemMessage[StateT, ContextT],
13861389
) -> AgentMiddleware[StateT, ContextT]: ...
13871390

13881391

13891392
@overload
13901393
def dynamic_prompt(
13911394
func: None = None,
13921395
) -> Callable[
1393-
[_CallableReturningPromptString[StateT, ContextT]],
1396+
[_CallableReturningSystemMessage[StateT, ContextT]],
13941397
AgentMiddleware[StateT, ContextT],
13951398
]: ...
13961399

13971400

13981401
def dynamic_prompt(
1399-
func: _CallableReturningPromptString[StateT, ContextT] | None = None,
1402+
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
14001403
) -> (
14011404
Callable[
1402-
[_CallableReturningPromptString[StateT, ContextT]],
1405+
[_CallableReturningSystemMessage[StateT, ContextT]],
14031406
AgentMiddleware[StateT, ContextT],
14041407
]
14051408
| AgentMiddleware[StateT, ContextT]
@@ -1453,7 +1456,7 @@ def context_aware_prompt(request: ModelRequest) -> str:
14531456
"""
14541457

14551458
def decorator(
1456-
func: _CallableReturningPromptString[StateT, ContextT],
1459+
func: _CallableReturningSystemMessage[StateT, ContextT],
14571460
) -> AgentMiddleware[StateT, ContextT]:
14581461
is_async = iscoroutinefunction(func)
14591462

@@ -1488,7 +1491,7 @@ def wrapped(
14881491
request: ModelRequest,
14891492
handler: Callable[[ModelRequest], ModelResponse],
14901493
) -> ModelCallResult:
1491-
prompt = func(request)
1494+
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
14921495
if isinstance(prompt, SystemMessage):
14931496
request = request.override(system_message=prompt)
14941497
else:
@@ -1501,7 +1504,7 @@ async def async_wrapped_from_sync(
15011504
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
15021505
) -> ModelCallResult:
15031506
# Delegate to sync function
1504-
prompt = func(request)
1507+
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
15051508
if isinstance(prompt, SystemMessage):
15061509
request = request.override(system_message=prompt)
15071510
else:

0 commit comments

Comments
 (0)