5555 'UserPromptNode' ,
5656 'ModelRequestNode' ,
5757 'CallToolsNode' ,
58+ 'ContinueRequestNode' ,
5859 'build_run_context' ,
5960 'capture_run_messages' ,
6061 'HistoryProcessor' ,
6566S = TypeVar ('S' )
6667NoneType = type (None )
6768EndStrategy = Literal ['early' , 'exhaustive' ]
69+
70+ _MAX_CONTINUATIONS = 50
71+ """Maximum number of continuations allowed for incomplete responses (e.g., Anthropic pause_turn)."""
72+
6873DepsT = TypeVar ('DepsT' )
6974OutputT = TypeVar ('OutputT' )
7075
@@ -77,6 +82,7 @@ class GraphAgentState:
7782 usage : _usage .RunUsage = dataclasses .field (default_factory = _usage .RunUsage )
7883 retries : int = 0
7984 run_step : int = 0
85+ continuations : int = 0
8086 run_id : str = dataclasses .field (default_factory = lambda : str (uuid .uuid4 ()))
8187 metadata : dict [str , Any ] | None = None
8288 last_max_tokens : int | None = None
@@ -792,13 +798,20 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
792798 """
793799
794800 _events_iterator : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , init = False , repr = False )
795- _next_node : ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]] | None = field (
796- default = None , init = False , repr = False
797- )
801+ _next_node : (
802+ ModelRequestNode [DepsT , NodeRunEndT ]
803+ | ContinueRequestNode [DepsT , NodeRunEndT ]
804+ | End [result .FinalResult [NodeRunEndT ]]
805+ | None
806+ ) = field (default = None , init = False , repr = False )
798807
799808 async def run (
800809 self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
801- ) -> ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]]:
810+ ) -> (
811+ ModelRequestNode [DepsT , NodeRunEndT ]
812+ | ContinueRequestNode [DepsT , NodeRunEndT ]
813+ | End [result .FinalResult [NodeRunEndT ]]
814+ ):
802815 async with self .stream (ctx ):
803816 pass
804817 assert self ._next_node is not None , 'the stream should set `self._next_node` before it ends'
@@ -825,6 +838,12 @@ async def _run_stream( # noqa: C901
825838 output_schema = ctx .deps .output_schema
826839
827840 async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]: # noqa: C901
841+ if self .model_response .state == 'suspended' :
842+ # Some providers (e.g. Anthropic pause_turn, OpenAI background mode) pause mid-turn
843+ # and expect us to continue.
844+ self ._next_node = ContinueRequestNode [DepsT , NodeRunEndT ](self .model_response )
845+ return
846+
828847 is_empty = not self .model_response .parts
829848 is_thinking_only = not is_empty and all (
830849 isinstance (p , _messages .ThinkingPart ) for p in self .model_response .parts
@@ -1075,6 +1094,265 @@ async def run(
10751094 return End (self .final_result )
10761095
10771096
1097+ @dataclasses .dataclass
1098+ class ContinueRequestNode (AgentNode [DepsT , NodeRunEndT ]):
1099+ """A node that makes a single continuation request and transitions accordingly.
1100+
1101+ This handles providers that pause mid-turn (e.g. Anthropic `pause_turn`, OpenAI background mode).
1102+ Each node makes one continuation request: if the response is still suspended, it transitions
1103+ to a new `ContinueRequestNode`; if complete, it transitions to `CallToolsNode`.
1104+ This keeps each continuation visible as a discrete graph node transition.
1105+
1106+ Note: `agent.run_stream()` advances this node via `run()` (non-streaming), not `stream()`.
1107+ The `stream()` method is available for users who manually iterate the graph via `agent.iter()`
1108+ and want streaming events from continuation requests.
1109+ """
1110+
1111+ model_response : _messages .ModelResponse
1112+
1113+ _result : CallToolsNode [DepsT , NodeRunEndT ] | ContinueRequestNode [DepsT , NodeRunEndT ] | None = field (
1114+ repr = False , init = False , default = None
1115+ )
1116+ _did_stream : bool = field (repr = False , init = False , default = False )
1117+
1118+ async def run (
1119+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
1120+ ) -> CallToolsNode [DepsT , NodeRunEndT ] | ContinueRequestNode [DepsT , NodeRunEndT ]:
1121+ if self ._result is not None :
1122+ return self ._result
1123+
1124+ if self ._did_stream :
1125+ raise exceptions .AgentRunError ('You must finish streaming before calling run()' ) # pragma: no cover
1126+
1127+ # Note: self.model_response is already the last entry in ctx.state.message_history
1128+ # (appended by HandleResponseNode). We pass message_history to model.request() and the
1129+ # model reads the suspended response from there to know how to continue.
1130+ new_response = await self ._request (ctx )
1131+ merged_response = self ._process_response (ctx , new_response )
1132+
1133+ if new_response .state == 'suspended' :
1134+ self ._result = ContinueRequestNode (merged_response )
1135+ else :
1136+ self ._result = CallToolsNode (merged_response )
1137+ return self ._result
1138+
1139+ @asynccontextmanager
1140+ async def stream (
1141+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
1142+ ) -> AsyncIterator [AsyncIterator [_messages .AgentStreamEvent ]]:
1143+ """Make a single continuation request with streaming, yielding model response events."""
1144+ assert not self ._did_stream , 'stream() should only be called once per node'
1145+
1146+ stream = self ._run_stream (ctx )
1147+ yield stream
1148+
1149+ # Run the stream to completion if it was not finished:
1150+ async for _event in stream :
1151+ pass
1152+
1153+ async def _run_stream (
1154+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
1155+ ) -> AsyncIterator [_messages .AgentStreamEvent ]:
1156+ self ._check_continuation_limit (ctx )
1157+ run_context , request_context = await self ._prepare_continuation (ctx )
1158+
1159+ # Cooperative hand-off between this generator and the wrap_model_request task,
1160+ # following the same pattern as ModelRequestNode._run_stream:
1161+ # 1. The task runs capability middleware, then the handler opens the stream.
1162+ # 2. The handler signals stream_ready, then waits on stream_done.
1163+ # 3. This generator yields events to the caller, then signals stream_done.
1164+ # 4. The handler resumes, the stream closes, and the task completes.
1165+ stream_ready = asyncio .Event ()
1166+ stream_done = asyncio .Event ()
1167+ streamed_response_holder : list [models .StreamedResponse ] = []
1168+
1169+ async def _streaming_handler (req_ctx : ModelRequestContext ) -> _messages .ModelResponse :
1170+ with set_current_run_context (run_context ):
1171+ async with ctx .deps .model .request_stream (
1172+ req_ctx .messages , req_ctx .model_settings , req_ctx .model_request_parameters , run_context
1173+ ) as sr :
1174+ self ._did_stream = True
1175+ streamed_response_holder .append (sr )
1176+ stream_ready .set ()
1177+ await stream_done .wait ()
1178+ return sr .get ()
1179+
1180+ wrap_task = asyncio .create_task (
1181+ ctx .deps .root_capability .wrap_model_request (
1182+ run_context ,
1183+ request_context = request_context ,
1184+ handler = _streaming_handler ,
1185+ )
1186+ )
1187+
1188+ ready_waiter = asyncio .create_task (stream_ready .wait ())
1189+ await asyncio .wait ({ready_waiter , wrap_task }, return_when = asyncio .FIRST_COMPLETED )
1190+ ready_waiter .cancel ()
1191+
1192+ if wrap_task .done () and not stream_ready .is_set (): # pragma: lax no cover
1193+ # wrap_model_request completed without calling handler (short-circuited or error)
1194+ try :
1195+ new_response = wrap_task .result ()
1196+ except exceptions .SkipModelRequest as e :
1197+ new_response = e .response
1198+ except Exception as e :
1199+ new_response = await ctx .deps .root_capability .on_model_request_error (
1200+ run_context , request_context = request_context , error = e
1201+ )
1202+ else :
1203+ # Normal path: stream is ready, yield events
1204+ stream_error : BaseException | None = None
1205+ try :
1206+ async for event in streamed_response_holder [0 ]:
1207+ yield event
1208+ except BaseException as exc :
1209+ stream_error = exc
1210+ finally :
1211+ stream_done .set ()
1212+
1213+ if stream_error is not None :
1214+ wrap_task .cancel ()
1215+ try :
1216+ await wrap_task
1217+ except (asyncio .CancelledError , BaseException ):
1218+ pass
1219+ raise stream_error
1220+
1221+ try :
1222+ new_response = await wrap_task
1223+ except Exception as e :
1224+ new_response = await ctx .deps .root_capability .on_model_request_error (
1225+ run_context , request_context = request_context , error = e
1226+ )
1227+
1228+ new_response = await ctx .deps .root_capability .after_model_request (
1229+ run_context , request_context = request_context , response = new_response
1230+ )
1231+ merged_response = self ._process_response (ctx , new_response )
1232+
1233+ if new_response .state == 'suspended' :
1234+ self ._result = ContinueRequestNode (merged_response )
1235+ else :
1236+ self ._result = CallToolsNode (merged_response )
1237+
1238+ async def _request (
1239+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
1240+ ) -> _messages .ModelResponse :
1241+ """Make a single non-streaming continuation request."""
1242+ self ._check_continuation_limit (ctx )
1243+ run_context , request_context = await self ._prepare_continuation (ctx )
1244+
1245+ async def model_handler (req_ctx : ModelRequestContext ) -> _messages .ModelResponse :
1246+ with set_current_run_context (run_context ):
1247+ return await ctx .deps .model .request (
1248+ req_ctx .messages , req_ctx .model_settings , req_ctx .model_request_parameters
1249+ )
1250+
1251+ try :
1252+ response = await ctx .deps .root_capability .wrap_model_request (
1253+ run_context ,
1254+ request_context = request_context ,
1255+ handler = model_handler ,
1256+ )
1257+ except exceptions .SkipModelRequest as e : # pragma: lax no cover
1258+ response = e .response
1259+ except Exception as e :
1260+ response = await ctx .deps .root_capability .on_model_request_error (
1261+ run_context , request_context = request_context , error = e
1262+ )
1263+
1264+ response = await ctx .deps .root_capability .after_model_request (
1265+ run_context , request_context = request_context , response = response
1266+ )
1267+ return response
1268+
1269+ async def _prepare_continuation (
1270+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
1271+ ) -> tuple [RunContext [DepsT ], ModelRequestContext ]:
1272+ """Prepare common state for a continuation request."""
1273+ ctx .deps .usage_limits .check_before_request (ctx .state .usage )
1274+
1275+ model_request_parameters = await _prepare_request_parameters (ctx )
1276+ run_context = build_run_context (ctx )
1277+ model_settings = ctx .deps .get_model_settings (run_context ) or ModelSettings ()
1278+ run_context .model_settings = model_settings
1279+
1280+ request_context = ModelRequestContext (
1281+ messages = ctx .state .message_history ,
1282+ model_settings = model_settings ,
1283+ model_request_parameters = model_request_parameters ,
1284+ )
1285+ request_context = await ctx .deps .root_capability .before_model_request (run_context , request_context )
1286+
1287+ return run_context , request_context
1288+
1289+ def _check_continuation_limit (
1290+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
1291+ ) -> None :
1292+ ctx .state .continuations += 1
1293+ if ctx .state .continuations > _MAX_CONTINUATIONS :
1294+ raise exceptions .UnexpectedModelBehavior (
1295+ f'Exceeded maximum continuations ({ _MAX_CONTINUATIONS } ) for incomplete responses'
1296+ )
1297+
1298+ def _process_response (
1299+ self ,
1300+ ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
1301+ new_response : _messages .ModelResponse ,
1302+ ) -> _messages .ModelResponse :
1303+ """Process a continuation response: track usage, merge, update history.
1304+
1305+ Returns the merged response.
1306+ """
1307+ ctx .state .usage .incr (new_response .usage )
1308+ if ctx .deps .usage_limits : # pragma: no branch
1309+ ctx .deps .usage_limits .check_tokens (ctx .state .usage )
1310+
1311+ merged_response = self ._merge_response (self .model_response , new_response )
1312+ merged_response .run_id = merged_response .run_id or ctx .state .run_id
1313+
1314+ # Intentionally replace the last message in-place: the suspended ModelResponse in message_history
1315+ # is progressively updated as continuation responses are merged, so users inspecting
1316+ # message_history after the run see the final merged response rather than each intermediate one.
1317+ assert isinstance (ctx .state .message_history [- 1 ], _messages .ModelResponse ), (
1318+ f'Expected last message to be ModelResponse, got { type (ctx .state .message_history [- 1 ])} '
1319+ )
1320+ ctx .state .message_history [- 1 ] = merged_response
1321+
1322+ return merged_response
1323+
1324+ @staticmethod
1325+ def _merge_response (existing : _messages .ModelResponse , new : _messages .ModelResponse ) -> _messages .ModelResponse :
1326+ """Merge a new response into an existing one.
1327+
1328+ If same `provider_response_id`, replace entirely with the new response.
1329+ If the model changed between responses, replace entirely (incompatible responses should not be merged).
1330+ Otherwise, accumulate parts, sum usage, and use other fields from the new response.
1331+ """
1332+ # Same response ID → the new response is a full replacement (e.g. OpenAI background retrieve).
1333+ if existing .provider_response_id and existing .provider_response_id == new .provider_response_id :
1334+ return new
1335+
1336+ # Different model → replace (accumulating parts from different models is always wrong).
1337+ # When either model_name is None/empty, we fall through to accumulation — this is intentional
1338+ # because providers may not always populate model_name on continuation responses.
1339+ if existing .model_name and new .model_name and existing .model_name != new .model_name :
1340+ return new
1341+
1342+ # Same model, different response → accumulate parts and sum usage.
1343+ # Preserve existing provider response IDs when continuation responses omit them
1344+ # (e.g. resumed OpenAI streams that start after a sequence number).
1345+ merged_usage = existing .usage + new .usage
1346+ return replace (
1347+ new ,
1348+ parts = [* existing .parts , * new .parts ],
1349+ usage = merged_usage ,
1350+ provider_response_id = new .provider_response_id or existing .provider_response_id ,
1351+ )
1352+
1353+ __repr__ = dataclasses_no_defaults_repr
1354+
1355+
10781356def build_run_context (ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]) -> RunContext [DepsT ]:
10791357 """Build a `RunContext` object from the current agent graph run context."""
10801358 run_context = RunContext [DepsT ](
@@ -1643,6 +1921,7 @@ def build_agent_graph(
16431921 g .node (UserPromptNode [DepsT , OutputT ]),
16441922 g .node (ModelRequestNode [DepsT , OutputT ]),
16451923 g .node (CallToolsNode [DepsT , OutputT ]),
1924+ g .node (ContinueRequestNode [DepsT , OutputT ]),
16461925 g .node (
16471926 SetFinalResult [DepsT , OutputT ],
16481927 ),
0 commit comments