Skip to content

Commit 14446ac

Browse files
committed
continuation support: ContinueRequestNode, ModelResponseState, fallback continuation pinning
Adds general-purpose continuation infrastructure for models that pause mid-turn: - ContinueRequestNode in agent graph for automatic continuation requests - ModelResponseState type (complete/suspended) on ModelResponse - Fallback model continuation pinning: pin to the model that started a continuation - Message rewinding for fallback recovery This enables Anthropic pause_turn and OpenAI background mode (added in follow-up).
1 parent 96276fd commit 14446ac

25 files changed

+5975
-83
lines changed

PLAN.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# PLAN: Continuation Support — ContinueRequestNode, ModelResponseState, Fallback Pinning
2+
3+
> **Issue refs:** #3365 (Anthropic/OpenAI Skills), #3963 (Shell/Bash builtin)
4+
> **Stack:** `continuation-support` -> `skill-support-v2` -> `local-tools`
5+
6+
---
7+
8+
## Scope
9+
10+
General-purpose agent graph infrastructure for models that pause mid-turn and expect a continuation request. This enables Anthropic `pause_turn` and OpenAI background mode (both added in the follow-up `skill-support-v2` change).
11+
12+
---
13+
14+
## 1. ModelResponseState
15+
16+
New type on `ModelResponse` indicating whether the response is final or requires further action:
17+
18+
```python
19+
ModelResponseState: TypeAlias = Literal['complete', 'suspended']
20+
```
21+
22+
- `'complete'` — default, response is done
23+
- `'suspended'` — model paused mid-turn, expects continuation
24+
25+
Added as `state` field on both `ModelResponse` (messages.py) and `StreamedResponse` (models/__init__.py). Also adds `metadata` field on `StreamedResponse` for fallback model stamping.
26+
27+
## 2. ContinueRequestNode
28+
29+
New node in the agent graph (`_agent_graph.py`) that handles automatic continuation when a model response has `state='suspended'`.
30+
31+
- Merges parts from the suspended response with the continuation response
32+
- Tracks continuation count in `GraphAgentState.continuations`
33+
- Enforces `_MAX_CONTINUATIONS = 50` safety limit
34+
- Supports both streaming and non-streaming paths
35+
- If continuation response is still suspended, chains to another `ContinueRequestNode`
36+
- If complete, transitions to `CallToolsNode`
37+
38+
## 3. Fallback Model Continuation Pinning
39+
40+
When using `FallbackModel`, a model that starts a continuation must handle subsequent continuation requests — you can't switch models mid-continuation.
41+
42+
- `_stamp_continuation()` writes the model name into `response.metadata` under `__pydantic_ai__` key
43+
- `_get_continuation_model()` reads the stamp from message history to find the pinned model
44+
- `_rewind_messages()` strips the suspended response and trailing request when a pinned model fails, allowing fallback to proceed cleanly
45+
- Both `request()` and `request_stream()` check for pinned continuation before entering the normal fallback chain
46+
47+
## 4. Test Coverage
48+
49+
All 18 new tests in `test_fallback.py` covering:
50+
- Primary model continuation success (single and multiple pauses)
51+
- Secondary model continuation after primary fails
52+
- Continuation failure propagation
53+
- Non-fallback error propagation during continuation
54+
- Recovery with message rewinding
55+
- Streaming variants of all above
56+
- Stamp/metadata edge cases

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 283 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
'UserPromptNode',
5656
'ModelRequestNode',
5757
'CallToolsNode',
58+
'ContinueRequestNode',
5859
'build_run_context',
5960
'capture_run_messages',
6061
'HistoryProcessor',
@@ -65,6 +66,10 @@
6566
S = TypeVar('S')
6667
NoneType = type(None)
6768
EndStrategy = Literal['early', 'exhaustive']
69+
70+
_MAX_CONTINUATIONS = 50
71+
"""Maximum number of continuations allowed for incomplete responses (e.g., Anthropic pause_turn)."""
72+
6873
DepsT = TypeVar('DepsT')
6974
OutputT = TypeVar('OutputT')
7075

