Skip to content

Commit cae9333

Browse files
strawgateclaude
andauthored
fix: cap consecutive final_response validation retries (#3851)
* Cap consecutive final_response validation retries to 3 Previously, when the LLM repeatedly called final_response with data that failed validation, the retry loop would continue up to 100 times (the shared max_iterations limit), wasting tokens on a model that cannot satisfy the schema. Add _MAX_VALIDATION_RETRIES (default 3) that caps consecutive validation failures. The counter resets when the LLM calls other tools (not final_response), so the cap only applies to consecutive failures. Fixes #3848 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add tests for consecutive validation retry cap Tests cover: - Validation failures within cap followed by success - Consecutive validation failures exceeding cap (raises RuntimeError) - Counter reset when LLM calls other tools between validation failures Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Slim down validation retry cap tests Reduce boilerplate with helper functions. Simplify counter-reset test from 5 calls to 4. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix static analysis: move imports to module level and format Move CreateMessageResultWithTools and ToolUseContent imports to the top of the test file so ty can resolve the names used in return-type annotations of the helper functions. Also fix ruff import sorting and formatting issues. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Align validation retry semantics with text-response retries Change `>=` to `>` so _MAX_VALIDATION_RETRIES means "number of retries after the initial attempt" (total = N+1), matching the convention used by _MAX_TEXT_RESPONSE_RETRIES in the text-response retry path. Before: _MAX=3 meant 3 total attempts (>= comparison) After: _MAX=3 means 1 initial + 3 retries = 4 total (> comparison) 🤖 Generated with Claude Code Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d5a3d54 commit cae9333

2 files changed

Lines changed: 148 additions & 1 deletion

File tree

src/fastmcp/server/sampling/run.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646

4747
ResultT = TypeVar("ResultT")
4848

49+
# Maximum number of consecutive final_response validation retries (not
50+
# counting the initial attempt) before aborting. Total attempts = N + 1.
51+
_MAX_VALIDATION_RETRIES = 3
52+
4953
# Simplified tool choice type - just the mode string instead of the full MCP object
5054
ToolChoiceOption = Literal["auto", "required", "none"]
5155

@@ -615,6 +619,7 @@ async def sample_impl(
615619
current_messages: str | Sequence[str | SamplingMessage] = messages
616620

617621
text_response_retries = 0
622+
consecutive_validation_failures = 0
618623

619624
for _iteration in range(max_iterations):
620625
step = await sample_step_impl(
@@ -631,9 +636,11 @@ async def sample_impl(
631636
)
632637

633638
# Check for final_response tool call for structured output
639+
had_final_response = False
634640
if result_type is not None and result_type is not str and step.is_tool_use:
635641
for tool_call in step.tool_calls:
636642
if tool_call.name == "final_response":
643+
had_final_response = True
637644
# Validate and return the structured result
638645
type_adapter = get_cached_typeadapter(result_type)
639646

@@ -660,6 +667,13 @@ async def sample_impl(
660667
history=step.history,
661668
)
662669
except ValidationError as e:
670+
consecutive_validation_failures += 1
671+
if consecutive_validation_failures > _MAX_VALIDATION_RETRIES:
672+
raise RuntimeError(
673+
f"Structured output validation failed "
674+
f"{consecutive_validation_failures} consecutive "
675+
f"times for type {result_type.__name__}: {e}"
676+
) from e
663677
# Validation failed - add error as tool result
664678
step.history.append(
665679
SamplingMessage(
@@ -683,6 +697,10 @@ async def sample_impl(
683697
)
684698
)
685699

700+
# The LLM called tools but not final_response — reset validation counter
701+
if not had_final_response:
702+
consecutive_validation_failures = 0
703+
686704
# If not a tool use response, we're done
687705
if not step.is_tool_use:
688706
# For structured output, the LLM must use the final_response tool

tests/client/test_sampling_result_types.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from mcp.types import TextContent
2+
from mcp.types import CreateMessageResultWithTools, TextContent, ToolUseContent
33

44
from fastmcp import Client, Context, FastMCP
55
from fastmcp.client.sampling import RequestContext, SamplingMessage, SamplingParams
@@ -550,3 +550,132 @@ async def t(context: Context) -> str:
550550

551551
assert call_count == 1
552552
assert result.data == "hello"
553+
554+
555+
def _final_response(call_id: str, input_data: dict) -> CreateMessageResultWithTools:
556+
"""Build a final_response tool-use reply."""
557+
return CreateMessageResultWithTools(
558+
role="assistant",
559+
content=[
560+
ToolUseContent(
561+
type="tool_use", id=call_id, name="final_response", input=input_data
562+
)
563+
],
564+
model="test-model",
565+
stopReason="toolUse",
566+
)
567+
568+
569+
def _tool_call(
570+
call_id: str, name: str, input_data: dict
571+
) -> CreateMessageResultWithTools:
572+
"""Build a regular tool-use reply."""
573+
return CreateMessageResultWithTools(
574+
role="assistant",
575+
content=[
576+
ToolUseContent(type="tool_use", id=call_id, name=name, input=input_data)
577+
],
578+
model="test-model",
579+
stopReason="toolUse",
580+
)
581+
582+
583+
class TestValidationRetryCap:
584+
"""Tests for the consecutive validation retry cap (PR #3851)."""
585+
586+
async def test_validation_failures_within_cap_then_success(self):
587+
"""Two consecutive failures followed by a valid response succeeds."""
588+
from pydantic import BaseModel
589+
590+
class R(BaseModel):
591+
value: int
592+
593+
call_count = 0
594+
595+
def handler(messages, params, ctx):
596+
nonlocal call_count
597+
call_count += 1
598+
if call_count <= 2:
599+
return _final_response(f"c{call_count}", {"value": "bad"})
600+
return _final_response(f"c{call_count}", {"value": 99})
601+
602+
mcp = FastMCP(sampling_handler=handler)
603+
604+
@mcp.tool
605+
async def t(context: Context) -> str:
606+
r = await context.sample(messages="go", result_type=R)
607+
return str(r.result.value)
608+
609+
async with Client(mcp) as client:
610+
result = await client.call_tool("t", {})
611+
612+
assert call_count == 3
613+
assert result.data == "99"
614+
615+
async def test_consecutive_validation_failures_exceed_cap(self):
616+
"""Always-invalid responses raise ToolError after exceeding the cap."""
617+
from pydantic import BaseModel
618+
619+
from fastmcp.exceptions import ToolError
620+
from fastmcp.server.sampling.run import _MAX_VALIDATION_RETRIES
621+
622+
class R(BaseModel):
623+
value: int
624+
625+
call_count = 0
626+
627+
def handler(messages, params, ctx):
628+
nonlocal call_count
629+
call_count += 1
630+
return _final_response(f"c{call_count}", {"value": "wrong"})
631+
632+
mcp = FastMCP(sampling_handler=handler)
633+
634+
@mcp.tool
635+
async def t(context: Context) -> str:
636+
return str((await context.sample(messages="go", result_type=R)).result)
637+
638+
async with Client(mcp) as client:
639+
with pytest.raises(ToolError, match="consecutive"):
640+
await client.call_tool("t", {})
641+
642+
# 1 initial attempt + _MAX_VALIDATION_RETRIES retries
643+
assert call_count == _MAX_VALIDATION_RETRIES + 1
644+
645+
async def test_validation_counter_resets_after_other_tool_call(self):
646+
"""A tool call between validation failures resets the counter."""
647+
from pydantic import BaseModel
648+
649+
class R(BaseModel):
650+
value: int
651+
652+
def helper_tool(x: int) -> str:
653+
"""A helper tool."""
654+
return f"result:{x}"
655+
656+
call_count = 0
657+
658+
def handler(messages, params, ctx):
659+
nonlocal call_count
660+
call_count += 1
661+
# fail -> other tool (resets counter) -> fail -> succeed
662+
if call_count == 1:
663+
return _final_response("c1", {"value": "bad"})
664+
if call_count == 2:
665+
return _tool_call("c2", "helper_tool", {"x": 1})
666+
if call_count == 3:
667+
return _final_response("c3", {"value": "bad"})
668+
return _final_response("c4", {"value": 42})
669+
670+
mcp = FastMCP(sampling_handler=handler)
671+
672+
@mcp.tool
673+
async def t(context: Context) -> str:
674+
r = await context.sample(messages="go", tools=[helper_tool], result_type=R)
675+
return str(r.result.value)
676+
677+
async with Client(mcp) as client:
678+
result = await client.call_tool("t", {})
679+
680+
assert call_count == 4
681+
assert result.data == "42"

0 commit comments

Comments
 (0)