@@ -230,10 +230,13 @@ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
230230 )
231231 ```
232232
233- !!! example "Override system prompt (backward compatible) "
233+ !!! example "Override multiple attributes "
234234
235235 ```python
236- new_request = request.override(system_prompt="New instructions")
236+ new_request = request.override(
237+ model=ChatOpenAI(model="gpt-4o"),
238+ system_message=SystemMessage(content="New instructions"),
239+ )
237240 ```
238241 """
239242 # Handle system_prompt/system_message conversion
@@ -695,7 +698,7 @@ def __call__(
695698 ...
696699
697700
698- class _CallableReturningPromptString (Protocol [StateT_contra , ContextT ]): # type: ignore[misc]
701+ class _CallableReturningSystemMessage (Protocol [StateT_contra , ContextT ]): # type: ignore[misc]
699702 """Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
700703
701704 def __call__ (
@@ -1382,24 +1385,24 @@ def wrapped(
13821385
13831386@overload
13841387def dynamic_prompt (
1385- func : _CallableReturningPromptString [StateT , ContextT ],
1388+ func : _CallableReturningSystemMessage [StateT , ContextT ],
13861389) -> AgentMiddleware [StateT , ContextT ]: ...
13871390
13881391
13891392@overload
13901393def dynamic_prompt (
13911394 func : None = None ,
13921395) -> Callable [
1393- [_CallableReturningPromptString [StateT , ContextT ]],
1396+ [_CallableReturningSystemMessage [StateT , ContextT ]],
13941397 AgentMiddleware [StateT , ContextT ],
13951398]: ...
13961399
13971400
13981401def dynamic_prompt (
1399- func : _CallableReturningPromptString [StateT , ContextT ] | None = None ,
1402+ func : _CallableReturningSystemMessage [StateT , ContextT ] | None = None ,
14001403) -> (
14011404 Callable [
1402- [_CallableReturningPromptString [StateT , ContextT ]],
1405+ [_CallableReturningSystemMessage [StateT , ContextT ]],
14031406 AgentMiddleware [StateT , ContextT ],
14041407 ]
14051408 | AgentMiddleware [StateT , ContextT ]
@@ -1453,7 +1456,7 @@ def context_aware_prompt(request: ModelRequest) -> str:
14531456 """
14541457
14551458 def decorator (
1456- func : _CallableReturningPromptString [StateT , ContextT ],
1459+ func : _CallableReturningSystemMessage [StateT , ContextT ],
14571460 ) -> AgentMiddleware [StateT , ContextT ]:
14581461 is_async = iscoroutinefunction (func )
14591462
@@ -1488,7 +1491,7 @@ def wrapped(
14881491 request : ModelRequest ,
14891492 handler : Callable [[ModelRequest ], ModelResponse ],
14901493 ) -> ModelCallResult :
1491- prompt = func (request )
1494+ prompt = cast ( "Callable[[ModelRequest], SystemMessage | str]" , func ) (request )
14921495 if isinstance (prompt , SystemMessage ):
14931496 request = request .override (system_message = prompt )
14941497 else :
@@ -1501,7 +1504,7 @@ async def async_wrapped_from_sync(
15011504 handler : Callable [[ModelRequest ], Awaitable [ModelResponse ]],
15021505 ) -> ModelCallResult :
15031506 # Delegate to sync function
1504- prompt = func (request )
1507+ prompt = cast ( "Callable[[ModelRequest], SystemMessage | str]" , func ) (request )
15051508 if isinstance (prompt , SystemMessage ):
15061509 request = request .override (system_message = prompt )
15071510 else :
0 commit comments