@@ -77,6 +82,7 @@ class GraphAgentState:
7782
usage: _usage.RunUsage = dataclasses.field(default_factory=_usage.RunUsage)
7883
retries: int = 0
7984
run_step: int = 0
85+
continuations: int = 0
8086
run_id: str = dataclasses.field(default_factory=lambda: str(uuid.uuid4()))
8187
metadata: dict[str, Any] | None = None
8288
last_max_tokens: int | None = None
@@ -792,13 +798,20 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
792798
"""
793799

794800
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
795-
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
796-
default=None, init=False, repr=False
797-
)
801+
_next_node: (
802+
ModelRequestNode[DepsT, NodeRunEndT]
803+
| ContinueRequestNode[DepsT, NodeRunEndT]
804+
| End[result.FinalResult[NodeRunEndT]]
805+
| None
806+
) = field(default=None, init=False, repr=False)
798807

799808
async def run(
800809
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
801-
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
810+
) -> (
811+
ModelRequestNode[DepsT, NodeRunEndT]
812+
| ContinueRequestNode[DepsT, NodeRunEndT]
813+
| End[result.FinalResult[NodeRunEndT]]
814+
):
802815
async with self.stream(ctx):
803816
pass
804817
assert self._next_node is not None, 'the stream should set `self._next_node` before it ends'
@@ -825,6 +838,12 @@ async def _run_stream( # noqa: C901
825838
output_schema = ctx.deps.output_schema
826839

827840
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901
841+
if self.model_response.state == 'suspended':
842+
# Some providers (e.g. Anthropic pause_turn, OpenAI background mode) pause mid-turn
843+
# and expect us to continue.
844+
self._next_node = ContinueRequestNode[DepsT, NodeRunEndT](self.model_response)
845+
return
846+
828847
is_empty = not self.model_response.parts
829848
is_thinking_only = not is_empty and all(
830849
isinstance(p, _messages.ThinkingPart) for p in self.model_response.parts
@@ -1075,6 +1094,265 @@ async def run(
10751094
return End(self.final_result)
10761095

10771096

1097+
@dataclasses.dataclass
1098+
class ContinueRequestNode(AgentNode[DepsT, NodeRunEndT]):
1099+
"""A node that makes a single continuation request and transitions accordingly.
1100+
1101+
This handles providers that pause mid-turn (e.g. Anthropic `pause_turn`, OpenAI background mode).
1102+
Each node makes one continuation request: if the response is still suspended, it transitions
1103+
to a new `ContinueRequestNode`; if complete, it transitions to `CallToolsNode`.
1104+
This keeps each continuation visible as a discrete graph node transition.
1105+
1106+
Note: `agent.run_stream()` advances this node via `run()` (non-streaming), not `stream()`.
1107+
The `stream()` method is available for users who manually iterate the graph via `agent.iter()`
1108+
and want streaming events from continuation requests.
1109+
"""
1110+
1111+
model_response: _messages.ModelResponse
1112+
1113+
_result: CallToolsNode[DepsT, NodeRunEndT] | ContinueRequestNode[DepsT, NodeRunEndT] | None = field(
1114+
repr=False, init=False, default=None
1115+
)
1116+
_did_stream: bool = field(repr=False, init=False, default=False)
1117+
1118+
async def run(
1119+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
1120+
) -> CallToolsNode[DepsT, NodeRunEndT] | ContinueRequestNode[DepsT, NodeRunEndT]:
1121+
if self._result is not None:
1122+
return self._result
1123+
1124+
if self._did_stream:
1125+
raise exceptions.AgentRunError('You must finish streaming before calling run()') # pragma: no cover
1126+
1127+
# Note: self.model_response is already the last entry in ctx.state.message_history
1128+
# (appended by HandleResponseNode). We pass message_history to model.request() and the
1129+
# model reads the suspended response from there to know how to continue.
1130+
new_response = await self._request(ctx)
1131+
merged_response = self._process_response(ctx, new_response)
1132+
1133+
if new_response.state == 'suspended':
1134+
self._result = ContinueRequestNode(merged_response)
1135+
else:
1136+
self._result = CallToolsNode(merged_response)
1137+
return self._result
1138+
1139+
@asynccontextmanager
1140+
async def stream(
1141+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
1142+
) -> AsyncIterator[AsyncIterator[_messages.AgentStreamEvent]]:
1143+
"""Make a single continuation request with streaming, yielding model response events."""
1144+
assert not self._did_stream, 'stream() should only be called once per node'
1145+
1146+
stream = self._run_stream(ctx)
1147+
yield stream
1148+
1149+
# Run the stream to completion if it was not finished:
1150+
async for _event in stream:
1151+
pass
1152+
1153+
async def _run_stream(
1154+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
1155+
) -> AsyncIterator[_messages.AgentStreamEvent]:
1156+
self._check_continuation_limit(ctx)
1157+
run_context, request_context = await self._prepare_continuation(ctx)
1158+
1159+
# Cooperative hand-off between this generator and the wrap_model_request task,
1160+
# following the same pattern as ModelRequestNode._run_stream:
1161+
# 1. The task runs capability middleware, then the handler opens the stream.
1162+
# 2. The handler signals stream_ready, then waits on stream_done.
1163+
# 3. This generator yields events to the caller, then signals stream_done.
1164+
# 4. The handler resumes, the stream closes, and the task completes.
1165+
stream_ready = asyncio.Event()
1166+
stream_done = asyncio.Event()
1167+
streamed_response_holder: list[models.StreamedResponse] = []
1168+
1169+
async def _streaming_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse:
1170+
with set_current_run_context(run_context):
1171+
async with ctx.deps.model.request_stream(
1172+
req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters, run_context
1173+
) as sr:
1174+
self._did_stream = True
1175+
streamed_response_holder.append(sr)
1176+
stream_ready.set()
1177+
await stream_done.wait()
1178+
return sr.get()
1179+
1180+
wrap_task = asyncio.create_task(
1181+
ctx.deps.root_capability.wrap_model_request(
1182+
run_context,
1183+
request_context=request_context,
1184+
handler=_streaming_handler,
1185+
)
1186+
)
1187+
1188+
ready_waiter = asyncio.create_task(stream_ready.wait())
1189+
await asyncio.wait({ready_waiter, wrap_task}, return_when=asyncio.FIRST_COMPLETED)
1190+
ready_waiter.cancel()
1191+
1192+
if wrap_task.done() and not stream_ready.is_set(): # pragma: lax no cover
1193+
# wrap_model_request completed without calling handler (short-circuited or error)
1194+
try:
1195+
new_response = wrap_task.result()
1196+
except exceptions.SkipModelRequest as e:
1197+
new_response = e.response
1198+
except Exception as e:
1199+
new_response = await ctx.deps.root_capability.on_model_request_error(
1200+
run_context, request_context=request_context, error=e
1201+
)
1202+
else:
1203+
# Normal path: stream is ready, yield events
1204+
stream_error: BaseException | None = None
1205+
try:
1206+
async for event in streamed_response_holder[0]:
1207+
yield event
1208+
except BaseException as exc:
1209+
stream_error = exc
1210+
finally:
1211+
stream_done.set()
1212+
1213+
if stream_error is not None:
1214+
wrap_task.cancel()
1215+
try:
1216+
await wrap_task
1217+
except (asyncio.CancelledError, BaseException):
1218+
pass
1219+
raise stream_error
1220+
1221+
try:
1222+
new_response = await wrap_task
1223+
except Exception as e:
1224+
new_response = await ctx.deps.root_capability.on_model_request_error(
1225+
run_context, request_context=request_context, error=e
1226+
)
1227+
1228+
new_response = await ctx.deps.root_capability.after_model_request(
1229+
run_context, request_context=request_context, response=new_response
1230+
)
1231+
merged_response = self._process_response(ctx, new_response)
1232+
1233+
if new_response.state == 'suspended':
1234+
self._result = ContinueRequestNode(merged_response)
1235+
else:
1236+
self._result = CallToolsNode(merged_response)
1237+
1238+
async def _request(
1239+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
1240+
) -> _messages.ModelResponse:
1241+
"""Make a single non-streaming continuation request."""
1242+
self._check_continuation_limit(ctx)
1243+
run_context, request_context = await self._prepare_continuation(ctx)
1244+
1245+
async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse:
1246+
with set_current_run_context(run_context):
1247+
return await ctx.deps.model.request(
1248+
req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters
1249+
)
1250+
1251+
try:
1252+
response = await ctx.deps.root_capability.wrap_model_request(
1253+
run_context,
1254+
request_context=request_context,
1255+
handler=model_handler,
1256+
)
1257+
except exceptions.SkipModelRequest as e: # pragma: lax no cover
1258+
response = e.response
1259+
except Exception as e:
1260+
response = await ctx.deps.root_capability.on_model_request_error(
1261+
run_context, request_context=request_context, error=e
1262+
)
1263+
1264+
response = await ctx.deps.root_capability.after_model_request(
1265+
run_context, request_context=request_context, response=response
1266+
)
1267+
return response
1268+
1269+
async def _prepare_continuation(
1270+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
1271+
) -> tuple[RunContext[DepsT], ModelRequestContext]:
1272+
"""Prepare common state for a continuation request."""
1273+
ctx.deps.usage_limits.check_before_request(ctx.state.usage)
1274+
1275+
model_request_parameters = await _prepare_request_parameters(ctx)
1276+
run_context = build_run_context(ctx)
1277+
model_settings = ctx.deps.get_model_settings(run_context) or ModelSettings()
1278+
run_context.model_settings = model_settings
1279+
1280+
request_context = ModelRequestContext(
1281+
messages=ctx.state.message_history,
1282+
model_settings=model_settings,
1283+
model_request_parameters=model_request_parameters,
1284+
)
1285+
request_context = await ctx.deps.root_capability.before_model_request(run_context, request_context)
1286+
1287+
return run_context, request_context
1288+
1289+
def _check_continuation_limit(
1290+
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
1291+
) -> None:
1292+
ctx.state.continuations += 1
1293+
if ctx.state.continuations > _MAX_CONTINUATIONS:
1294+
raise exceptions.UnexpectedModelBehavior(
1295+
f'Exceeded maximum continuations ({_MAX_CONTINUATIONS}) for incomplete responses'
1296+
)
1297+
1298+
def _process_response(
1299+
self,
1300+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
1301+
new_response: _messages.ModelResponse,
1302+
) -> _messages.ModelResponse:
1303+
"""Process a continuation response: track usage, merge, update history.
1304+
1305+
Returns the merged response.
1306+
"""
1307+
ctx.state.usage.incr(new_response.usage)
1308+
if ctx.deps.usage_limits: # pragma: no branch
1309+
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
1310+
1311+
merged_response = self._merge_response(self.model_response, new_response)
1312+
merged_response.run_id = merged_response.run_id or ctx.state.run_id
1313+
1314+
# Intentionally replace the last message in-place: the suspended ModelResponse in message_history
1315+
# is progressively updated as continuation responses are merged, so users inspecting
1316+
# message_history after the run see the final merged response rather than each intermediate one.
1317+
assert isinstance(ctx.state.message_history[-1], _messages.ModelResponse), (
1318+
f'Expected last message to be ModelResponse, got {type(ctx.state.message_history[-1])}'
1319+
)
1320+
ctx.state.message_history[-1] = merged_response
1321+
1322+
return merged_response
1323+
1324+
@staticmethod
1325+
def _merge_response(existing: _messages.ModelResponse, new: _messages.ModelResponse) -> _messages.ModelResponse:
1326+
"""Merge a new response into an existing one.
1327+
1328+
If same `provider_response_id`, replace entirely with the new response.
1329+
If the model changed between responses, replace entirely (incompatible responses should not be merged).
1330+
Otherwise, accumulate parts, sum usage, and use other fields from the new response.
1331+
"""
1332+
# Same response ID → the new response is a full replacement (e.g. OpenAI background retrieve).
1333+
if existing.provider_response_id and existing.provider_response_id == new.provider_response_id:
1334+
return new
1335+
1336+
# Different model → replace (accumulating parts from different models is always wrong).
1337+
# When either model_name is None/empty, we fall through to accumulation — this is intentional
1338+
# because providers may not always populate model_name on continuation responses.
1339+
if existing.model_name and new.model_name and existing.model_name != new.model_name:
1340+
return new
1341+
1342+
# Same model, different response → accumulate parts and sum usage.
1343+
# Preserve existing provider response IDs when continuation responses omit them
1344+
# (e.g. resumed OpenAI streams that start after a sequence number).
1345+
merged_usage = existing.usage + new.usage
1346+
return replace(
1347+
new,
1348+
parts=[*existing.parts, *new.parts],
1349+
usage=merged_usage,
1350+
provider_response_id=new.provider_response_id or existing.provider_response_id,
1351+
)
1352+
1353+
__repr__ = dataclasses_no_defaults_repr
1354+
1355+
10781356
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
10791357
"""Build a `RunContext` object from the current agent graph run context."""
10801358
run_context = RunContext[DepsT](
@@ -1643,6 +1921,7 @@ def build_agent_graph(
16431921
g.node(UserPromptNode[DepsT, OutputT]),
16441922
g.node(ModelRequestNode[DepsT, OutputT]),
16451923
g.node(CallToolsNode[DepsT, OutputT]),
1924+
g.node(ContinueRequestNode[DepsT, OutputT]),
16461925
g.node(
16471926
SetFinalResult[DepsT, OutputT],
16481927
),

0 commit comments

Comments
 (0)